From d8c9292b80d1381bb999c4519965de8d310bfcca Mon Sep 17 00:00:00 2001 From: "Alex J. Champandard" Date: Tue, 11 Oct 2016 08:59:42 +0200 Subject: [PATCH] Tweak code for generalized super-resolution. --- enhance.py | 76 +++++++++++++++++++++++++++++++----------------------- 1 file changed, 44 insertions(+), 32 deletions(-) diff --git a/enhance.py b/enhance.py index 080cc3e..20abbbf 100644 --- a/enhance.py +++ b/enhance.py @@ -34,20 +34,22 @@ parser = argparse.ArgumentParser(description='Generate a new image by applying s add_arg = parser.add_argument add_arg('files', nargs='*', default=[]) add_arg('--train', default=False, action='store_true') -add_arg('--load', default=None, action='store_true') add_arg('--save', default=None, action='store_true') add_arg('--model', default='ne%ix.pkl.bz2', type=str) -add_arg('--batch-size', default=1, type=int) -add_arg('--batch-resolution', default=224, type=int) -add_arg('--epoch-size', default=36, type=int) -add_arg('--epochs', default=10, type=int) -add_arg('--generator-filters', default=256, type=int) -add_arg('--generator-blocks', default=4, type=int) +add_arg('--batch-size', default=15, type=int) +add_arg('--batch-resolution', default=192, type=int) +add_arg('--generator-filters', default=128, type=int) +add_arg('--generator-blocks', default=16, type=int) add_arg('--generator-residual', default=2, type=int) add_arg('--perceptual-layer', default='conv2_2', type=str) add_arg('--perceptual-weight', default=1e0, type=float) add_arg('--smoothness-weight', default=2e5, type=float) -add_arg('--adversary-weight', default=2e2, type=float) +add_arg('--adversary-weight', default=1e2, type=float) +add_arg('--epoch-size', default=36, type=int) +add_arg('--epochs', default=10, type=int) +add_arg('--generator-start', default=0, type=int) +add_arg('--discriminator-start',default=1, type=int) +add_arg('--adversarial-start', default=2, type=int) add_arg('--scales', default=2, type=int, help='') add_arg('--device', default='cpu', type=str, help='Name of the CPU/GPU number to use, for Theano.') args = parser.parse_args() @@ -84,7 +86,7 @@ print("""{} {}Super Resolution for images and videos powered by Deep Learning! # Load the underlying deep learning libraries based on the device specified. If you specify THEANO_FLAGS manually, # the code assumes you know what you are doing and they are not overriden! os.environ.setdefault('THEANO_FLAGS', 'floatX=float32,device={},force_device=True,allow_gc=True,'\ - 'print_active_device=False'.format(args.device)) + 'print_active_device=False,lib.cnmem=1.0'.format(args.device)) # Scientific & Imaging Libraries import numpy as np @@ -130,7 +132,17 @@ class DataLoader(threading.Thread): for i, f in enumerate(files[:args.batch_size]): filename = os.path.join(self.cwd, f) try: - img = cache.setdefault(f, scipy.ndimage.imread(filename, mode='RGB')) + if f not in cache: + if len(cache) > 3172: + del cache[random.choice(list(cache.keys()))] + + img = scipy.ndimage.imread(filename, mode='RGB') + ratio = min(1024 / img.shape[0], 1024 / img.shape[1]) + if ratio < 1.0: + img = scipy.misc.imresize(img, ratio, interp='bicubic') + cache[f] = img + else: + img = cache[f] except Exception as e: warn('Could not load `{}` as image.'.format(filename), ' - Try fixing or removing the file before next run.') @@ -188,6 +200,7 @@ class Model(object): self.network['seed'] = PoolLayer(self.network['img'], pool_size=2**args.scales, mode='average_exc_pad') else: self.network['img'] = InputLayer((None, 3, None, None)) + self.network['seed'] = self.network['img'] self.setup_generator(self.last_layer()) if args.train: @@ -267,10 +280,10 @@ class Model(object): self.make_layer('disc1.1', batch_norm(self.network['conv1_2']), 64, filter_size=(5,5), stride=(2,2), pad=(2,2)) self.make_layer('disc1.2', self.last_layer(), 64, filter_size=(5,5), stride=(2,2), pad=(2,2)) self.make_layer('disc2', batch_norm(self.network['conv2_2']), 128, filter_size=(5,5), stride=(2,2), pad=(2,2)) - # self.make_layer('disc3', batch_norm(self.network['conv3_2']), 256, filter_size=(3,3), stride=(1,1), pad=(1,1)) - hypercolumn = ConcatLayer([self.network['disc1.2>'], self.network['disc2>']]) - self.make_layer('disc4', hypercolumn, 128, filter_size=(3,3), stride=(1,1)) - self.make_layer('disc5', self.last_layer(), 64, filter_size=(3,3), stride=(1,1)) + self.make_layer('disc3', batch_norm(self.network['conv3_2']), 192, filter_size=(3,3), stride=(1,1), pad=(1,1)) + hypercolumn = ConcatLayer([self.network['disc1.2>'], self.network['disc2>'], self.network['disc3>']]) + self.make_layer('disc4', hypercolumn, 192, filter_size=(3,3), stride=(1,1)) + self.make_layer('disc5', self.last_layer(), 96, filter_size=(3,3), stride=(1,1)) self.network['disc'] = batch_norm(ConvLayer(self.last_layer(), 1, filter_size=(1,1), nonlinearity=lasagne.nonlinearities.sigmoid)) @@ -307,7 +320,7 @@ class Model(object): def load_generator(self): filename = args.model % 2**args.scales - if not os.path.exists(filename) or not args.load: return + if not os.path.exists(filename): return params = pickle.load(bz2.open(filename, 'rb')) for k, l in self.list_generator_layers(): assert k in params, "Couldn't find layer `%s` in loaded model.'" @@ -328,16 +341,16 @@ class Model(object): return T.mean(((x[:,:,:-1,:-1] - x[:,:,1:,:-1])**2 + (x[:,:,:-1,:-1] - x[:,:,:-1,1:])**2)**1.25) def loss_adversarial(self, d): - return 1.0 - T.log(d[args.batch_size:]).mean() + return 1.0 - T.log(1E-6 + d[args.batch_size:]).mean() def loss_discriminator(self, d): - return T.mean(T.log(d[args.batch_size:]) + T.log(1.0 - d[:args.batch_size])) + return T.mean(T.log(1E-6 + d[args.batch_size:]) + T.log(1E-6 + 1.0 - d[:args.batch_size])) def compile(self): # Helper function for rendering test images during training, or standalone non-training mode. input_tensor = T.tensor4() input_layers = {self.network['img']: input_tensor} - output = lasagne.layers.get_output([self.network[k] for k in ['img', 'out']], input_layers, deterministic=True) + output = lasagne.layers.get_output([self.network[k] for k in ['img', 'seed', 'out']], input_layers, deterministic=True) self.predict = theano.function([input_tensor], output) if not args.train: return @@ -404,8 +417,9 @@ class NeuralEnhancer(object): l_r = l_min + 0.5 * (l_max - l_min) * (1.0 + math.cos(t_cur / t_i * math.pi)) t_cur += 1 l_r = 1E-4 - self.model.gen_lr.set_value(l_r) - self.model.disc_lr.set_value(l_r) + + if epoch >= args.generator_start: self.model.gen_lr.set_value(l_r) + if epoch >= args.discriminator_start: self.model.disc_lr.set_value(l_r) if t_cur >= t_i: t_cur, t_i = 0, int(t_i * t_mult) @@ -433,20 +447,23 @@ class NeuralEnhancer(object): real, fake = stats[:args.batch_size], stats[args.batch_size:] print(' - discriminator', real.mean(), len(np.where(real > 0.5)[0]), fake.mean(), len(np.where(fake < 0.5)[0])) - if epoch == 0: + if epoch == args.adversary_start-1: + print(' - adversary mode: generator engaging discriminator.') self.model.adversary_weight.set_value(args.adversary_weight) running = None except KeyboardInterrupt: pass print('\n{}Trained {}x super-resolution for {} epochs.{}'\ - .format(ansi.CYAN_B, 2**args.scales, epoch, ansi.CYAN)) + .format(ansi.CYAN_B, 2**args.scales, epoch+1, ansi.CYAN)) self.model.save_generator() print(ansi.ENDC) - def process(self, images): - _, repro = self.model.predict(images) - return repro + def process(self, image): + img = np.transpose(image / 255.0 - 0.5, (2, 0, 1))[np.newaxis].astype(np.float32) + *_, repro = self.model.predict(img) + repro = np.transpose(repro[0] + 0.5, (1, 2, 0)).clip(0.0, 1.0) + return scipy.misc.toimage(repro * 255.0, cmin=0, cmax=255) if __name__ == "__main__": @@ -456,10 +473,5 @@ if __name__ == "__main__": enhancer.train() for filename in args.files: - img = scipy.ndimage.imread(filename, mode='RGB') - img = np.transpose(img / 255.0 - 0.5, (2, 0, 1))[np.newaxis] - - out = enhancer.process(img.astype(np.float32)) - - out = np.transpose((out[0] + 0.5) * 255.0, (1, 2, 0)).astype(np.uint8) - scipy.misc.imsave(os.path.splitext(filename)[0]+'_enhanced.png', out) + out = enhancer.process(scipy.ndimage.imread(filename, mode='RGB')) + out.save(os.path.splitext(filename)[0]+'_enhanced.png')