From 37cb20837487ed9958724df85d7d9d78b4e1fea4 Mon Sep 17 00:00:00 2001 From: Tom White Date: Mon, 31 Oct 2016 01:18:46 +1300 Subject: [PATCH 1/2] Move generation of seeds out of training network This moves the generation of the image seeds out of the training network and into the DataLoader. Currently seeds are computed as bilinear downsamplings of the original image. This is almost functionly equivalent to the version it replaces, but opens up new possibilities at training time because the seeds are now decoupled from the netork. For example, seeds could be made with different interpolations or even with other transformations such as image compression. --- enhance.py | 30 +++++++++++++++++++++--------- 1 file changed, 21 insertions(+), 9 deletions(-) diff --git a/enhance.py b/enhance.py index e1cfe5e..3bce683 100644 --- a/enhance.py +++ b/enhance.py @@ -130,7 +130,10 @@ class DataLoader(threading.Thread): self.data_copied = threading.Event() self.resolution = args.batch_resolution + self.seed_resolution = int(args.batch_resolution / 2**args.scales) + self.buffer = np.zeros((args.buffer_size, 3, self.resolution, self.resolution), dtype=np.float32) + self.seed_buffer = np.zeros((args.buffer_size, 3, self.seed_resolution, self.seed_resolution), dtype=np.float32) self.files = glob.glob(args.train) if len(self.files) == 0: error("There were no files found to train from searching for `{}`".format(args.train), @@ -168,17 +171,23 @@ class DataLoader(threading.Thread): i = self.available.pop() self.buffer[i] = np.transpose(copy / 255.0 - 0.5, (2, 0, 1)) + seed_copy = scipy.misc.imresize(copy, + size=(self.seed_resolution, self.seed_resolution), + interp='bilinear') + self.seed_buffer[i] = np.transpose(seed_copy / 255.0 - 0.5, (2, 0, 1)) self.ready.add(i) if len(self.ready) >= args.batch_size: self.data_ready.set() - def copy(self, output): + def copy(self, images_out, seeds_out): self.data_ready.wait() self.data_ready.clear() for i, j in enumerate(random.sample(self.ready, args.batch_size)): - output[i] = self.buffer[j] + images_out[i] = self.buffer[j] + seeds_out[i] = self.seed_buffer[j] + self.available.add(j) self.data_copied.set() @@ -214,7 +223,7 @@ class Model(object): self.network = collections.OrderedDict() if args.train: self.network['img'] = InputLayer((None, 3, None, None)) - self.network['seed'] = PoolLayer(self.network['img'], pool_size=2**args.scales, mode='average_exc_pad') + self.network['seed'] = InputLayer((None, 3, None, None)) else: self.network['img'] = InputLayer((None, 3, None, None)) self.network['seed'] = self.network['img'] @@ -380,9 +389,10 @@ class Model(object): def compile(self): # Helper function for rendering test images during training, or standalone non-training mode. input_tensor = T.tensor4() - input_layers = {self.network['img']: input_tensor} + seed_tensor = T.tensor4() + input_layers = {self.network['img']: input_tensor, self.network['seed']: seed_tensor} output = lasagne.layers.get_output([self.network[k] for k in ['img', 'seed', 'out']], input_layers, deterministic=True) - self.predict = theano.function([input_tensor], output) + self.predict = theano.function([input_tensor, seed_tensor], output) if not args.train: return @@ -408,7 +418,7 @@ class Model(object): # Combined Theano function for updating both generator and discriminator at the same time. updates = collections.OrderedDict(list(gen_updates.items()) + list(disc_updates.items())) - self.fit = theano.function([input_tensor], gen_losses + [disc_out.mean(axis=(1,2,3))], updates=updates) + self.fit = theano.function([input_tensor, seed_tensor], gen_losses + [disc_out.mean(axis=(1,2,3))], updates=updates) @@ -448,7 +458,9 @@ class NeuralEnhancer(object): if t_cur % args.learning_period == 0: l_r *= args.learning_decay def train(self): + seed_size = int(args.batch_resolution / 2**args.scales) images = np.zeros((args.batch_size, 3, args.batch_resolution, args.batch_resolution), dtype=np.float32) + seeds = np.zeros((args.batch_size, 3, seed_size, seed_size), dtype=np.float32) learning_rate = self.decay_learning_rate() try: running, start = None, time.time() @@ -459,8 +471,8 @@ class NeuralEnhancer(object): if epoch >= args.discriminator_start: self.model.disc_lr.set_value(l_r) for _ in range(args.epoch_size): - self.thread.copy(images) - output = self.model.fit(images) + self.thread.copy(images, seeds) + output = self.model.fit(images, seeds) losses = np.array(output[:3], dtype=np.float32) stats = (stats + output[3]) if stats is not None else output[3] total = total + losses if total is not None else losses @@ -469,7 +481,7 @@ class NeuralEnhancer(object): running = l if running is None else running * 0.95 + 0.05 * l print('↑' if l > running else '↓', end='', flush=True) - orign, scald, repro = self.model.predict(images) + orign, scald, repro = self.model.predict(images, seeds) self.show_progress(orign, scald, repro) total /= args.epoch_size stats /= args.epoch_size From 8f5167d235ad831e7a7c38faae2d4231bf39a8ce Mon Sep 17 00:00:00 2001 From: Tom White Date: Mon, 31 Oct 2016 02:13:02 +1300 Subject: [PATCH 2/2] Fix enhancer.process to pass img, seed --- enhance.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/enhance.py b/enhance.py index 3bce683..c73213a 100644 --- a/enhance.py +++ b/enhance.py @@ -221,12 +221,8 @@ class Model(object): def __init__(self): self.network = collections.OrderedDict() - if args.train: - self.network['img'] = InputLayer((None, 3, None, None)) - self.network['seed'] = InputLayer((None, 3, None, None)) - else: - self.network['img'] = InputLayer((None, 3, None, None)) - self.network['seed'] = self.network['img'] + self.network['img'] = InputLayer((None, 3, None, None)) + self.network['seed'] = InputLayer((None, 3, None, None)) config, params = self.load_model() self.setup_generator(self.last_layer(), config) @@ -507,7 +503,7 @@ class NeuralEnhancer(object): def process(self, image): img = np.transpose(image / 255.0 - 0.5, (2, 0, 1))[np.newaxis].astype(np.float32) - *_, repro = self.model.predict(img) + *_, repro = self.model.predict(img, img) repro = np.transpose(repro[0] + 0.5, (1, 2, 0)).clip(0.0, 1.0) return scipy.misc.toimage(repro * 255.0, cmin=0, cmax=255)