diff --git a/enhance.py b/enhance.py index e1cfe5e..3bce683 100644 --- a/enhance.py +++ b/enhance.py @@ -130,7 +130,10 @@ class DataLoader(threading.Thread): self.data_copied = threading.Event() self.resolution = args.batch_resolution + self.seed_resolution = int(args.batch_resolution / 2**args.scales) + self.buffer = np.zeros((args.buffer_size, 3, self.resolution, self.resolution), dtype=np.float32) + self.seed_buffer = np.zeros((args.buffer_size, 3, self.seed_resolution, self.seed_resolution), dtype=np.float32) self.files = glob.glob(args.train) if len(self.files) == 0: error("There were no files found to train from searching for `{}`".format(args.train), @@ -168,17 +171,23 @@ class DataLoader(threading.Thread): i = self.available.pop() self.buffer[i] = np.transpose(copy / 255.0 - 0.5, (2, 0, 1)) + seed_copy = scipy.misc.imresize(copy, + size=(self.seed_resolution, self.seed_resolution), + interp='bilinear') + self.seed_buffer[i] = np.transpose(seed_copy / 255.0 - 0.5, (2, 0, 1)) self.ready.add(i) if len(self.ready) >= args.batch_size: self.data_ready.set() - def copy(self, output): + def copy(self, images_out, seeds_out): self.data_ready.wait() self.data_ready.clear() for i, j in enumerate(random.sample(self.ready, args.batch_size)): - output[i] = self.buffer[j] + images_out[i] = self.buffer[j] + seeds_out[i] = self.seed_buffer[j] + self.available.add(j) self.data_copied.set() @@ -214,7 +223,7 @@ class Model(object): self.network = collections.OrderedDict() if args.train: self.network['img'] = InputLayer((None, 3, None, None)) - self.network['seed'] = PoolLayer(self.network['img'], pool_size=2**args.scales, mode='average_exc_pad') + self.network['seed'] = InputLayer((None, 3, None, None)) else: self.network['img'] = InputLayer((None, 3, None, None)) self.network['seed'] = self.network['img'] @@ -380,9 +389,10 @@ class Model(object): def compile(self): # Helper function for rendering test images during training, or standalone non-training mode. input_tensor = T.tensor4() - input_layers = {self.network['img']: input_tensor} + seed_tensor = T.tensor4() + input_layers = {self.network['img']: input_tensor, self.network['seed']: seed_tensor} output = lasagne.layers.get_output([self.network[k] for k in ['img', 'seed', 'out']], input_layers, deterministic=True) - self.predict = theano.function([input_tensor], output) + self.predict = theano.function([input_tensor, seed_tensor], output) if not args.train: return @@ -408,7 +418,7 @@ class Model(object): # Combined Theano function for updating both generator and discriminator at the same time. updates = collections.OrderedDict(list(gen_updates.items()) + list(disc_updates.items())) - self.fit = theano.function([input_tensor], gen_losses + [disc_out.mean(axis=(1,2,3))], updates=updates) + self.fit = theano.function([input_tensor, seed_tensor], gen_losses + [disc_out.mean(axis=(1,2,3))], updates=updates) @@ -448,7 +458,9 @@ class NeuralEnhancer(object): if t_cur % args.learning_period == 0: l_r *= args.learning_decay def train(self): + seed_size = int(args.batch_resolution / 2**args.scales) images = np.zeros((args.batch_size, 3, args.batch_resolution, args.batch_resolution), dtype=np.float32) + seeds = np.zeros((args.batch_size, 3, seed_size, seed_size), dtype=np.float32) learning_rate = self.decay_learning_rate() try: running, start = None, time.time() @@ -459,8 +471,8 @@ class NeuralEnhancer(object): if epoch >= args.discriminator_start: self.model.disc_lr.set_value(l_r) for _ in range(args.epoch_size): - self.thread.copy(images) - output = self.model.fit(images) + self.thread.copy(images, seeds) + output = self.model.fit(images, seeds) losses = np.array(output[:3], dtype=np.float32) stats = (stats + output[3]) if stats is not None else output[3] total = total + losses if total is not None else losses @@ -469,7 +481,7 @@ class NeuralEnhancer(object): running = l if running is None else running * 0.95 + 0.05 * l print('↑' if l > running else '↓', end='', flush=True) - orign, scald, repro = self.model.predict(images) + orign, scald, repro = self.model.predict(images, seeds) self.show_progress(orign, scald, repro) total /= args.epoch_size stats /= args.epoch_size