Add residual blocks to generator.

main
Alex J. Champandard 9 years ago
parent 058e3d3b9e
commit 3f24714039

@ -32,14 +32,15 @@ parser = argparse.ArgumentParser(description='Generate a new image by applying s
formatter_class=argparse.ArgumentDefaultsHelpFormatter) formatter_class=argparse.ArgumentDefaultsHelpFormatter)
add_arg = parser.add_argument add_arg = parser.add_argument
add_arg('--batch-size', default=15, type=int) add_arg('--batch-size', default=15, type=int)
add_arg('--batch-resolution', default=128, type=int) add_arg('--batch-resolution', default=256, type=int)
add_arg('--epoch-size', default=72, type=int) add_arg('--epoch-size', default=36, type=int)
add_arg('--epochs', default=100, type=int) add_arg('--epochs', default=100, type=int)
add_arg('--network-filters', default=64, type=int) add_arg('--network-filters', default=128, type=int)
add_arg('--network-blocks', default=4, type=int)
add_arg('--perceptual-layer', default='mse', type=str) add_arg('--perceptual-layer', default='mse', type=str)
add_arg('--perceptual-weight', default=1e0, type=float) add_arg('--perceptual-weight', default=1e0, type=float)
add_arg('--smoothness-weight', default=0.0, type=float) add_arg('--smoothness-weight', default=0.0, type=float)
add_arg('--adversary-weight', default=1e4, type=float) add_arg('--adversary-weight', default=0.0, type=float)
add_arg('--scales', default=1, type=int, help='') add_arg('--scales', default=1, type=int, help='')
add_arg('--device', default='gpu0', type=str, help='Name of the CPU/GPU number to use, for Theano.') add_arg('--device', default='gpu0', type=str, help='Name of the CPU/GPU number to use, for Theano.')
args = parser.parse_args() args = parser.parse_args()
@ -94,7 +95,7 @@ if sys.platform == 'win32':
# Deep Learning Framework # Deep Learning Framework
import lasagne import lasagne
from lasagne.layers import Conv2DLayer as ConvLayer, Deconv2DLayer as DeconvLayer, Pool2DLayer as PoolLayer from lasagne.layers import Conv2DLayer as ConvLayer, Deconv2DLayer as DeconvLayer, Pool2DLayer as PoolLayer
from lasagne.layers import InputLayer, ConcatLayer, batch_norm from lasagne.layers import InputLayer, ConcatLayer, batch_norm, ElemwiseSumLayer
print('{} - Using the device `{}` for neural computation.{}\n'.format(ansi.CYAN, theano.config.device, ansi.ENDC)) print('{} - Using the device `{}` for neural computation.{}\n'.format(ansi.CYAN, theano.config.device, ansi.ENDC))
@ -164,19 +165,36 @@ class Model(object):
def __init__(self): def __init__(self):
self.network = collections.OrderedDict() self.network = collections.OrderedDict()
self.network['img'] = InputLayer((None, 3, None, None)) self.network['img'] = InputLayer((None, 3, None, None))
self.network['img.scaled'] = PoolLayer(self.network['img'], pool_size=2**args.scales) low_res = PoolLayer(self.network['img'], pool_size=2**args.scales)
self.setup_generator(low_res)
self.setup_generator(self.network['img.scaled'])
concatenated = lasagne.layers.ConcatLayer([self.network['img'], self.network['out']], axis=0) concatenated = lasagne.layers.ConcatLayer([self.network['img'], self.network['out']], axis=0)
self.setup_perceptual(concatenated) self.setup_perceptual(concatenated)
self.load_perceptual() self.load_perceptual()
self.compile() self.compile()
def last_layer(self): def last_layer(self):
return list(self.network.values())[-1] return list(self.network.values())[-1]
def setup_generator(self, input):
f = args.network_filters
self.network['iter.0'] = ConvLayer(input, f, filter_size=(1,1), stride=(1,1), pad=0)
for i in range(0, args.network_blocks):
self.network['iter.%i'%(i+1)] = self.make_block(self.last_layer(), f)
for i in range(args.scales, 0, -1):
self.network['scale%i.2'%i] = DeconvLayer(self.last_layer(), f, filter_size=(4,4), stride=(2,2), crop=1)
self.network['scale%i.1'%i] = ConvLayer(self.network['scale%i.2'%i], f, filter_size=(3,3), pad=1)
self.network['out'] = ConvLayer(self.last_layer(), 3, filter_size=(1,1), stride=(1,1), pad=0, b=None,
nonlinearity=lasagne.nonlinearities.tanh)
def make_block(self, input, units):
l1 = batch_norm(ConvLayer(input, units, filter_size=(3,3), stride=(1,1), pad=1))
l2 = batch_norm(ConvLayer(l1, units, filter_size=(3,3), stride=(1,1), pad=1))
return ElemwiseSumLayer([input, l2])
def setup_perceptual(self, input): def setup_perceptual(self, input):
"""Use lasagne to create a network of convolution layers using pre-trained VGG19 weights. """Use lasagne to create a network of convolution layers using pre-trained VGG19 weights.
""" """
@ -218,32 +236,26 @@ class Model(object):
layers = lasagne.layers.get_all_layers(self.last_layer(), treat_as_input=[self.network['percept']]) layers = lasagne.layers.get_all_layers(self.last_layer(), treat_as_input=[self.network['percept']])
for p, d in zip(itertools.chain(*[l.get_params() for l in layers]), data): p.set_value(d) for p, d in zip(itertools.chain(*[l.get_params() for l in layers]), data): p.set_value(d)
def setup_generator(self, input):
f = args.network_filters
self.network['iter.0'] = ConvLayer(input, f, filter_size=(1,1), stride=(1,1), pad=0,)
for i in range(args.scales, 0, -1):
self.network['scale%i.2'%i] = DeconvLayer(self.last_layer(), f, filter_size=(4,4), stride=(2,2), crop=1)
self.network['scale%i.1'%i] = ConvLayer(self.network['scale%i.2'%i], f, filter_size=(3,3), pad=1)
self.network['out'] = ConvLayer(self.last_layer(), 3, filter_size=(1,1), stride=(1,1), pad=0, b=None,
nonlinearity=lasagne.nonlinearities.tanh)
def compile(self): def compile(self):
self.learning_rate = theano.shared(np.array(1e-4, dtype=theano.config.floatX))
input_tensor = T.tensor4() input_tensor = T.tensor4()
output_layers = [self.network['out'], self.network[args.perceptual_layer]] output_layers = [self.network['out'], self.network[args.perceptual_layer], self.network['disc']]
input_layers = {self.network['img']: input_tensor} input_layers = {self.network['img']: input_tensor}
gen_out, percept_out = lasagne.layers.get_output(output_layers, input_layers, deterministic=False) gen_out, percept_out, disc_out = lasagne.layers.get_output(output_layers, input_layers, deterministic=False)
losses = [self.loss_perceptual(percept_out) * args.perceptual_weight,
# Generator loss function, parameters and updates.
self.gen_lr = theano.shared(np.array(0.0, dtype=theano.config.floatX))
gen_losses = [self.loss_perceptual(percept_out) * args.perceptual_weight,
self.loss_total_variation(gen_out) * args.smoothness_weight] self.loss_total_variation(gen_out) * args.smoothness_weight]
gen_params = lasagne.layers.get_all_params(self.network['out'], trainable=True)
print(' - {} tensors learned for generator.'.format(len(gen_params)))
gen_updates = lasagne.updates.adam(sum(gen_losses, 0.0), gen_params, learning_rate=self.gen_lr)
params = lasagne.layers.get_all_params(self.network['out'], trainable=True) # Combined Theano function for updating both generator and discriminator at the same time.
updates = lasagne.updates.adam(sum(losses, 0.0), params, learning_rate=self.learning_rate) updates = list(gen_updates.items())
self.fit = theano.function([input_tensor], losses, updates=updates) self.fit = theano.function([input_tensor], gen_losses, updates=collections.OrderedDict(updates))
# Helper function for rendering test images deterministically, computing statistics.
gen_out, gen_inp = lasagne.layers.get_output([self.network['out'], self.network['img']], gen_out, gen_inp = lasagne.layers.get_output([self.network['out'], self.network['img']],
input_layers, deterministic=True) input_layers, deterministic=True)
self.predict = theano.function([input_tensor], [gen_out, gen_inp]) self.predict = theano.function([input_tensor], [gen_out, gen_inp])
@ -252,15 +264,20 @@ class Model(object):
return lasagne.objectives.squared_error(p[:args.batch_size], p[args.batch_size:]).mean() return lasagne.objectives.squared_error(p[:args.batch_size], p[args.batch_size:]).mean()
def loss_total_variation(self, x): def loss_total_variation(self, x):
return (((x[:,:,:-1,:-1] - x[:,:,1:,:-1])**2 + (x[:,:,:-1,:-1] - x[:,:,:-1,1:])**2)**1.25).mean() return T.mean(((x[:,:,:-1,:-1] - x[:,:,1:,:-1])**2 + (x[:,:,:-1,:-1] - x[:,:,:-1,1:])**2)**1.25)
class NeuralEnhancer(object): class NeuralEnhancer(object):
def __init__(self): def __init__(self):
print('{}Training {} epochs on random image sections with batch size {}.{}'\
.format(ansi.BLUE_B, args.epochs, args.batch_size, ansi.BLUE))
self.thread = DataLoader() self.thread = DataLoader()
self.model = Model() self.model = Model()
print('\n{}'.format(ansi.ENDC))
def imsave(self, fn, img): def imsave(self, fn, img):
img = np.transpose(img + 0.5, (1, 2, 0)).clip(0.0, 1.0) img = np.transpose(img + 0.5, (1, 2, 0)).clip(0.0, 1.0)
image = scipy.misc.toimage(img * 255.0, cmin=0, cmax=255) image = scipy.misc.toimage(img * 255.0, cmin=0, cmax=255)
@ -272,22 +289,19 @@ class NeuralEnhancer(object):
self.imsave('test/%03i_repro.png' % i, repro[i]) self.imsave('test/%03i_repro.png' % i, repro[i])
def train(self): def train(self):
print('\n{}Training {} epochs with batch size {}.{}'\
.format(ansi.BLUE_B, args.epochs, args.batch_size, ansi.ENDC))
images = np.zeros((args.batch_size, 3, args.batch_resolution, args.batch_resolution), dtype=np.float32) images = np.zeros((args.batch_size, 3, args.batch_resolution, args.batch_resolution), dtype=np.float32)
l_min, l_max, l_mult = 1E-7, 1E-3, 0.2 l_min, l_max, l_mult = 1E-7, 1E-3, 0.2
t_cur, t_i, t_mult = 0, 150, 1 t_cur, t_i, t_mult = 0, 150, 1
i, last, running = 0, float('inf'), None i, running = 0, None
for _ in range(args.epochs): for _ in range(args.epochs):
total = 0.0 total = None
for _ in range(args.epoch_size): for _ in range(args.epoch_size):
i += 1 i += 1
l_r = l_min + 0.5 * (l_max - l_min) * (1.0 + math.cos(t_cur / t_i * math.pi)) l_r = l_min + 0.5 * (l_max - l_min) * (1.0 + math.cos(t_cur / t_i * math.pi))
t_cur += 1 t_cur += 1
self.model.learning_rate.set_value(l_r) self.model.gen_lr.set_value(l_r)
if t_cur >= t_i: if t_cur >= t_i:
t_cur = 0 t_cur = 0
@ -296,16 +310,17 @@ class NeuralEnhancer(object):
l_min = max(l_min * l_mult, 1e-6) l_min = max(l_min * l_mult, 1e-6)
self.thread.copy(images) self.thread.copy(images)
losses = self.model.fit(images) losses = np.array(self.model.fit(images), dtype=np.float32)
l = sum(losses) total = total + losses if total is not None else losses
total += l l = np.sum(losses)
running = l if running is None else running * 0.9 + 0.1 * l running = l if running is None else running * 0.9 + 0.1 * l
print('' if l >= running else '', end=' ', flush=True) print('' if l >= running else '', end=' ', flush=True)
self.show_progress(*self.model.predict(images)) self.show_progress(*self.model.predict(images))
last = total / args.epoch_size total = total / args.epoch_size
print('\nLosses total:', last) labels = ['{}={:4.2e}'.format(k, v) for k, v in zip(['prcpt', 'smthn', 'advrs'], total)]
print('\nLosses: total={:4.2e} {}'.format(sum(total), ' '.join(labels)))
if __name__ == "__main__": if __name__ == "__main__":

Loading…
Cancel
Save