|
|
|
|
@ -19,6 +19,7 @@ import sys
|
|
|
|
|
import bz2
|
|
|
|
|
import glob
|
|
|
|
|
import math
|
|
|
|
|
import time
|
|
|
|
|
import pickle
|
|
|
|
|
import random
|
|
|
|
|
import argparse
|
|
|
|
|
@ -34,12 +35,13 @@ add_arg = parser.add_argument
|
|
|
|
|
add_arg('--batch-size', default=15, type=int)
|
|
|
|
|
add_arg('--batch-resolution', default=256, type=int)
|
|
|
|
|
add_arg('--epoch-size', default=36, type=int)
|
|
|
|
|
add_arg('--epochs', default=15, type=int)
|
|
|
|
|
add_arg('--epochs', default=25, type=int)
|
|
|
|
|
add_arg('--generator-filters', default=128, type=int)
|
|
|
|
|
add_arg('--generator-blocks', default=4, type=int)
|
|
|
|
|
add_arg('--generator-residual', default=2, type=int)
|
|
|
|
|
add_arg('--perceptual-layer', default='conv2_2', type=str)
|
|
|
|
|
add_arg('--perceptual-weight', default=1e0, type=float)
|
|
|
|
|
add_arg('--smoothness-weight', default=1e4, type=float)
|
|
|
|
|
add_arg('--smoothness-weight', default=1e6, type=float)
|
|
|
|
|
add_arg('--adversary-weight', default=0.0, type=float)
|
|
|
|
|
add_arg('--scales', default=1, type=int, help='')
|
|
|
|
|
add_arg('--device', default='gpu0', type=str, help='Name of the CPU/GPU number to use, for Theano.')
|
|
|
|
|
@ -50,8 +52,8 @@ args = parser.parse_args()
|
|
|
|
|
|
|
|
|
|
# Color coded output helps visualize the information a little better, plus it looks cool!
|
|
|
|
|
class ansi:
|
|
|
|
|
BOLD = '\033[1;97m'
|
|
|
|
|
WHITE = '\033[0;97m'
|
|
|
|
|
WHITE_B = '\033[1;97m'
|
|
|
|
|
YELLOW = '\033[0;33m'
|
|
|
|
|
YELLOW_B = '\033[1;33m'
|
|
|
|
|
RED = '\033[0;31m'
|
|
|
|
|
@ -118,7 +120,6 @@ class DataLoader(threading.Thread):
|
|
|
|
|
|
|
|
|
|
def run(self):
|
|
|
|
|
files, cache = glob.glob('train/*.jpg'), {}
|
|
|
|
|
|
|
|
|
|
while True:
|
|
|
|
|
random.shuffle(files)
|
|
|
|
|
for i, f in enumerate(files[:args.batch_size]):
|
|
|
|
|
@ -190,14 +191,15 @@ class Model(object):
|
|
|
|
|
def last_layer(self):
|
|
|
|
|
return list(self.network.values())[-1]
|
|
|
|
|
|
|
|
|
|
def make_block(self, input, units):
|
|
|
|
|
l1 = batch_norm(ConvLayer(input, units, filter_size=(3,3), stride=(1,1), pad=1))
|
|
|
|
|
l2 = batch_norm(ConvLayer(l1, units, filter_size=(3,3), stride=(1,1), pad=1))
|
|
|
|
|
return ElemwiseSumLayer([input, l2])
|
|
|
|
|
def make_layer(self, input, units, filter_size=(3,3), stride=(1,1), pad=(1,1)):
|
|
|
|
|
conv = ConvLayer(input, units, filter_size=filter_size, stride=stride, pad=pad,
|
|
|
|
|
nonlinearity=lasagne.nonlinearities.elu)
|
|
|
|
|
return batch_norm(conv)
|
|
|
|
|
|
|
|
|
|
def make_layer(self, input, units, filter_size=(3,3), stride=(1,1), pad=(1,1), nl=None):
|
|
|
|
|
return ConvLayer(input, units, filter_size=filter_size, stride=stride, pad=pad,
|
|
|
|
|
nonlinearity=nl or lasagne.nonlinearities.rectify)
|
|
|
|
|
def make_block(self, input, units):
|
|
|
|
|
l1 = self.make_layer(input, units)
|
|
|
|
|
l2 = self.make_layer(l1, units)
|
|
|
|
|
return ElemwiseSumLayer([input, l2]) if args.generator_residual > 0 else l2
|
|
|
|
|
|
|
|
|
|
def setup_generator(self, input):
|
|
|
|
|
units = args.generator_filters
|
|
|
|
|
@ -220,15 +222,15 @@ class Model(object):
|
|
|
|
|
self.network['disc3'] = ConvLayer(self.network['conv3_2'], 256, filter_size=(3,3), stride=(1,1), pad=(1,1))
|
|
|
|
|
hypercolumn = ConcatLayer([self.network['disc1'], self.network['disc2'], self.network['disc3']])
|
|
|
|
|
self.network['disc4'] = ConvLayer(hypercolumn, 192, filter_size=(3,3), stride=(1,1))
|
|
|
|
|
self.network['disc'] = batch_norm(ConvLayer(self.last_layer(), 1, filter_size=(1,1), stride=(1,1), pad=(0,0),
|
|
|
|
|
nonlinearity=lasagne.nonlinearities.sigmoid))
|
|
|
|
|
self.network['disc'] = ConvLayer(self.last_layer(), 1, filter_size=(1,1), stride=(1,1), pad=(0,0),
|
|
|
|
|
nonlinearity=lasagne.nonlinearities.sigmoid)
|
|
|
|
|
|
|
|
|
|
def setup_perceptual(self, input):
|
|
|
|
|
"""Use lasagne to create a network of convolution layers using pre-trained VGG19 weights.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
offset = np.array([103.939, 116.779, 123.680], dtype=np.float32).reshape((1,3,1,1))
|
|
|
|
|
self.network['percept'] = lasagne.layers.NonlinearityLayer(input, lambda x: ((x+0.5).clip(0.0, 1.0)*255.0) - offset)
|
|
|
|
|
self.network['percept'] = lasagne.layers.NonlinearityLayer(input, lambda x: ((x+0.5)*255.0) - offset)
|
|
|
|
|
|
|
|
|
|
self.network['mse'] = self.network['percept']
|
|
|
|
|
self.network['conv1_1'] = ConvLayer(self.network['percept'], 64, 3, pad=1)
|
|
|
|
|
@ -273,9 +275,10 @@ class Model(object):
|
|
|
|
|
|
|
|
|
|
# Generator loss function, parameters and updates.
|
|
|
|
|
self.gen_lr = theano.shared(np.array(0.0, dtype=theano.config.floatX))
|
|
|
|
|
self.adversary_weight = theano.shared(np.array(0.0, dtype=theano.config.floatX))
|
|
|
|
|
gen_losses = [self.loss_perceptual(percept_out) * args.perceptual_weight,
|
|
|
|
|
self.loss_total_variation(gen_out) * args.smoothness_weight,
|
|
|
|
|
self.loss_adversarial(disc_out) * args.adversary_weight]
|
|
|
|
|
self.loss_total_variation(gen_out) * args.smoothness_weight]
|
|
|
|
|
#self.loss_adversarial(disc_out) * self.adversary_weight]
|
|
|
|
|
gen_params = lasagne.layers.get_all_params(self.network['out'], trainable=True)
|
|
|
|
|
print(' - {} tensors learned for generator.'.format(len(gen_params)))
|
|
|
|
|
gen_updates = lasagne.updates.adam(sum(gen_losses, 0.0), gen_params, learning_rate=self.gen_lr)
|
|
|
|
|
@ -292,9 +295,9 @@ class Model(object):
|
|
|
|
|
self.fit = theano.function([input_tensor], gen_losses, updates=collections.OrderedDict(updates))
|
|
|
|
|
|
|
|
|
|
# Helper function for rendering test images deterministically, computing statistics.
|
|
|
|
|
gen_out, gen_inp = lasagne.layers.get_output([self.network['out'], self.network['img']],
|
|
|
|
|
gen_out, gen_inp, disc_out = lasagne.layers.get_output([self.network[l] for l in ['out', 'img', 'disc']],
|
|
|
|
|
input_layers, deterministic=True)
|
|
|
|
|
self.predict = theano.function([input_tensor], [gen_out, gen_inp])
|
|
|
|
|
self.predict = theano.function([input_tensor], [gen_out, gen_inp]) # disc_out.mean(axis=(1,2,3))
|
|
|
|
|
|
|
|
|
|
def loss_perceptual(self, p):
|
|
|
|
|
return lasagne.objectives.squared_error(p[:args.batch_size], p[args.batch_size:]).mean()
|
|
|
|
|
@ -327,18 +330,17 @@ class NeuralEnhancer(object):
|
|
|
|
|
|
|
|
|
|
def show_progress(self, repro, orign):
|
|
|
|
|
for i in range(args.batch_size):
|
|
|
|
|
self.imsave('test/%03i_orign.png' % i, orign[i])
|
|
|
|
|
self.imsave('test/%03i_repro.png' % i, repro[i])
|
|
|
|
|
self.imsave('valid/%03i_orign.png' % i, orign[i])
|
|
|
|
|
self.imsave('valid/%03i_repro.png' % i, repro[i])
|
|
|
|
|
|
|
|
|
|
def train(self):
|
|
|
|
|
images = np.zeros((args.batch_size, 3, args.batch_resolution, args.batch_resolution), dtype=np.float32)
|
|
|
|
|
|
|
|
|
|
l_min, l_max, l_mult = 1E-7, 1E-3, 0.2
|
|
|
|
|
t_cur, t_i, t_mult = 120, 150, 1
|
|
|
|
|
|
|
|
|
|
i, running = 0, None
|
|
|
|
|
for _ in range(args.epochs):
|
|
|
|
|
total = None
|
|
|
|
|
for k in range(args.epochs):
|
|
|
|
|
total, start = None, time.time()
|
|
|
|
|
for _ in range(args.epoch_size):
|
|
|
|
|
i += 1
|
|
|
|
|
l_r = l_min + 0.5 * (l_max - l_min) * (1.0 + math.cos(t_cur / t_i * math.pi))
|
|
|
|
|
@ -346,23 +348,28 @@ class NeuralEnhancer(object):
|
|
|
|
|
self.model.gen_lr.set_value(l_r)
|
|
|
|
|
|
|
|
|
|
if t_cur >= t_i:
|
|
|
|
|
t_cur = 0
|
|
|
|
|
t_i = int(t_i * t_mult)
|
|
|
|
|
l_max = max(l_max * l_mult, 1e-10)
|
|
|
|
|
l_min = max(l_min * l_mult, 1e-6)
|
|
|
|
|
t_cur, t_i = 0, int(t_i * t_mult)
|
|
|
|
|
l_max = max(l_max * l_mult, 1e-12)
|
|
|
|
|
l_min = max(l_min * l_mult, 1e-8)
|
|
|
|
|
|
|
|
|
|
self.thread.copy(images)
|
|
|
|
|
losses = np.array(self.model.fit(images), dtype=np.float32)
|
|
|
|
|
total = total + losses if total is not None else losses
|
|
|
|
|
l = np.sum(losses)
|
|
|
|
|
assert not np.isnan(losses).any()
|
|
|
|
|
running = l if running is None else running * 0.9 + 0.1 * l
|
|
|
|
|
|
|
|
|
|
print('↑' if l > running else '↓', end=' ', flush=True)
|
|
|
|
|
|
|
|
|
|
self.show_progress(*self.model.predict(images))
|
|
|
|
|
total = total / args.epoch_size
|
|
|
|
|
labels = ['{}={:4.2e}'.format(k, v) for k, v in zip(['prcpt', 'smthn', 'advrs'], total)]
|
|
|
|
|
print('\nLosses: total={:4.2e} {}'.format(sum(total), ' '.join(labels)))
|
|
|
|
|
repro, orign = self.model.predict(images)
|
|
|
|
|
self.show_progress(repro, orign)
|
|
|
|
|
total /= args.epoch_size
|
|
|
|
|
totals, labels = [sum(total)] + list(total), ['total', 'prcpt', 'smthn', 'advrs']
|
|
|
|
|
losses = ['{}{}{}={:4.2e}'.format(ansi.WHITE_B, k, ansi.ENDC, v) for k, v in zip(labels, totals)]
|
|
|
|
|
print('\rEpoch #{} in {:4.1f}s{}'.format(k+1, time.time()-start, ' '*args.epoch_size))
|
|
|
|
|
print(' - losses {}\n'.format(' '.join(losses)))
|
|
|
|
|
# print(stats[:args.batch_size].mean(), stats[args.batch_size:].mean())
|
|
|
|
|
if k == 0: self.model.disc_lr.set_value(l_r)
|
|
|
|
|
if k == 1: self.model.adversary_weight.set_value(args.adversary_weight)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
|