diff --git a/enhance.py b/enhance.py index 0f030ae..569391b 100644 --- a/enhance.py +++ b/enhance.py @@ -44,7 +44,7 @@ add_arg('--generator-blocks', default=4, type=int) add_arg('--generator-residual', default=2, type=int) add_arg('--perceptual-layer', default='conv2_2', type=str) add_arg('--perceptual-weight', default=1e0, type=float) -add_arg('--smoothness-weight', default=1e6, type=float) +add_arg('--smoothness-weight', default=1e7, type=float) add_arg('--adversary-weight', default=0.0, type=float) add_arg('--scales', default=1, type=int, help='') add_arg('--device', default='gpu0', type=str, help='Name of the CPU/GPU number to use, for Theano.') @@ -100,7 +100,7 @@ if sys.platform == 'win32': # Deep Learning Framework import lasagne from lasagne.layers import Conv2DLayer as ConvLayer, Deconv2DLayer as DeconvLayer, Pool2DLayer as PoolLayer -from lasagne.layers import InputLayer, ConcatLayer, batch_norm, ElemwiseSumLayer +from lasagne.layers import InputLayer, ConcatLayer, ElemwiseSumLayer print('{} - Using the device `{}` for neural computation.{}\n'.format(ansi.CYAN, theano.config.device, ansi.ENDC)) @@ -182,7 +182,7 @@ class Model(object): def __init__(self): self.network = collections.OrderedDict() self.network['img'] = InputLayer((None, 3, None, None)) - self.network['seed'] = PoolLayer(self.network['img'], pool_size=2**args.scales) + self.network['seed'] = PoolLayer(self.network['img'], pool_size=2**args.scales, mode='average_exc_pad') self.setup_generator(self.network['seed']) concatenated = lasagne.layers.ConcatLayer([self.network['img'], self.network['out']], axis=0) @@ -200,31 +200,32 @@ class Model(object): def last_layer(self): return list(self.network.values())[-1] - def make_layer(self, input, units, filter_size=(3,3), stride=(1,1), pad=(1,1), nl='prelu'): - conv = ConvLayer(input, units, filter_size=filter_size, stride=stride, pad=pad, - nonlinearity=lasagne.nonlinearities.linear) - if nl == 'relu': conv.nonlinearity = lasagne.nonlinearities.rectify - if nl == 'prelu': conv = lasagne.layers.prelu(conv) - return conv + def make_layer(self, name, input, units, filter_size=(3,3), stride=(1,1), pad=(1,1), alpha=0.25): + # bias = None if normalized else lasagne.init.Constant(0.0) + conv = ConvLayer(input, units, filter_size=filter_size, stride=stride, pad=pad, nonlinearity=None) + # if normalized: conv = lasagne.layers.BatchNormLayer(conv) + prelu = lasagne.layers.ParametricRectifierLayer(conv, alpha=lasagne.init.Constant(alpha)) + self.network[name+'x'] = conv + self.network[name+'>'] = prelu + return prelu def make_block(self, name, input, units): - self.network[name+'|Ac'] = self.make_layer(input, units) - self.network[name+'|An'] = batch_norm(self.last_layer()).input_layer - self.network[name+'|Bc'] = self.make_layer(self.last_layer(), units) - self.network[name+'|Bn'] = batch_norm(self.last_layer()).input_layer + self.make_layer(name+'-A', input, units, alpha=0.25) + self.make_layer(name+'-B', self.last_layer(), units, alpha=1.0) return ElemwiseSumLayer([input, self.last_layer()]) if args.generator_residual else self.last_layer() def setup_generator(self, input): units = args.generator_filters - self.network['iter.0'] = self.make_layer(input, units, filter_size=(5,5), pad=(2,2)) + self.make_layer('iter.0', input, units, filter_size=(5,5), pad=(2,2)) for i in range(0, args.generator_blocks): self.network['iter.%i'%(i+1)] = self.make_block('iter.%i'%(i+1), self.last_layer(), units) for i in range(args.scales, 0, -1): - self.network['scale%i.3'%i] = self.make_layer(self.last_layer(), units*2) - self.network['scale%i.2'%i] = SubpixelShuffle(self.network['scale%i.3'%i], units//2, 2) - self.network['scale%i.1'%i] = self.make_layer(self.network['scale%i.2'%i], units) + u = units // 2**(args.scales-i) + self.make_layer('scale%i.3'%i, self.last_layer(), u*4) + self.network['scale%i.2'%i] = SubpixelShuffle(self.last_layer(), u, 2) + self.make_layer('scale%i.1'%i, self.network['scale%i.2'%i], u) self.network['out'] = ConvLayer(self.last_layer(), 3, filter_size=(5,5), stride=(1,1), pad=(2,2), nonlinearity=lasagne.nonlinearities.tanh) @@ -235,8 +236,8 @@ class Model(object): self.network['disc3'] = ConvLayer(self.network['conv3_2'], 256, filter_size=(3,3), stride=(1,1), pad=(1,1)) hypercolumn = ConcatLayer([self.network['disc1'], self.network['disc2'], self.network['disc3']]) self.network['disc4'] = ConvLayer(hypercolumn, 192, filter_size=(3,3), stride=(1,1)) - self.network['disc'] = batch_norm(ConvLayer(self.last_layer(), 1, filter_size=(1,1), stride=(1,1), pad=(0,0), - nonlinearity=lasagne.nonlinearities.sigmoid)) + self.network['disc'] = ConvLayer(self.last_layer(), 1, filter_size=(1,1), stride=(1,1), pad=(0,0), + nonlinearity=lasagne.nonlinearities.sigmoid) def setup_perceptual(self, input): """Use lasagne to create a network of convolution layers using pre-trained VGG19 weights. @@ -352,10 +353,9 @@ class Model(object): self.fit = theano.function([input_tensor], gen_losses, updates=collections.OrderedDict(updates)) # Helper function for rendering test images deterministically, computing statistics. - gen_out, gen_inp = lasagne.layers.get_output([self.network['out'], self.network['img']], - input_layers, deterministic=True) - self.predict = theano.function([input_tensor], [gen_out, gen_inp]) - + outputs = lasagne.layers.get_output([self.network[k] for k in ['img', 'seed', 'out']], + input_layers, deterministic=True) + self.predict = theano.function([input_tensor], outputs) class NeuralEnhancer(object): @@ -374,59 +374,63 @@ class NeuralEnhancer(object): image = scipy.misc.toimage(img * 255.0, cmin=0, cmax=255) image.save(fn) - def show_progress(self, repro, orign): + def show_progress(self, orign, scald, repro): for i in range(args.batch_size): - self.imsave('valid/%03i_orign.png' % i, orign[i]) - self.imsave('valid/%03i_repro.png' % i, repro[i]) + self.imsave('valid/%03i_origin.png' % i, orign[i]) + self.imsave('valid/%03i_pixels.png' % i, scald[i]) + self.imsave('valid/%03i_reprod.png' % i, repro[i]) def train(self): images = np.zeros((args.batch_size, 3, args.batch_resolution, args.batch_resolution), dtype=np.float32) l_min, l_max, l_mult = 1E-7, 1E-3, 0.2 t_cur, t_i, t_mult = 120, 150, 1 - i, running, start = 0, None, time.time() - for k in range(args.epochs): - total = None - for _ in range(args.epoch_size): - i += 1 - l_r = l_min + 0.5 * (l_max - l_min) * (1.0 + math.cos(t_cur / t_i * math.pi)) - t_cur += 1 - self.model.gen_lr.set_value(l_r) - - if t_cur >= t_i: - t_cur, t_i = 0, int(t_i * t_mult) - l_max = max(l_max * l_mult, 1e-12) - l_min = max(l_min * l_mult, 1e-8) - - self.thread.copy(images) - losses = np.array(self.model.fit(images), dtype=np.float32) - total = total + losses if total is not None else losses - l = np.sum(losses) - assert not np.isnan(losses).any() - running = l if running is None else running * 0.9 + 0.1 * l - print('↑' if l > running else '↓', end=' ', flush=True) - - repro, orign = self.model.predict(images) - self.show_progress(repro, orign) - total /= args.epoch_size - totals, labels = [sum(total)] + list(total), ['total', 'prcpt', 'smthn', 'advrs'] - losses = ['{}{}{}={:4.2e}'.format(ansi.WHITE_B, k, ansi.ENDC, v) for k, v in zip(labels, totals)] - print('\rEpoch #{} at {:4.1f}s{}'.format(k+1, time.time()-start, ' '*args.epoch_size)) - print(' - losses {}'.format(' '.join(losses))) - - # print(stats[:args.batch_size].mean(), stats[args.batch_size:].mean()) - # if k == 0: self.model.disc_lr.set_value(l_r) - # if k == 1: self.model.adversary_weight.set_value(args.adversary_weight) + try: + i, running, start = 0, None, time.time() + for epoch in range(args.epochs): + total = None + for _ in range(args.epoch_size): + i += 1 + l_r = l_min + 0.5 * (l_max - l_min) * (1.0 + math.cos(t_cur / t_i * math.pi)) + t_cur += 1 + l_r = 1E-4 + self.model.gen_lr.set_value(l_r) + + if t_cur >= t_i: + t_cur, t_i = 0, int(t_i * t_mult) + l_max = max(l_max * l_mult, 1e-11) + l_min = max(l_min * l_mult, 1e-7) + + self.thread.copy(images) + losses = np.array(self.model.fit(images), dtype=np.float32) + total = total + losses if total is not None else losses + l = np.sum(losses) + assert not np.isnan(losses).any() + running = l if running is None else running * 0.9 + 0.1 * l + print('↑' if l > running else '↓', end=' ', flush=True) + + orign, scald, repro = self.model.predict(images) + self.show_progress(orign, scald, repro) + total /= args.epoch_size + totals, labels = [sum(total)] + list(total), ['total', 'prcpt', 'smthn', 'advrs'] + gen_info = ['{}{}{}={:4.2e}'.format(ansi.WHITE_B, k, ansi.ENDC, v) for k, v in zip(labels, totals)] + print('\rEpoch #{} at {:4.1f}s{}'.format(epoch+1, time.time()-start, ' '*args.epoch_size)) + print(' - generator {}'.format(' '.join(gen_info))) + + # print(' - discriminator {}'.format(' '.join(gen_stats))) + + # print(stats[:args.batch_size].mean(), stats[args.batch_size:].mean()) + # if epoch == 0: self.model.disc_lr.set_value(l_r) + # if epoch == 1: self.model.adversary_weight.set_value(args.adversary_weight) + except KeyboardInterrupt: + pass print('\n{}Trained {}x super-resolution for {} epochs.{}'\ - .format(ansi.CYAN_B, 2**args.scales, args.epochs, ansi.CYAN)) + .format(ansi.CYAN_B, 2**args.scales, epoch, ansi.CYAN)) self.model.save_generator() print(ansi.ENDC) if __name__ == "__main__": enhancer = NeuralEnhancer() - try: - enhancer.train() - except KeyboardInterrupt: - pass + enhancer.train()