|
|
|
@ -201,9 +201,7 @@ class Model(object):
|
|
|
|
return list(self.network.values())[-1]
|
|
|
|
return list(self.network.values())[-1]
|
|
|
|
|
|
|
|
|
|
|
|
def make_layer(self, name, input, units, filter_size=(3,3), stride=(1,1), pad=(1,1), alpha=0.25):
|
|
|
|
def make_layer(self, name, input, units, filter_size=(3,3), stride=(1,1), pad=(1,1), alpha=0.25):
|
|
|
|
# bias = None if normalized else lasagne.init.Constant(0.0)
|
|
|
|
|
|
|
|
conv = ConvLayer(input, units, filter_size=filter_size, stride=stride, pad=pad, nonlinearity=None)
|
|
|
|
conv = ConvLayer(input, units, filter_size=filter_size, stride=stride, pad=pad, nonlinearity=None)
|
|
|
|
# if normalized: conv = lasagne.layers.BatchNormLayer(conv)
|
|
|
|
|
|
|
|
prelu = lasagne.layers.ParametricRectifierLayer(conv, alpha=lasagne.init.Constant(alpha))
|
|
|
|
prelu = lasagne.layers.ParametricRectifierLayer(conv, alpha=lasagne.init.Constant(alpha))
|
|
|
|
self.network[name+'x'] = conv
|
|
|
|
self.network[name+'x'] = conv
|
|
|
|
self.network[name+'>'] = prelu
|
|
|
|
self.network[name+'>'] = prelu
|
|
|
|
@ -230,15 +228,6 @@ class Model(object):
|
|
|
|
self.network['out'] = ConvLayer(self.last_layer(), 3, filter_size=(5,5), stride=(1,1), pad=(2,2),
|
|
|
|
self.network['out'] = ConvLayer(self.last_layer(), 3, filter_size=(5,5), stride=(1,1), pad=(2,2),
|
|
|
|
nonlinearity=lasagne.nonlinearities.tanh)
|
|
|
|
nonlinearity=lasagne.nonlinearities.tanh)
|
|
|
|
|
|
|
|
|
|
|
|
def setup_discriminator(self):
|
|
|
|
|
|
|
|
self.network['disc1'] = ConvLayer(self.network['conv1_2'], 64, filter_size=(7,7), stride=(4,4), pad=(3,3))
|
|
|
|
|
|
|
|
self.network['disc2'] = ConvLayer(self.network['conv2_2'], 128, filter_size=(5,5), stride=(2,2), pad=(2,2))
|
|
|
|
|
|
|
|
self.network['disc3'] = ConvLayer(self.network['conv3_2'], 256, filter_size=(3,3), stride=(1,1), pad=(1,1))
|
|
|
|
|
|
|
|
hypercolumn = ConcatLayer([self.network['disc1'], self.network['disc2'], self.network['disc3']])
|
|
|
|
|
|
|
|
self.network['disc4'] = ConvLayer(hypercolumn, 192, filter_size=(3,3), stride=(1,1))
|
|
|
|
|
|
|
|
self.network['disc'] = ConvLayer(self.last_layer(), 1, filter_size=(1,1), stride=(1,1), pad=(0,0),
|
|
|
|
|
|
|
|
nonlinearity=lasagne.nonlinearities.sigmoid)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def setup_perceptual(self, input):
|
|
|
|
def setup_perceptual(self, input):
|
|
|
|
"""Use lasagne to create a network of convolution layers using pre-trained VGG19 weights.
|
|
|
|
"""Use lasagne to create a network of convolution layers using pre-trained VGG19 weights.
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
@ -268,6 +257,18 @@ class Model(object):
|
|
|
|
self.network['conv5_3'] = ConvLayer(self.network['conv5_2'], 512, 3, pad=1)
|
|
|
|
self.network['conv5_3'] = ConvLayer(self.network['conv5_2'], 512, 3, pad=1)
|
|
|
|
self.network['conv5_4'] = ConvLayer(self.network['conv5_3'], 512, 3, pad=1)
|
|
|
|
self.network['conv5_4'] = ConvLayer(self.network['conv5_3'], 512, 3, pad=1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def setup_discriminator(self):
|
|
|
|
|
|
|
|
self.make_layer('disc1.1', batch_norm(self.network['conv1_2']), 64, filter_size=(5,5), stride=(2,2), pad=(2,2))
|
|
|
|
|
|
|
|
self.make_layer('disc1.2', self.last_layer(), 64, filter_size=(5,5), stride=(2,2), pad=(2,2))
|
|
|
|
|
|
|
|
self.make_layer('disc2', self.network['conv2_2'], 128, filter_size=(5,5), stride=(2,2), pad=(2,2))
|
|
|
|
|
|
|
|
self.make_layer('disc3', self.network['conv3_2'], 256, filter_size=(3,3), stride=(1,1), pad=(1,1))
|
|
|
|
|
|
|
|
hypercolumn = ConcatLayer([self.network['disc1.2>'], self.network['disc2>'], self.network['disc3>']])
|
|
|
|
|
|
|
|
self.make_layer('disc4', hypercolumn, 192, filter_size=(5,5), stride=(2,2))
|
|
|
|
|
|
|
|
self.make_layer('disc5', self.last_layer(), 96, filter_size=(5,5), stride=(2,2))
|
|
|
|
|
|
|
|
self.network['disc'] = batch_norm(ConvLayer(self.last_layer(), 1, filter_size=(1,1),
|
|
|
|
|
|
|
|
nonlinearity=lasagne.nonlinearities.sigmoid))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#------------------------------------------------------------------------------------------------------------------
|
|
|
|
#------------------------------------------------------------------------------------------------------------------
|
|
|
|
# Input / Output
|
|
|
|
# Input / Output
|
|
|
|
#------------------------------------------------------------------------------------------------------------------
|
|
|
|
#------------------------------------------------------------------------------------------------------------------
|
|
|
|
@ -299,13 +300,13 @@ class Model(object):
|
|
|
|
print(' - Saved model as `{}` after training.'.format(filename))
|
|
|
|
print(' - Saved model as `{}` after training.'.format(filename))
|
|
|
|
|
|
|
|
|
|
|
|
def load_generator(self):
|
|
|
|
def load_generator(self):
|
|
|
|
if not args.load: return
|
|
|
|
|
|
|
|
filename = args.model % 2**args.scales
|
|
|
|
filename = args.model % 2**args.scales
|
|
|
|
if not os.path.exists(filename): return
|
|
|
|
if not os.path.exists(filename) or not args.load: return
|
|
|
|
params = pickle.load(bz2.open(filename, 'rb'))
|
|
|
|
params = pickle.load(bz2.open(filename, 'rb'))
|
|
|
|
for k, l in self.list_generator_layers():
|
|
|
|
for k, l in self.list_generator_layers():
|
|
|
|
if k not in params: continue
|
|
|
|
if k not in params: continue
|
|
|
|
(p.set_value(v) for p, v in zip(l.get_params(), params[k]))
|
|
|
|
for p, v in zip(l.get_params(), params[k]):
|
|
|
|
|
|
|
|
p.set_value(v.astype(np.float32))
|
|
|
|
print(' - Loaded file `{}` with trained model.'.format(filename))
|
|
|
|
print(' - Loaded file `{}` with trained model.'.format(filename))
|
|
|
|
|
|
|
|
|
|
|
|
#------------------------------------------------------------------------------------------------------------------
|
|
|
|
#------------------------------------------------------------------------------------------------------------------
|
|
|
|
@ -335,8 +336,8 @@ class Model(object):
|
|
|
|
self.gen_lr = theano.shared(np.array(0.0, dtype=theano.config.floatX))
|
|
|
|
self.gen_lr = theano.shared(np.array(0.0, dtype=theano.config.floatX))
|
|
|
|
self.adversary_weight = theano.shared(np.array(0.0, dtype=theano.config.floatX))
|
|
|
|
self.adversary_weight = theano.shared(np.array(0.0, dtype=theano.config.floatX))
|
|
|
|
gen_losses = [self.loss_perceptual(percept_out) * args.perceptual_weight,
|
|
|
|
gen_losses = [self.loss_perceptual(percept_out) * args.perceptual_weight,
|
|
|
|
self.loss_total_variation(gen_out) * args.smoothness_weight]
|
|
|
|
self.loss_total_variation(gen_out) * args.smoothness_weight,
|
|
|
|
#self.loss_adversarial(disc_out) * self.adversary_weight]
|
|
|
|
self.loss_adversarial(disc_out) * self.adversary_weight]
|
|
|
|
gen_params = lasagne.layers.get_all_params(self.network['out'], trainable=True)
|
|
|
|
gen_params = lasagne.layers.get_all_params(self.network['out'], trainable=True)
|
|
|
|
print(' - {} tensors learned for generator.'.format(len(gen_params)))
|
|
|
|
print(' - {} tensors learned for generator.'.format(len(gen_params)))
|
|
|
|
gen_updates = lasagne.updates.adam(sum(gen_losses, 0.0), gen_params, learning_rate=self.gen_lr)
|
|
|
|
gen_updates = lasagne.updates.adam(sum(gen_losses, 0.0), gen_params, learning_rate=self.gen_lr)
|
|
|
|
@ -349,13 +350,13 @@ class Model(object):
|
|
|
|
disc_updates = lasagne.updates.adam(sum(disc_losses, 0.0), disc_params, learning_rate=self.disc_lr)
|
|
|
|
disc_updates = lasagne.updates.adam(sum(disc_losses, 0.0), disc_params, learning_rate=self.disc_lr)
|
|
|
|
|
|
|
|
|
|
|
|
# Combined Theano function for updating both generator and discriminator at the same time.
|
|
|
|
# Combined Theano function for updating both generator and discriminator at the same time.
|
|
|
|
updates = list(gen_updates.items()) # + list(disc_updates.items())
|
|
|
|
updates = list(gen_updates.items()) + list(disc_updates.items())
|
|
|
|
self.fit = theano.function([input_tensor], gen_losses, updates=collections.OrderedDict(updates))
|
|
|
|
self.fit = theano.function([input_tensor], gen_losses, updates=collections.OrderedDict(updates))
|
|
|
|
|
|
|
|
|
|
|
|
# Helper function for rendering test images deterministically, computing statistics.
|
|
|
|
# Helper function for rendering test images deterministically, computing statistics.
|
|
|
|
outputs = lasagne.layers.get_output([self.network[k] for k in ['img', 'seed', 'out']],
|
|
|
|
*outputs, disc_out = lasagne.layers.get_output([self.network[k] for k in ['img', 'seed', 'out', 'disc']],
|
|
|
|
input_layers, deterministic=True)
|
|
|
|
input_layers, deterministic=True)
|
|
|
|
self.predict = theano.function([input_tensor], outputs)
|
|
|
|
self.predict = theano.function([input_tensor], outputs + [disc_out.mean(axis=(1,2,3))])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class NeuralEnhancer(object):
|
|
|
|
class NeuralEnhancer(object):
|
|
|
|
@ -395,6 +396,7 @@ class NeuralEnhancer(object):
|
|
|
|
t_cur += 1
|
|
|
|
t_cur += 1
|
|
|
|
l_r = 1E-4
|
|
|
|
l_r = 1E-4
|
|
|
|
self.model.gen_lr.set_value(l_r)
|
|
|
|
self.model.gen_lr.set_value(l_r)
|
|
|
|
|
|
|
|
self.model.disc_lr.set_value(l_r)
|
|
|
|
|
|
|
|
|
|
|
|
if t_cur >= t_i:
|
|
|
|
if t_cur >= t_i:
|
|
|
|
t_cur, t_i = 0, int(t_i * t_mult)
|
|
|
|
t_cur, t_i = 0, int(t_i * t_mult)
|
|
|
|
@ -409,7 +411,7 @@ class NeuralEnhancer(object):
|
|
|
|
running = l if running is None else running * 0.9 + 0.1 * l
|
|
|
|
running = l if running is None else running * 0.9 + 0.1 * l
|
|
|
|
print('↑' if l > running else '↓', end=' ', flush=True)
|
|
|
|
print('↑' if l > running else '↓', end=' ', flush=True)
|
|
|
|
|
|
|
|
|
|
|
|
orign, scald, repro = self.model.predict(images)
|
|
|
|
orign, scald, repro, stats = self.model.predict(images)
|
|
|
|
self.show_progress(orign, scald, repro)
|
|
|
|
self.show_progress(orign, scald, repro)
|
|
|
|
total /= args.epoch_size
|
|
|
|
total /= args.epoch_size
|
|
|
|
totals, labels = [sum(total)] + list(total), ['total', 'prcpt', 'smthn', 'advrs']
|
|
|
|
totals, labels = [sum(total)] + list(total), ['total', 'prcpt', 'smthn', 'advrs']
|
|
|
|
@ -417,11 +419,11 @@ class NeuralEnhancer(object):
|
|
|
|
print('\rEpoch #{} at {:4.1f}s{}'.format(epoch+1, time.time()-start, ' '*args.epoch_size))
|
|
|
|
print('\rEpoch #{} at {:4.1f}s{}'.format(epoch+1, time.time()-start, ' '*args.epoch_size))
|
|
|
|
print(' - generator {}'.format(' '.join(gen_info)))
|
|
|
|
print(' - generator {}'.format(' '.join(gen_info)))
|
|
|
|
|
|
|
|
|
|
|
|
# print(' - discriminator {}'.format(' '.join(gen_stats)))
|
|
|
|
real, fake = stats[:args.batch_size], stats[args.batch_size:]
|
|
|
|
|
|
|
|
print(' - discriminator', real.mean(), len(np.where(real > 0.5)[0]), fake.mean(), len(np.where(fake < 0.5)[0]))
|
|
|
|
# print(stats[:args.batch_size].mean(), stats[args.batch_size:].mean())
|
|
|
|
if epoch == 0:
|
|
|
|
# if epoch == 0: self.model.disc_lr.set_value(l_r)
|
|
|
|
self.model.adversary_weight.set_value(args.adversary_weight)
|
|
|
|
# if epoch == 1: self.model.adversary_weight.set_value(args.adversary_weight)
|
|
|
|
running = None
|
|
|
|
except KeyboardInterrupt:
|
|
|
|
except KeyboardInterrupt:
|
|
|
|
pass
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
|