diff --git a/enhance.py b/enhance.py index 882fa68..60128cf 100755 --- a/enhance.py +++ b/enhance.py @@ -46,10 +46,10 @@ add_arg('--epochs', default=10, type=int, help='Total add_arg('--epoch-size', default=72, type=int, help='Number of batches trained in an epoch.') add_arg('--save-every', default=10, type=int, help='Save generator after every training epoch.') add_arg('--batch-shape', default=192, type=int, help='Resolution of images in training batch.') -add_arg('--batch-size', default=15, type=int, help='Number of images per training batch.') +add_arg('--batch-size', default=10, type=int, help='Number of images per training batch.') add_arg('--buffer-size', default=1500, type=int, help='Total image fragments kept in cache.') add_arg('--buffer-similar', default=5, type=int, help='Fragments cached for each image loaded.') -add_arg('--learning-rate', default=1E-4, type=float, help='Parameter for the ADAM optimizer.') +add_arg('--learning-rate', default=5E-4, type=float, help='Parameter for the ADAM optimizer.') add_arg('--learning-period', default=100, type=int, help='How often to decay the learning rate.') add_arg('--learning-decay', default=0.5, type=float, help='How much to decay the learning rate.') add_arg('--generator-upscale', default=2, type=int, help='Steps of 2x up-sampling as post-process.') @@ -248,14 +248,25 @@ class Model(object): config, params = self.load_model() self.setup_generator(self.last_layer(), config) + # Compute batch-size to take into account there's no zero-padding of generator convolution layers. + s = args.batch_shape // args.zoom + current = lasagne.layers.helper.get_output_shape(self.network['out'], {self.network['seed']: (1, 3, s, s)}) + args.batch_shape = args.batch_shape * 2 - current[2] + + self.network['img'].shape = (args.batch_size, 3, args.batch_shape, args.batch_shape) + self.network['seed'].shape = (args.batch_size, 3, args.batch_shape // args.zoom, args.batch_shape // args.zoom) + # How to re-force this to compute more elegantly using Lasagne? + self.network['out'].input_shape = lasagne.layers.get_output_shape(self.network['out'].input_layer, + {self.network['seed']: self.network['seed'].shape}) + if args.train: - concatenated = lasagne.layers.ConcatLayer([self.network['img'], self.network['out']], axis=0) + concatenated = lasagne.layers.ConcatLayer([self.network['img'], self.network['out']], + axis=0, cropping=(None, None, 'center', 'center')) self.setup_perceptual(concatenated) self.load_perceptual() self.setup_discriminator() self.load_generator(params) - self.compile() #------------------------------------------------------------------------------------------------------------------ # Network Configuration @@ -265,7 +276,7 @@ class Model(object): return list(self.network.values())[-1] def make_layer(self, name, input, units, filter_size=(3,3), stride=(1,1), pad=(1,1), alpha=0.25): - conv = ConvLayer(input, units, filter_size=filter_size, stride=stride, pad=pad, nonlinearity=None) + conv = ConvLayer(input, units, filter_size, stride=stride, pad=self.pad_override or pad, nonlinearity=None) prelu = lasagne.layers.ParametricRectifierLayer(conv, alpha=lasagne.init.Constant(alpha)) self.network[name+'x'] = conv self.network[name+'>'] = prelu @@ -277,6 +288,7 @@ class Model(object): return ElemwiseSumLayer([input, self.last_layer()]) if args.generator_residual else self.last_layer() def setup_generator(self, input, config): + self.pad_override = (0, 0) for k, v in config.items(): setattr(args, k, v) args.zoom = 2**(args.generator_upscale - args.generator_downscale) @@ -301,6 +313,7 @@ class Model(object): self.network['out'] = ConvLayer(self.last_layer(), 3, filter_size=(5,5), stride=(1,1), pad=(2,2), nonlinearity=lasagne.nonlinearities.tanh) + self.pad_override = None def setup_perceptual(self, input): """Use lasagne to create a network of convolution layers using pre-trained VGG19 weights. @@ -405,13 +418,13 @@ class Model(object): return T.mean(((x[:,:,:-1,:-1] - x[:,:,1:,:-1])**2 + (x[:,:,:-1,:-1] - x[:,:,:-1,1:])**2)**1.25) def loss_adversarial(self, d): - return T.mean(1.0 - T.nnet.softplus(d[args.batch_size:])) + return T.mean(1.0 - T.nnet.softminus(d[args.batch_size:])) def loss_discriminator(self, d): return T.mean(T.nnet.softminus(d[args.batch_size:]) - T.nnet.softplus(d[:args.batch_size])) def compile(self): - # Helper function for rendering test images during training, or standalone non-training mode. + # Helper function for rendering test images during training, or standalone inference mode. input_tensor, seed_tensor = T.tensor4(), T.tensor4() input_layers = {self.network['img']: input_tensor, self.network['seed']: seed_tensor} output = lasagne.layers.get_output([self.network[k] for k in ['seed', 'out']], input_layers, deterministic=True) @@ -437,7 +450,7 @@ class Model(object): disc_losses = [self.loss_discriminator(disc_out)] disc_params = list(itertools.chain(*[l.get_params() for k, l in self.network.items() if 'disc' in k])) print(' - {} tensors learned for discriminator.'.format(len(disc_params))) - grads = [g.clip(-1.0, +1.0) for g in T.grad(sum(disc_losses, 0.0), disc_params)] + grads = [g.clip(-5.0, +5.0) for g in T.grad(sum(disc_losses, 0.0), disc_params)] disc_updates = lasagne.updates.adam(grads, disc_params, learning_rate=self.disc_lr) # Combined Theano function for updating both generator and discriminator at the same time. @@ -459,6 +472,7 @@ class NeuralEnhancer(object): self.model = Model() self.thread = DataLoader() if loader else None + self.model.compile() print('{}'.format(ansi.ENDC)) @@ -476,7 +490,7 @@ class NeuralEnhancer(object): l_r, t_cur = args.learning_rate, 0 while True: - yield l_r if t_cur > 0 else l_r * 0.1 + yield l_r t_cur += 1 if t_cur % args.learning_period == 0: l_r *= args.learning_decay @@ -510,7 +524,7 @@ class NeuralEnhancer(object): stats /= args.epoch_size totals, labels = [sum(total)] + list(total), ['total', 'prcpt', 'smthn', 'advrs'] gen_info = ['{}{}{}={:4.2e}'.format(ansi.WHITE_B, k, ansi.ENDC, v) for k, v in zip(labels, totals)] - print('\rEpoch #{} at {:4.1f}s, lr={:4.2e}{}'.format(epoch+1, time.time()-start, l_r, ' '*(args.epoch_size-35))) + print('\rEpoch #{} at {:4.1f}s, lr={:4.2e}{}'.format(epoch+1, time.time()-start, l_r, ' '*(args.epoch_size-30))) print(' - generator {}'.format(' '.join(gen_info))) real, fake = stats[:args.batch_size], stats[args.batch_size:]