diff --git a/enhance.py b/enhance.py index bc40cdb..0a98c66 100755 --- a/enhance.py +++ b/enhance.py @@ -38,12 +38,13 @@ add_arg('files', nargs='*', default=[]) add_arg('--scales', default=2, type=int, help='How many times to perform 2x upsampling.') add_arg('--model', default='small', type=str, help='Name of the neural network to load/save.') add_arg('--train', default=False, type=str, help='File pattern to load for training.') -add_arg('--batch-resolution', default=192, type=int, help='Resolution of images in training batch.') +add_arg('--epochs', default=10, type=int, help='Total number of iterations in training.') +add_arg('--epoch-size', default=72, type=int, help='Number of batches trained in an epoch.') +add_arg('--save-every', default=10, type=int, help='Save generator after every training epoch.') +add_arg('--batch-shape', default=192, type=int, help='Resolution of images in training batch.') add_arg('--batch-size', default=15, type=int, help='Number of images per training batch.') add_arg('--buffer-size', default=1500, type=int, help='Total image fragments kept in cache.') add_arg('--buffer-similar', default=5, type=int, help='Fragments cached for each image loaded.') -add_arg('--epochs', default=10, type=int, help='Total number of iterations in training.') -add_arg('--epoch-size', default=72, type=int, help='Number of batches trained in an epoch.') add_arg('--learning-rate', default=1E-4, type=float, help='Parameter for the ADAM optimizer.') add_arg('--learning-period', default=50, type=int, help='How often to decay the learning rate.') add_arg('--learning-decay', default=0.5, type=float, help='How much to decay the learning rate.') @@ -128,8 +129,10 @@ class DataLoader(threading.Thread): self.data_ready = threading.Event() self.data_copied = threading.Event() - self.resolution = args.batch_resolution - self.buffer = np.zeros((args.buffer_size, 3, self.resolution, self.resolution), dtype=np.float32) + self.orig_shape, self.seed_shape = args.batch_shape, int(args.batch_shape / 2**args.scales) + + self.orig_buffer = np.zeros((args.buffer_size, 3, self.orig_shape, self.orig_shape), dtype=np.float32) + self.seed_buffer = np.zeros((args.buffer_size, 3, self.seed_shape, self.seed_shape), dtype=np.float32) self.files = glob.glob(args.train) if len(self.files) == 0: error("There were no files found to train from searching for `{}`".format(args.train), @@ -157,27 +160,31 @@ class DataLoader(threading.Thread): for _ in range(args.buffer_similar): copy = img[:,::-1] if random.choice([True, False]) else img - h = random.randint(0, copy.shape[0] - self.resolution) - w = random.randint(0, copy.shape[1] - self.resolution) - copy = copy[h:h+self.resolution, w:w+self.resolution] + h = random.randint(0, copy.shape[0] - self.orig_shape) + w = random.randint(0, copy.shape[1] - self.orig_shape) + copy = copy[h:h+self.orig_shape, w:w+self.orig_shape] while len(self.available) == 0: self.data_copied.wait() self.data_copied.clear() i = self.available.pop() - self.buffer[i] = np.transpose(copy / 255.0 - 0.5, (2, 0, 1)) + self.orig_buffer[i] = np.transpose(copy / 255.0 - 0.5, (2, 0, 1)) + seed = scipy.misc.imresize(copy, size=(self.seed_shape, self.seed_shape), interp='bilinear') + self.seed_buffer[i] = np.transpose(seed / 255.0 - 0.5, (2, 0, 1)) self.ready.add(i) if len(self.ready) >= args.batch_size: self.data_ready.set() - def copy(self, output): + def copy(self, origs_out, seeds_out): self.data_ready.wait() self.data_ready.clear() for i, j in enumerate(random.sample(self.ready, args.batch_size)): - output[i] = self.buffer[j] + origs_out[i] = self.orig_buffer[j] + seeds_out[i] = self.seed_buffer[j] + self.available.add(j) self.data_copied.set() @@ -211,12 +218,8 @@ class Model(object): def __init__(self): self.network = collections.OrderedDict() - if args.train: - self.network['img'] = InputLayer((None, 3, None, None)) - self.network['seed'] = PoolLayer(self.network['img'], pool_size=2**args.scales, mode='average_exc_pad') - else: - self.network['img'] = InputLayer((None, 3, None, None)) - self.network['seed'] = self.network['img'] + self.network['img'] = InputLayer((None, 3, None, None)) + self.network['seed'] = InputLayer((None, 3, None, None)) config, params = self.load_model() self.setup_generator(self.last_layer(), config) @@ -378,10 +381,10 @@ class Model(object): def compile(self): # Helper function for rendering test images during training, or standalone non-training mode. - input_tensor = T.tensor4() - input_layers = {self.network['img']: input_tensor} - output = lasagne.layers.get_output([self.network[k] for k in ['img', 'seed', 'out']], input_layers, deterministic=True) - self.predict = theano.function([input_tensor], output) + input_tensor, seed_tensor = T.tensor4(), T.tensor4() + input_layers = {self.network['img']: input_tensor, self.network['seed']: seed_tensor} + output = lasagne.layers.get_output([self.network[k] for k in ['seed', 'out']], input_layers, deterministic=True) + self.predict = theano.function([seed_tensor], output) if not args.train: return @@ -407,7 +410,7 @@ class Model(object): # Combined Theano function for updating both generator and discriminator at the same time. updates = collections.OrderedDict(list(gen_updates.items()) + list(disc_updates.items())) - self.fit = theano.function([input_tensor], gen_losses + [disc_out.mean(axis=(1,2,3))], updates=updates) + self.fit = theano.function([input_tensor, seed_tensor], gen_losses + [disc_out.mean(axis=(1,2,3))], updates=updates) @@ -448,10 +451,12 @@ class NeuralEnhancer(object): if t_cur % args.learning_period == 0: l_r *= args.learning_decay def train(self): - images = np.zeros((args.batch_size, 3, args.batch_resolution, args.batch_resolution), dtype=np.float32) + seed_size = int(args.batch_shape / 2**args.scales) + images = np.zeros((args.batch_size, 3, args.batch_shape, args.batch_shape), dtype=np.float32) + seeds = np.zeros((args.batch_size, 3, seed_size, seed_size), dtype=np.float32) learning_rate = self.decay_learning_rate() try: - running, start = None, time.time() + average, start = None, time.time() for epoch in range(args.epochs): total, stats = None, None l_r = next(learning_rate) @@ -459,18 +464,18 @@ class NeuralEnhancer(object): if epoch >= args.discriminator_start: self.model.disc_lr.set_value(l_r) for _ in range(args.epoch_size): - self.thread.copy(images) - output = self.model.fit(images) + self.thread.copy(images, seeds) + output = self.model.fit(images, seeds) losses = np.array(output[:3], dtype=np.float32) stats = (stats + output[3]) if stats is not None else output[3] total = total + losses if total is not None else losses l = np.sum(losses) assert not np.isnan(losses).any() - running = l if running is None else running * 0.95 + 0.05 * l - print('↑' if l > running else '↓', end='', flush=True) + average = l if average is None else average * 0.95 + 0.05 * l + print('↑' if l > average else '↓', end='', flush=True) - orign, scald, repro = self.model.predict(images) - self.show_progress(orign, scald, repro) + scald, repro = self.model.predict(seeds) + self.show_progress(images, scald, repro) total /= args.epoch_size stats /= args.epoch_size totals, labels = [sum(total)] + list(total), ['total', 'prcpt', 'smthn', 'advrs'] @@ -481,9 +486,12 @@ class NeuralEnhancer(object): real, fake = stats[:args.batch_size], stats[args.batch_size:] print(' - discriminator', real.mean(), len(np.where(real > 0.5)[0]), fake.mean(), len(np.where(fake < -0.5)[0])) if epoch == args.adversarial_start-1: - print(' - adversary mode: generator engaging discriminator.') + print(' - generator now optimizing against discriminator.') self.model.adversary_weight.set_value(args.adversary_weight) running = None + if (epoch+1) % args.save_every == 0: + print(' - saving current generator layers to disk...') + self.model.save_generator() except KeyboardInterrupt: pass @@ -505,11 +513,9 @@ if __name__ == "__main__": if args.train: enhancer.train() - - for filename in args.files: - print(filename) - out = enhancer.process(scipy.ndimage.imread(filename, mode='RGB')) - out.save(os.path.splitext(filename)[0]+'_ne%ix.png'%(2**args.scales)) - - if args.files: + else: + for filename in args.files: + print(filename) + out = enhancer.process(scipy.ndimage.imread(filename, mode='RGB')) + out.save(os.path.splitext(filename)[0]+'_ne%ix.png'%(2**args.scales)) print(ansi.ENDC)