diff --git a/enhance.py b/enhance.py index 0ce6440..cbc8822 100644 --- a/enhance.py +++ b/enhance.py @@ -171,6 +171,7 @@ class Model(object): concatenated = lasagne.layers.ConcatLayer([self.network['img'], self.network['out']], axis=0) self.setup_perceptual(concatenated) self.load_perceptual() + self.setup_discriminator() self.compile() def last_layer(self): @@ -195,6 +196,15 @@ class Model(object): l2 = batch_norm(ConvLayer(l1, units, filter_size=(3,3), stride=(1,1), pad=1)) return ElemwiseSumLayer([input, l2]) + def setup_discriminator(self): + self.network['disc1'] = ConvLayer(self.network['conv1_2'], 64, filter_size=(7,7), stride=(4,4), pad=(3,3)) + self.network['disc2'] = ConvLayer(self.network['conv2_2'], 128, filter_size=(5,5), stride=(2,2), pad=(2,2)) + self.network['disc3'] = ConvLayer(self.network['conv3_2'], 256, filter_size=(3,3), stride=(1,1), pad=(1,1)) + hypercolumn = ConcatLayer([self.network['disc1'], self.network['disc2'], self.network['disc3']]) + self.network['disc4'] = ConvLayer(hypercolumn, 192, filter_size=(3,3), stride=(1,1)) + self.network['disc'] = batch_norm(ConvLayer(self.last_layer(), 1, filter_size=(1,1), stride=(1,1), pad=(0,0), + nonlinearity=lasagne.nonlinearities.sigmoid)) + def setup_perceptual(self, input): """Use lasagne to create a network of convolution layers using pre-trained VGG19 weights. """ @@ -246,13 +256,21 @@ class Model(object): # Generator loss function, parameters and updates. self.gen_lr = theano.shared(np.array(0.0, dtype=theano.config.floatX)) gen_losses = [self.loss_perceptual(percept_out) * args.perceptual_weight, - self.loss_total_variation(gen_out) * args.smoothness_weight] + self.loss_total_variation(gen_out) * args.smoothness_weight, + self.loss_adversarial(disc_out) * args.adversary_weight] gen_params = lasagne.layers.get_all_params(self.network['out'], trainable=True) print(' - {} tensors learned for generator.'.format(len(gen_params))) gen_updates = lasagne.updates.adam(sum(gen_losses, 0.0), gen_params, learning_rate=self.gen_lr) + # Discriminator loss function, parameters and updates. + self.disc_lr = theano.shared(np.array(0.0, dtype=theano.config.floatX)) + disc_losses = [self.loss_discriminator(disc_out)] + disc_params = list(itertools.chain(*[l.get_params() for k, l in self.network.items() if 'disc' in k])) + print(' - {} tensors learned for discriminator.'.format(len(disc_params))) + disc_updates = lasagne.updates.adam(sum(disc_losses, 0.0), disc_params, learning_rate=self.disc_lr) + # Combined Theano function for updating both generator and discriminator at the same time. - updates = list(gen_updates.items()) + updates = list(gen_updates.items()) # + list(disc_updates.items()) self.fit = theano.function([input_tensor], gen_losses, updates=collections.OrderedDict(updates)) # Helper function for rendering test images deterministically, computing statistics. @@ -266,6 +284,12 @@ class Model(object): def loss_total_variation(self, x): return T.mean(((x[:,:,:-1,:-1] - x[:,:,1:,:-1])**2 + (x[:,:,:-1,:-1] - x[:,:,:-1,1:])**2)**1.25) + def loss_adversarial(self, d): + return 1.0 - T.log(d[args.batch_size:]).mean() + + def loss_discriminator(self, d): + return T.mean(T.log(d[args.batch_size:]) + T.log(1.0 - d[:args.batch_size])) + class NeuralEnhancer(object):