Refactor of changes related to training.

main
Alex J. Champandard 9 years ago
parent 2b67daedb6
commit 17fcad8d28

@ -38,13 +38,13 @@ add_arg('files', nargs='*', default=[])
add_arg('--scales', default=2, type=int, help='How many times to perform 2x upsampling.')
add_arg('--model', default='small', type=str, help='Name of the neural network to load/save.')
add_arg('--train', default=False, type=str, help='File pattern to load for training.')
add_arg('--save-every-epoch', default=False, action='store_true', help='Save generator after every training epoch.')
add_arg('--batch-resolution', default=192, type=int, help='Resolution of images in training batch.')
add_arg('--epochs', default=10, type=int, help='Total number of iterations in training.')
add_arg('--epoch-size', default=72, type=int, help='Number of batches trained in an epoch.')
add_arg('--save-every', default=10, type=int, help='Save generator after every training epoch.')
add_arg('--batch-shape', default=192, type=int, help='Resolution of images in training batch.')
add_arg('--batch-size', default=15, type=int, help='Number of images per training batch.')
add_arg('--buffer-size', default=1500, type=int, help='Total image fragments kept in cache.')
add_arg('--buffer-similar', default=5, type=int, help='Fragments cached for each image loaded.')
add_arg('--epochs', default=10, type=int, help='Total number of iterations in training.')
add_arg('--epoch-size', default=72, type=int, help='Number of batches trained in an epoch.')
add_arg('--learning-rate', default=1E-4, type=float, help='Parameter for the ADAM optimizer.')
add_arg('--learning-period', default=50, type=int, help='How often to decay the learning rate.')
add_arg('--learning-decay', default=0.5, type=float, help='How much to decay the learning rate.')
@ -129,11 +129,10 @@ class DataLoader(threading.Thread):
self.data_ready = threading.Event()
self.data_copied = threading.Event()
self.resolution = args.batch_resolution
self.seed_resolution = int(args.batch_resolution / 2**args.scales)
self.orig_shape, self.seed_shape = args.batch_shape, int(args.batch_shape / 2**args.scales)
self.buffer = np.zeros((args.buffer_size, 3, self.resolution, self.resolution), dtype=np.float32)
self.seed_buffer = np.zeros((args.buffer_size, 3, self.seed_resolution, self.seed_resolution), dtype=np.float32)
self.orig_buffer = np.zeros((args.buffer_size, 3, self.orig_shape, self.orig_shape), dtype=np.float32)
self.seed_buffer = np.zeros((args.buffer_size, 3, self.seed_shape, self.seed_shape), dtype=np.float32)
self.files = glob.glob(args.train)
if len(self.files) == 0:
error("There were no files found to train from searching for `{}`".format(args.train),
@ -161,31 +160,29 @@ class DataLoader(threading.Thread):
for _ in range(args.buffer_similar):
copy = img[:,::-1] if random.choice([True, False]) else img
h = random.randint(0, copy.shape[0] - self.resolution)
w = random.randint(0, copy.shape[1] - self.resolution)
copy = copy[h:h+self.resolution, w:w+self.resolution]
h = random.randint(0, copy.shape[0] - self.orig_shape)
w = random.randint(0, copy.shape[1] - self.orig_shape)
copy = copy[h:h+self.orig_shape, w:w+self.orig_shape]
while len(self.available) == 0:
self.data_copied.wait()
self.data_copied.clear()
i = self.available.pop()
self.buffer[i] = np.transpose(copy / 255.0 - 0.5, (2, 0, 1))
seed_copy = scipy.misc.imresize(copy,
size=(self.seed_resolution, self.seed_resolution),
interp='bilinear')
self.seed_buffer[i] = np.transpose(seed_copy / 255.0 - 0.5, (2, 0, 1))
self.orig_buffer[i] = np.transpose(copy / 255.0 - 0.5, (2, 0, 1))
seed = scipy.misc.imresize(copy, size=(self.seed_shape, self.seed_shape), interp='bilinear')
self.seed_buffer[i] = np.transpose(seed / 255.0 - 0.5, (2, 0, 1))
self.ready.add(i)
if len(self.ready) >= args.batch_size:
self.data_ready.set()
def copy(self, images_out, seeds_out):
def copy(self, origs_out, seeds_out):
self.data_ready.wait()
self.data_ready.clear()
for i, j in enumerate(random.sample(self.ready, args.batch_size)):
images_out[i] = self.buffer[j]
origs_out[i] = self.orig_buffer[j]
seeds_out[i] = self.seed_buffer[j]
self.available.add(j)
@ -384,11 +381,10 @@ class Model(object):
def compile(self):
# Helper function for rendering test images during training, or standalone non-training mode.
input_tensor = T.tensor4()
seed_tensor = T.tensor4()
input_tensor, seed_tensor = T.tensor4(), T.tensor4()
input_layers = {self.network['img']: input_tensor, self.network['seed']: seed_tensor}
output = lasagne.layers.get_output([self.network[k] for k in ['img', 'seed', 'out']], input_layers, deterministic=True)
self.predict = theano.function([input_tensor, seed_tensor], output)
output = lasagne.layers.get_output([self.network[k] for k in ['seed', 'out']], input_layers, deterministic=True)
self.predict = theano.function([seed_tensor], output)
if not args.train: return
@ -455,12 +451,12 @@ class NeuralEnhancer(object):
if t_cur % args.learning_period == 0: l_r *= args.learning_decay
def train(self):
seed_size = int(args.batch_resolution / 2**args.scales)
images = np.zeros((args.batch_size, 3, args.batch_resolution, args.batch_resolution), dtype=np.float32)
seed_size = int(args.batch_shape / 2**args.scales)
images = np.zeros((args.batch_size, 3, args.batch_shape, args.batch_shape), dtype=np.float32)
seeds = np.zeros((args.batch_size, 3, seed_size, seed_size), dtype=np.float32)
learning_rate = self.decay_learning_rate()
try:
running, start = None, time.time()
average, start = None, time.time()
for epoch in range(args.epochs):
total, stats = None, None
l_r = next(learning_rate)
@ -475,11 +471,11 @@ class NeuralEnhancer(object):
total = total + losses if total is not None else losses
l = np.sum(losses)
assert not np.isnan(losses).any()
running = l if running is None else running * 0.95 + 0.05 * l
print('' if l > running else '', end='', flush=True)
average = l if average is None else average * 0.95 + 0.05 * l
print('' if l > average else '', end='', flush=True)
orign, scald, repro = self.model.predict(images, seeds)
self.show_progress(orign, scald, repro)
scald, repro = self.model.predict(seeds)
self.show_progress(images, scald, repro)
total /= args.epoch_size
stats /= args.epoch_size
totals, labels = [sum(total)] + list(total), ['total', 'prcpt', 'smthn', 'advrs']
@ -490,10 +486,11 @@ class NeuralEnhancer(object):
real, fake = stats[:args.batch_size], stats[args.batch_size:]
print(' - discriminator', real.mean(), len(np.where(real > 0.5)[0]), fake.mean(), len(np.where(fake < -0.5)[0]))
if epoch == args.adversarial_start-1:
print(' - adversary mode: generator engaging discriminator.')
print(' - generator now optimizing against discriminator.')
self.model.adversary_weight.set_value(args.adversary_weight)
running = None
if args.save_every_epoch:
if (epoch+1) % args.save_every == 0:
print(' - saving current generator layers to disk...')
self.model.save_generator()
except KeyboardInterrupt:
@ -506,7 +503,7 @@ class NeuralEnhancer(object):
def process(self, image):
img = np.transpose(image / 255.0 - 0.5, (2, 0, 1))[np.newaxis].astype(np.float32)
*_, repro = self.model.predict(img, img)
*_, repro = self.model.predict(img)
repro = np.transpose(repro[0] + 0.5, (1, 2, 0)).clip(0.0, 1.0)
return scipy.misc.toimage(repro * 255.0, cmin=0, cmax=255)
@ -516,11 +513,9 @@ if __name__ == "__main__":
if args.train:
enhancer.train()
for filename in args.files:
print(filename)
out = enhancer.process(scipy.ndimage.imread(filename, mode='RGB'))
out.save(os.path.splitext(filename)[0]+'_ne%ix.png'%(2**args.scales))
if args.files:
else:
for filename in args.files:
print(filename)
out = enhancer.process(scipy.ndimage.imread(filename, mode='RGB'))
out.save(os.path.splitext(filename)[0]+'_ne%ix.png'%(2**args.scales))
print(ansi.ENDC)

Loading…
Cancel
Save