Stabilize the training, tuned the architecture.

main
Alex J. Champandard 9 years ago
parent 4580e5531a
commit 6542f435b4

@ -44,7 +44,7 @@ add_arg('--generator-blocks', default=4, type=int)
add_arg('--generator-residual', default=2, type=int)
add_arg('--perceptual-layer', default='conv2_2', type=str)
add_arg('--perceptual-weight', default=1e0, type=float)
add_arg('--smoothness-weight', default=1e6, type=float)
add_arg('--smoothness-weight', default=1e7, type=float)
add_arg('--adversary-weight', default=0.0, type=float)
add_arg('--scales', default=1, type=int, help='')
add_arg('--device', default='gpu0', type=str, help='Name of the CPU/GPU number to use, for Theano.')
@ -100,7 +100,7 @@ if sys.platform == 'win32':
# Deep Learning Framework
import lasagne
from lasagne.layers import Conv2DLayer as ConvLayer, Deconv2DLayer as DeconvLayer, Pool2DLayer as PoolLayer
from lasagne.layers import InputLayer, ConcatLayer, batch_norm, ElemwiseSumLayer
from lasagne.layers import InputLayer, ConcatLayer, ElemwiseSumLayer
print('{} - Using the device `{}` for neural computation.{}\n'.format(ansi.CYAN, theano.config.device, ansi.ENDC))
@ -182,7 +182,7 @@ class Model(object):
def __init__(self):
self.network = collections.OrderedDict()
self.network['img'] = InputLayer((None, 3, None, None))
self.network['seed'] = PoolLayer(self.network['img'], pool_size=2**args.scales)
self.network['seed'] = PoolLayer(self.network['img'], pool_size=2**args.scales, mode='average_exc_pad')
self.setup_generator(self.network['seed'])
concatenated = lasagne.layers.ConcatLayer([self.network['img'], self.network['out']], axis=0)
@ -200,31 +200,32 @@ class Model(object):
def last_layer(self):
return list(self.network.values())[-1]
def make_layer(self, input, units, filter_size=(3,3), stride=(1,1), pad=(1,1), nl='prelu'):
conv = ConvLayer(input, units, filter_size=filter_size, stride=stride, pad=pad,
nonlinearity=lasagne.nonlinearities.linear)
if nl == 'relu': conv.nonlinearity = lasagne.nonlinearities.rectify
if nl == 'prelu': conv = lasagne.layers.prelu(conv)
return conv
def make_layer(self, name, input, units, filter_size=(3,3), stride=(1,1), pad=(1,1), alpha=0.25):
# bias = None if normalized else lasagne.init.Constant(0.0)
conv = ConvLayer(input, units, filter_size=filter_size, stride=stride, pad=pad, nonlinearity=None)
# if normalized: conv = lasagne.layers.BatchNormLayer(conv)
prelu = lasagne.layers.ParametricRectifierLayer(conv, alpha=lasagne.init.Constant(alpha))
self.network[name+'x'] = conv
self.network[name+'>'] = prelu
return prelu
def make_block(self, name, input, units):
self.network[name+'|Ac'] = self.make_layer(input, units)
self.network[name+'|An'] = batch_norm(self.last_layer()).input_layer
self.network[name+'|Bc'] = self.make_layer(self.last_layer(), units)
self.network[name+'|Bn'] = batch_norm(self.last_layer()).input_layer
self.make_layer(name+'-A', input, units, alpha=0.25)
self.make_layer(name+'-B', self.last_layer(), units, alpha=1.0)
return ElemwiseSumLayer([input, self.last_layer()]) if args.generator_residual else self.last_layer()
def setup_generator(self, input):
units = args.generator_filters
self.network['iter.0'] = self.make_layer(input, units, filter_size=(5,5), pad=(2,2))
self.make_layer('iter.0', input, units, filter_size=(5,5), pad=(2,2))
for i in range(0, args.generator_blocks):
self.network['iter.%i'%(i+1)] = self.make_block('iter.%i'%(i+1), self.last_layer(), units)
for i in range(args.scales, 0, -1):
self.network['scale%i.3'%i] = self.make_layer(self.last_layer(), units*2)
self.network['scale%i.2'%i] = SubpixelShuffle(self.network['scale%i.3'%i], units//2, 2)
self.network['scale%i.1'%i] = self.make_layer(self.network['scale%i.2'%i], units)
u = units // 2**(args.scales-i)
self.make_layer('scale%i.3'%i, self.last_layer(), u*4)
self.network['scale%i.2'%i] = SubpixelShuffle(self.last_layer(), u, 2)
self.make_layer('scale%i.1'%i, self.network['scale%i.2'%i], u)
self.network['out'] = ConvLayer(self.last_layer(), 3, filter_size=(5,5), stride=(1,1), pad=(2,2),
nonlinearity=lasagne.nonlinearities.tanh)
@ -235,8 +236,8 @@ class Model(object):
self.network['disc3'] = ConvLayer(self.network['conv3_2'], 256, filter_size=(3,3), stride=(1,1), pad=(1,1))
hypercolumn = ConcatLayer([self.network['disc1'], self.network['disc2'], self.network['disc3']])
self.network['disc4'] = ConvLayer(hypercolumn, 192, filter_size=(3,3), stride=(1,1))
self.network['disc'] = batch_norm(ConvLayer(self.last_layer(), 1, filter_size=(1,1), stride=(1,1), pad=(0,0),
nonlinearity=lasagne.nonlinearities.sigmoid))
self.network['disc'] = ConvLayer(self.last_layer(), 1, filter_size=(1,1), stride=(1,1), pad=(0,0),
nonlinearity=lasagne.nonlinearities.sigmoid)
def setup_perceptual(self, input):
"""Use lasagne to create a network of convolution layers using pre-trained VGG19 weights.
@ -352,10 +353,9 @@ class Model(object):
self.fit = theano.function([input_tensor], gen_losses, updates=collections.OrderedDict(updates))
# Helper function for rendering test images deterministically, computing statistics.
gen_out, gen_inp = lasagne.layers.get_output([self.network['out'], self.network['img']],
outputs = lasagne.layers.get_output([self.network[k] for k in ['img', 'seed', 'out']],
input_layers, deterministic=True)
self.predict = theano.function([input_tensor], [gen_out, gen_inp])
self.predict = theano.function([input_tensor], outputs)
class NeuralEnhancer(object):
@ -374,29 +374,32 @@ class NeuralEnhancer(object):
image = scipy.misc.toimage(img * 255.0, cmin=0, cmax=255)
image.save(fn)
def show_progress(self, repro, orign):
def show_progress(self, orign, scald, repro):
for i in range(args.batch_size):
self.imsave('valid/%03i_orign.png' % i, orign[i])
self.imsave('valid/%03i_repro.png' % i, repro[i])
self.imsave('valid/%03i_origin.png' % i, orign[i])
self.imsave('valid/%03i_pixels.png' % i, scald[i])
self.imsave('valid/%03i_reprod.png' % i, repro[i])
def train(self):
images = np.zeros((args.batch_size, 3, args.batch_resolution, args.batch_resolution), dtype=np.float32)
l_min, l_max, l_mult = 1E-7, 1E-3, 0.2
t_cur, t_i, t_mult = 120, 150, 1
try:
i, running, start = 0, None, time.time()
for k in range(args.epochs):
for epoch in range(args.epochs):
total = None
for _ in range(args.epoch_size):
i += 1
l_r = l_min + 0.5 * (l_max - l_min) * (1.0 + math.cos(t_cur / t_i * math.pi))
t_cur += 1
l_r = 1E-4
self.model.gen_lr.set_value(l_r)
if t_cur >= t_i:
t_cur, t_i = 0, int(t_i * t_mult)
l_max = max(l_max * l_mult, 1e-12)
l_min = max(l_min * l_mult, 1e-8)
l_max = max(l_max * l_mult, 1e-11)
l_min = max(l_min * l_mult, 1e-7)
self.thread.copy(images)
losses = np.array(self.model.fit(images), dtype=np.float32)
@ -406,27 +409,28 @@ class NeuralEnhancer(object):
running = l if running is None else running * 0.9 + 0.1 * l
print('' if l > running else '', end=' ', flush=True)
repro, orign = self.model.predict(images)
self.show_progress(repro, orign)
orign, scald, repro = self.model.predict(images)
self.show_progress(orign, scald, repro)
total /= args.epoch_size
totals, labels = [sum(total)] + list(total), ['total', 'prcpt', 'smthn', 'advrs']
losses = ['{}{}{}={:4.2e}'.format(ansi.WHITE_B, k, ansi.ENDC, v) for k, v in zip(labels, totals)]
print('\rEpoch #{} at {:4.1f}s{}'.format(k+1, time.time()-start, ' '*args.epoch_size))
print(' - losses {}'.format(' '.join(losses)))
gen_info = ['{}{}{}={:4.2e}'.format(ansi.WHITE_B, k, ansi.ENDC, v) for k, v in zip(labels, totals)]
print('\rEpoch #{} at {:4.1f}s{}'.format(epoch+1, time.time()-start, ' '*args.epoch_size))
print(' - generator {}'.format(' '.join(gen_info)))
# print(' - discriminator {}'.format(' '.join(gen_stats)))
# print(stats[:args.batch_size].mean(), stats[args.batch_size:].mean())
# if k == 0: self.model.disc_lr.set_value(l_r)
# if k == 1: self.model.adversary_weight.set_value(args.adversary_weight)
# if epoch == 0: self.model.disc_lr.set_value(l_r)
# if epoch == 1: self.model.adversary_weight.set_value(args.adversary_weight)
except KeyboardInterrupt:
pass
print('\n{}Trained {}x super-resolution for {} epochs.{}'\
.format(ansi.CYAN_B, 2**args.scales, args.epochs, ansi.CYAN))
.format(ansi.CYAN_B, 2**args.scales, epoch, ansi.CYAN))
self.model.save_generator()
print(ansi.ENDC)
if __name__ == "__main__":
enhancer = NeuralEnhancer()
try:
enhancer.train()
except KeyboardInterrupt:
pass

Loading…
Cancel
Save