Add extra padding on input to avoid zero-padding. Experiment with training values from ENet (segmentation).

main
Alex J. Champandard 9 years ago
parent 7924cc4a85
commit 3b2a6b9d8d

@ -46,10 +46,10 @@ add_arg('--epochs', default=10, type=int, help='Total
add_arg('--epoch-size', default=72, type=int, help='Number of batches trained in an epoch.')
add_arg('--save-every', default=10, type=int, help='Save generator after every training epoch.')
add_arg('--batch-shape', default=192, type=int, help='Resolution of images in training batch.')
add_arg('--batch-size', default=15, type=int, help='Number of images per training batch.')
add_arg('--batch-size', default=10, type=int, help='Number of images per training batch.')
add_arg('--buffer-size', default=1500, type=int, help='Total image fragments kept in cache.')
add_arg('--buffer-similar', default=5, type=int, help='Fragments cached for each image loaded.')
add_arg('--learning-rate', default=1E-4, type=float, help='Parameter for the ADAM optimizer.')
add_arg('--learning-rate', default=5E-4, type=float, help='Parameter for the ADAM optimizer.')
add_arg('--learning-period', default=100, type=int, help='How often to decay the learning rate.')
add_arg('--learning-decay', default=0.5, type=float, help='How much to decay the learning rate.')
add_arg('--generator-upscale', default=2, type=int, help='Steps of 2x up-sampling as post-process.')
@ -248,14 +248,25 @@ class Model(object):
config, params = self.load_model()
self.setup_generator(self.last_layer(), config)
# Compute batch-size to take into account there's no zero-padding of generator convolution layers.
s = args.batch_shape // args.zoom
current = lasagne.layers.helper.get_output_shape(self.network['out'], {self.network['seed']: (1, 3, s, s)})
args.batch_shape = args.batch_shape * 2 - current[2]
self.network['img'].shape = (args.batch_size, 3, args.batch_shape, args.batch_shape)
self.network['seed'].shape = (args.batch_size, 3, args.batch_shape // args.zoom, args.batch_shape // args.zoom)
# How to re-force this to compute more elegantly using Lasagne?
self.network['out'].input_shape = lasagne.layers.get_output_shape(self.network['out'].input_layer,
{self.network['seed']: self.network['seed'].shape})
if args.train:
concatenated = lasagne.layers.ConcatLayer([self.network['img'], self.network['out']], axis=0)
concatenated = lasagne.layers.ConcatLayer([self.network['img'], self.network['out']],
axis=0, cropping=(None, None, 'center', 'center'))
self.setup_perceptual(concatenated)
self.load_perceptual()
self.setup_discriminator()
self.load_generator(params)
self.compile()
#------------------------------------------------------------------------------------------------------------------
# Network Configuration
@ -265,7 +276,7 @@ class Model(object):
return list(self.network.values())[-1]
def make_layer(self, name, input, units, filter_size=(3,3), stride=(1,1), pad=(1,1), alpha=0.25):
conv = ConvLayer(input, units, filter_size=filter_size, stride=stride, pad=pad, nonlinearity=None)
conv = ConvLayer(input, units, filter_size, stride=stride, pad=self.pad_override or pad, nonlinearity=None)
prelu = lasagne.layers.ParametricRectifierLayer(conv, alpha=lasagne.init.Constant(alpha))
self.network[name+'x'] = conv
self.network[name+'>'] = prelu
@ -277,6 +288,7 @@ class Model(object):
return ElemwiseSumLayer([input, self.last_layer()]) if args.generator_residual else self.last_layer()
def setup_generator(self, input, config):
self.pad_override = (0, 0)
for k, v in config.items(): setattr(args, k, v)
args.zoom = 2**(args.generator_upscale - args.generator_downscale)
@ -301,6 +313,7 @@ class Model(object):
self.network['out'] = ConvLayer(self.last_layer(), 3, filter_size=(5,5), stride=(1,1), pad=(2,2),
nonlinearity=lasagne.nonlinearities.tanh)
self.pad_override = None
def setup_perceptual(self, input):
"""Use lasagne to create a network of convolution layers using pre-trained VGG19 weights.
@ -405,13 +418,13 @@ class Model(object):
return T.mean(((x[:,:,:-1,:-1] - x[:,:,1:,:-1])**2 + (x[:,:,:-1,:-1] - x[:,:,:-1,1:])**2)**1.25)
def loss_adversarial(self, d):
return T.mean(1.0 - T.nnet.softplus(d[args.batch_size:]))
return T.mean(1.0 - T.nnet.softminus(d[args.batch_size:]))
def loss_discriminator(self, d):
return T.mean(T.nnet.softminus(d[args.batch_size:]) - T.nnet.softplus(d[:args.batch_size]))
def compile(self):
# Helper function for rendering test images during training, or standalone non-training mode.
# Helper function for rendering test images during training, or standalone inference mode.
input_tensor, seed_tensor = T.tensor4(), T.tensor4()
input_layers = {self.network['img']: input_tensor, self.network['seed']: seed_tensor}
output = lasagne.layers.get_output([self.network[k] for k in ['seed', 'out']], input_layers, deterministic=True)
@ -437,7 +450,7 @@ class Model(object):
disc_losses = [self.loss_discriminator(disc_out)]
disc_params = list(itertools.chain(*[l.get_params() for k, l in self.network.items() if 'disc' in k]))
print(' - {} tensors learned for discriminator.'.format(len(disc_params)))
grads = [g.clip(-1.0, +1.0) for g in T.grad(sum(disc_losses, 0.0), disc_params)]
grads = [g.clip(-5.0, +5.0) for g in T.grad(sum(disc_losses, 0.0), disc_params)]
disc_updates = lasagne.updates.adam(grads, disc_params, learning_rate=self.disc_lr)
# Combined Theano function for updating both generator and discriminator at the same time.
@ -459,6 +472,7 @@ class NeuralEnhancer(object):
self.model = Model()
self.thread = DataLoader() if loader else None
self.model.compile()
print('{}'.format(ansi.ENDC))
@ -476,7 +490,7 @@ class NeuralEnhancer(object):
l_r, t_cur = args.learning_rate, 0
while True:
yield l_r if t_cur > 0 else l_r * 0.1
yield l_r
t_cur += 1
if t_cur % args.learning_period == 0: l_r *= args.learning_decay
@ -510,7 +524,7 @@ class NeuralEnhancer(object):
stats /= args.epoch_size
totals, labels = [sum(total)] + list(total), ['total', 'prcpt', 'smthn', 'advrs']
gen_info = ['{}{}{}={:4.2e}'.format(ansi.WHITE_B, k, ansi.ENDC, v) for k, v in zip(labels, totals)]
print('\rEpoch #{} at {:4.1f}s, lr={:4.2e}{}'.format(epoch+1, time.time()-start, l_r, ' '*(args.epoch_size-35)))
print('\rEpoch #{} at {:4.1f}s, lr={:4.2e}{}'.format(epoch+1, time.time()-start, l_r, ' '*(args.epoch_size-30)))
print(' - generator {}'.format(' '.join(gen_info)))
real, fake = stats[:args.batch_size], stats[args.batch_size:]

Loading…
Cancel
Save