From c456221cb53a2af8cd736ee1dd53611b5bad44d1 Mon Sep 17 00:00:00 2001 From: "Alex J. Champandard" Date: Fri, 28 Oct 2016 04:28:30 +0200 Subject: [PATCH] Experiment with recursive super-resolution and weight reuse, mixed results. --- enhance.py | 47 ++++++++++++++++++++++++++++++++++++----------- 1 file changed, 36 insertions(+), 11 deletions(-) diff --git a/enhance.py b/enhance.py index eaa31ea..36982e2 100644 --- a/enhance.py +++ b/enhance.py @@ -85,6 +85,8 @@ def warn(message, *lines): string = "\n{}WARNING: " + message + "{}\n" + "\n".join(lines) + "{}\n" print(string.format(ansi.YELLOW_B, ansi.YELLOW, ansi.ENDC)) +def extend(lst): return itertools.chain(lst, itertools.repeat(lst[-1])) + print("""{} {}Super Resolution for images and videos powered by Deep Learning!{} - Code licensed as AGPLv3, models under CC BY-NC-SA.{}""".format(ansi.CYAN_B, __doc__, ansi.CYAN, ansi.ENDC)) @@ -231,8 +233,23 @@ class Model(object): return list(self.network.values())[-1] def make_layer(self, name, input, units, filter_size=(3,3), stride=(1,1), pad=(1,1), alpha=0.25): - conv = ConvLayer(input, units, filter_size=filter_size, stride=stride, pad=pad, nonlinearity=None) - prelu = lasagne.layers.ParametricRectifierLayer(conv, alpha=lasagne.init.Constant(alpha)) + orig = '1.'+''.join(name.split('.')[1:]) + if orig+'x' in self.network: + print('reused', orig, 'for', name) + l = self.network[orig +'x'] + extra = {'W': l.W, 'b': l.b} + else: + extra = {} + + conv = ConvLayer(input, units, filter_size=filter_size, stride=stride, pad=pad, nonlinearity=None, **extra) + + alpha = lasagne.init.Constant(alpha) + if orig +'>' in self.network: + l = self.network[orig +'>'] + alpha = l.alpha + + prelu = lasagne.layers.ParametricRectifierLayer(conv, alpha=alpha) + self.network[name+'x'] = conv self.network[name+'>'] = prelu return prelu @@ -244,17 +261,25 @@ class Model(object): def setup_generator(self, input, config): for k, v in config.items(): setattr(args, k, v) - units = args.generator_filters - self.make_layer('iter.0', input, units, filter_size=(5,5), pad=(2,2)) - - for i in range(0, args.generator_blocks): - self.network['iter.%i'%(i+1)] = self.make_block('iter.%i'%(i+1), self.last_layer(), units) - - for i in range(args.scales, 0, -1): - u = units // 2**(args.scales-i) + units_iter = extend(args.generator_filters) + units = next(units_iter) + self.make_layer('iter.0-A', input, units, filter_size=(5,5), pad=(2,2)) + self.make_layer('iter.0-B', self.last_layer(), units, filter_size=(5,5), pad=(2,2)) + self.network['iter.0'] = self.last_layer() + + for i in range(0, args.generator_iters): + base = self.last_layer() + for j in range(0, args.generator_blocks): + self.make_block('%i.iter-%i'%(i+1, j), self.last_layer(), units) + print('iter.%i-%i'%(i+1, j)) + # self.network['iter.%i'%(i+1)] = DropPathLayer([base, self.last_layer()]) + + for i in range(0, args.scales): + u = next(units_iter) self.make_layer('scale%i.3'%i, self.last_layer(), u*4) self.network['scale%i.2'%i] = SubpixelReshuffleLayer(self.last_layer(), u, 2) - self.make_layer('scale%i.1'%i, self.network['scale%i.2'%i], u) + self.make_layer('scale%i.1'%i, self.last_layer(), u) + self.network['out'] = ConvLayer(self.last_layer(), 3, filter_size=(5,5), stride=(1,1), pad=(2,2), nonlinearity=lasagne.nonlinearities.tanh)