diff --git a/enhance.py b/enhance.py index bc40cdb..1dbe199 100755 --- a/enhance.py +++ b/enhance.py @@ -38,6 +38,7 @@ add_arg('files', nargs='*', default=[]) add_arg('--scales', default=2, type=int, help='How many times to perform 2x upsampling.') add_arg('--model', default='small', type=str, help='Name of the neural network to load/save.') add_arg('--train', default=False, type=str, help='File pattern to load for training.') +add_arg('--save-every-epoch', default=False, action='store_true', help='Save generator after every training epoch.') add_arg('--batch-resolution', default=192, type=int, help='Resolution of images in training batch.') add_arg('--batch-size', default=15, type=int, help='Number of images per training batch.') add_arg('--buffer-size', default=1500, type=int, help='Total image fragments kept in cache.') @@ -484,6 +485,8 @@ class NeuralEnhancer(object): print(' - adversary mode: generator engaging discriminator.') self.model.adversary_weight.set_value(args.adversary_weight) running = None + if args.save_every_epoch: + self.model.save_generator() except KeyboardInterrupt: pass