Add discriminator network too.

main
Alex J. Champandard 9 years ago
parent 3f24714039
commit 9f13695050

@ -171,6 +171,7 @@ class Model(object):
concatenated = lasagne.layers.ConcatLayer([self.network['img'], self.network['out']], axis=0) concatenated = lasagne.layers.ConcatLayer([self.network['img'], self.network['out']], axis=0)
self.setup_perceptual(concatenated) self.setup_perceptual(concatenated)
self.load_perceptual() self.load_perceptual()
self.setup_discriminator()
self.compile() self.compile()
def last_layer(self): def last_layer(self):
@ -195,6 +196,15 @@ class Model(object):
l2 = batch_norm(ConvLayer(l1, units, filter_size=(3,3), stride=(1,1), pad=1)) l2 = batch_norm(ConvLayer(l1, units, filter_size=(3,3), stride=(1,1), pad=1))
return ElemwiseSumLayer([input, l2]) return ElemwiseSumLayer([input, l2])
def setup_discriminator(self):
self.network['disc1'] = ConvLayer(self.network['conv1_2'], 64, filter_size=(7,7), stride=(4,4), pad=(3,3))
self.network['disc2'] = ConvLayer(self.network['conv2_2'], 128, filter_size=(5,5), stride=(2,2), pad=(2,2))
self.network['disc3'] = ConvLayer(self.network['conv3_2'], 256, filter_size=(3,3), stride=(1,1), pad=(1,1))
hypercolumn = ConcatLayer([self.network['disc1'], self.network['disc2'], self.network['disc3']])
self.network['disc4'] = ConvLayer(hypercolumn, 192, filter_size=(3,3), stride=(1,1))
self.network['disc'] = batch_norm(ConvLayer(self.last_layer(), 1, filter_size=(1,1), stride=(1,1), pad=(0,0),
nonlinearity=lasagne.nonlinearities.sigmoid))
def setup_perceptual(self, input): def setup_perceptual(self, input):
"""Use lasagne to create a network of convolution layers using pre-trained VGG19 weights. """Use lasagne to create a network of convolution layers using pre-trained VGG19 weights.
""" """
@ -246,13 +256,21 @@ class Model(object):
# Generator loss function, parameters and updates. # Generator loss function, parameters and updates.
self.gen_lr = theano.shared(np.array(0.0, dtype=theano.config.floatX)) self.gen_lr = theano.shared(np.array(0.0, dtype=theano.config.floatX))
gen_losses = [self.loss_perceptual(percept_out) * args.perceptual_weight, gen_losses = [self.loss_perceptual(percept_out) * args.perceptual_weight,
self.loss_total_variation(gen_out) * args.smoothness_weight] self.loss_total_variation(gen_out) * args.smoothness_weight,
self.loss_adversarial(disc_out) * args.adversary_weight]
gen_params = lasagne.layers.get_all_params(self.network['out'], trainable=True) gen_params = lasagne.layers.get_all_params(self.network['out'], trainable=True)
print(' - {} tensors learned for generator.'.format(len(gen_params))) print(' - {} tensors learned for generator.'.format(len(gen_params)))
gen_updates = lasagne.updates.adam(sum(gen_losses, 0.0), gen_params, learning_rate=self.gen_lr) gen_updates = lasagne.updates.adam(sum(gen_losses, 0.0), gen_params, learning_rate=self.gen_lr)
# Discriminator loss function, parameters and updates.
self.disc_lr = theano.shared(np.array(0.0, dtype=theano.config.floatX))
disc_losses = [self.loss_discriminator(disc_out)]
disc_params = list(itertools.chain(*[l.get_params() for k, l in self.network.items() if 'disc' in k]))
print(' - {} tensors learned for discriminator.'.format(len(disc_params)))
disc_updates = lasagne.updates.adam(sum(disc_losses, 0.0), disc_params, learning_rate=self.disc_lr)
# Combined Theano function for updating both generator and discriminator at the same time. # Combined Theano function for updating both generator and discriminator at the same time.
updates = list(gen_updates.items()) updates = list(gen_updates.items()) # + list(disc_updates.items())
self.fit = theano.function([input_tensor], gen_losses, updates=collections.OrderedDict(updates)) self.fit = theano.function([input_tensor], gen_losses, updates=collections.OrderedDict(updates))
# Helper function for rendering test images deterministically, computing statistics. # Helper function for rendering test images deterministically, computing statistics.
@ -266,6 +284,12 @@ class Model(object):
def loss_total_variation(self, x): def loss_total_variation(self, x):
return T.mean(((x[:,:,:-1,:-1] - x[:,:,1:,:-1])**2 + (x[:,:,:-1,:-1] - x[:,:,:-1,1:])**2)**1.25) return T.mean(((x[:,:,:-1,:-1] - x[:,:,1:,:-1])**2 + (x[:,:,:-1,:-1] - x[:,:,:-1,1:])**2)**1.25)
def loss_adversarial(self, d):
return 1.0 - T.log(d[args.batch_size:]).mean()
def loss_discriminator(self, d):
return T.mean(T.log(d[args.batch_size:]) + T.log(1.0 - d[:args.batch_size]))
class NeuralEnhancer(object): class NeuralEnhancer(object):

Loading…
Cancel
Save