Add loading parameters from saved models. Clean up learning-rate code.

main
Alex J. Champandard 9 years ago
parent bb3b24b04c
commit 619fad7f3c

@ -1,9 +1,9 @@
Neural Enhance Neural Enhance
============== ==============
`As seen on TV! <https://www.youtube.com/watch?v=LhF_56SxrGk>`_ What if you could increase the resolution of your photos using techniques from CSI laboratories? Thanks to deep learning, it's now possible to train a neural network to zoom in to your images using examples. You'll get better results by increasing the number of neurons and specializing the training images (e.g. faces). `As seen on TV! <https://www.youtube.com/watch?v=LhF_56SxrGk>`_ What if you could increase the resolution of your photos using technology from CSI laboratories? Thanks to deep learning, it's now possible to train a neural network to zoom in to your images using examples. You'll get better results by increasing the number of neurons and specializing the training images (e.g. faces).
The catch? The neural network is hallucinating details based on its analysis of example images. It's not reconstructing the image exactly as it would have been if it was HD. That's only possible in Holywood — but deep learning as "Creative AI" works and its just as cool! Here's how you can get started... The catch? The neural network is hallucinating details based on its training from example images. It's not reconstructing the image exactly as it would have been if it was HD. That's only possible in Holywood — but deep learning as "Creative AI" works and its just as cool! Here's how you can get started...
1. `Examples & Usage <#1-examples--usage>`_ 1. `Examples & Usage <#1-examples--usage>`_
2. `Installation <#2-installation--setup>`_ 2. `Installation <#2-installation--setup>`_
@ -152,7 +152,7 @@ It seems your terminal is misconfigured and not compatible with the way Python t
Q: Is there an application for this? I want to download it! Q: Is there an application for this? I want to download it!
----------------------------------------------------------- -----------------------------------------------------------
There are many online services that provide basic style transfer with neural networks. We run `@DeepForger <https://deepforger.com/>`_, a Twitter & Facebook bot with web interface, that can take your requests too. It takes time to make forgeries, so there's a queue... be patient! A: Not yet.
---- ----

@ -50,7 +50,7 @@ add_arg('--adversary-weight', default=1e2, type=float, help='Weight
add_arg('--generator-start', default=0, type=int, help='Epoch count to start training generator.') add_arg('--generator-start', default=0, type=int, help='Epoch count to start training generator.')
add_arg('--discriminator-start',default=1, type=int, help='Epoch count to update the discriminator.') add_arg('--discriminator-start',default=1, type=int, help='Epoch count to update the discriminator.')
add_arg('--adversarial-start', default=2, type=int, help='Epoch for generator to use discriminator.') add_arg('--adversarial-start', default=2, type=int, help='Epoch for generator to use discriminator.')
add_arg('--device', default='cpu', type=str, help='Name of the CPU/GPU to use, for Theano.') add_arg('--device', default='gpu0', type=str, help='Name of the CPU/GPU to use, for Theano.')
args = parser.parse_args() args = parser.parse_args()
@ -132,7 +132,7 @@ class DataLoader(threading.Thread):
filename = os.path.join(self.cwd, f) filename = os.path.join(self.cwd, f)
try: try:
if f not in cache: if f not in cache:
if len(cache) > 3172: if len(cache) > 2048:
del cache[random.choice(list(cache.keys()))] del cache[random.choice(list(cache.keys()))]
img = scipy.ndimage.imread(filename, mode='RGB') img = scipy.ndimage.imread(filename, mode='RGB')
@ -200,14 +200,16 @@ class Model(object):
else: else:
self.network['img'] = InputLayer((None, 3, None, None)) self.network['img'] = InputLayer((None, 3, None, None))
self.network['seed'] = self.network['img'] self.network['seed'] = self.network['img']
self.setup_generator(self.last_layer())
config, params = self.load_model()
self.setup_generator(self.last_layer(), config)
if args.train: if args.train:
concatenated = lasagne.layers.ConcatLayer([self.network['img'], self.network['out']], axis=0) concatenated = lasagne.layers.ConcatLayer([self.network['img'], self.network['out']], axis=0)
self.setup_perceptual(concatenated) self.setup_perceptual(concatenated)
self.load_perceptual() self.load_perceptual()
self.setup_discriminator() self.setup_discriminator()
self.load_generator() self.load_generator(params)
self.compile() self.compile()
@ -230,7 +232,8 @@ class Model(object):
self.make_layer(name+'-B', self.last_layer(), units, alpha=1.0) self.make_layer(name+'-B', self.last_layer(), units, alpha=1.0)
return ElemwiseSumLayer([input, self.last_layer()]) if args.generator_residual else self.last_layer() return ElemwiseSumLayer([input, self.last_layer()]) if args.generator_residual else self.last_layer()
def setup_generator(self, input): def setup_generator(self, input, config):
for k, v in config.items(): setattr(args, k, v)
units = args.generator_filters units = args.generator_filters
self.make_layer('iter.0', input, units, filter_size=(5,5), pad=(2,2)) self.make_layer('iter.0', input, units, filter_size=(5,5), pad=(2,2))
@ -312,21 +315,25 @@ class Model(object):
def save_generator(self): def save_generator(self):
def cast(p): return p.get_value().astype(np.float16) def cast(p): return p.get_value().astype(np.float16)
params = {k: [cast(p) for p in l.get_params()] for (k, l) in self.list_generator_layers()} params = {k: [cast(p) for p in l.get_params()] for (k, l) in self.list_generator_layers()}
config = {k: getattr(args, k) for k in ['generator_blocks', 'generator_residual', 'generator_filters']}
filename = args.model % 2**args.scales filename = args.model % 2**args.scales
pickle.dump(params, bz2.open(filename, 'wb')) pickle.dump((config, params), bz2.open(filename, 'wb'))
print(' - Saved model as `{}` after training.'.format(filename)) print(' - Saved model as `{}` after training.'.format(filename))
def load_generator(self): def load_model(self):
filename = args.model % 2**args.scales filename = args.model % 2**args.scales
if not os.path.exists(filename): return if not os.path.exists(filename): return {}, {}
params = pickle.load(bz2.open(filename, 'rb')) print(' - Loaded file `{}` with trained model.'.format(filename))
return pickle.load(bz2.open(filename, 'rb'))
def load_generator(self, params):
if len(params) == 0: return
for k, l in self.list_generator_layers(): for k, l in self.list_generator_layers():
assert k in params, "Couldn't find layer `%s` in loaded model.'" assert k in params, "Couldn't find layer `%s` in loaded model.'"
assert len(l.get_params()) == len(params[k]), "Mismatch in types of layers." assert len(l.get_params()) == len(params[k]), "Mismatch in types of layers."
for p, v in zip(l.get_params(), params[k]): for p, v in zip(l.get_params(), params[k]):
assert v.shape == p.get_value().shape, "Mismatch in number of parameters." assert v.shape == p.get_value().shape, "Mismatch in number of parameters."
p.set_value(v.astype(np.float32)) p.set_value(v.astype(np.float32))
print(' - Loaded file `{}` with trained model.'.format(filename))
#------------------------------------------------------------------------------------------------------------------ #------------------------------------------------------------------------------------------------------------------
# Training & Loss Functions # Training & Loss Functions
@ -401,29 +408,31 @@ class NeuralEnhancer(object):
self.imsave('valid/%03i_pixels.png' % i, scald[i]) self.imsave('valid/%03i_pixels.png' % i, scald[i])
self.imsave('valid/%03i_reprod.png' % i, repro[i]) self.imsave('valid/%03i_reprod.png' % i, repro[i])
def decay_with_restart(self):
l_min, l_max, l_mult = 1E-7, 1E-3, 0.5
t_cur, t_i, t_mult = 10, 10, 1
while True:
yield l_min + 0.5 * (l_max - l_min) * (1.0 + math.cos(t_cur / t_i * math.pi))
t_cur += 1
if t_cur > t_i:
t_cur, t_i = 0, int(t_i * t_mult)
l_max = max(l_max * l_mult, 1e-12)
l_min = max(l_min * l_mult, 1e-8)
def train(self): def train(self):
images = np.zeros((args.batch_size, 3, args.batch_resolution, args.batch_resolution), dtype=np.float32) images = np.zeros((args.batch_size, 3, args.batch_resolution, args.batch_resolution), dtype=np.float32)
l_min, l_max, l_mult = 1E-7, 1E-3, 0.2 learning_rate = self.decay_with_restart()
t_cur, t_i, t_mult = 120, 150, 1
try: try:
i, running, start = 0, None, time.time() running, start = None, time.time()
for epoch in range(args.epochs): for epoch in range(args.epochs):
total, stats = None, None total, stats = None, None
for _ in range(args.epoch_size): l_r = next(learning_rate)
i += 1
l_r = l_min + 0.5 * (l_max - l_min) * (1.0 + math.cos(t_cur / t_i * math.pi))
t_cur += 1
l_r = 1E-4
if epoch >= args.generator_start: self.model.gen_lr.set_value(l_r) if epoch >= args.generator_start: self.model.gen_lr.set_value(l_r)
if epoch >= args.discriminator_start: self.model.disc_lr.set_value(l_r) if epoch >= args.discriminator_start: self.model.disc_lr.set_value(l_r)
if t_cur >= t_i: for _ in range(args.epoch_size):
t_cur, t_i = 0, int(t_i * t_mult)
l_max = max(l_max * l_mult, 1e-11)
l_min = max(l_min * l_mult, 1e-7)
self.thread.copy(images) self.thread.copy(images)
output = self.model.fit(images) output = self.model.fit(images)
losses = np.array(output[:3], dtype=np.float32) losses = np.array(output[:3], dtype=np.float32)
@ -440,12 +449,12 @@ class NeuralEnhancer(object):
stats /= args.epoch_size stats /= args.epoch_size
totals, labels = [sum(total)] + list(total), ['total', 'prcpt', 'smthn', 'advrs'] totals, labels = [sum(total)] + list(total), ['total', 'prcpt', 'smthn', 'advrs']
gen_info = ['{}{}{}={:4.2e}'.format(ansi.WHITE_B, k, ansi.ENDC, v) for k, v in zip(labels, totals)] gen_info = ['{}{}{}={:4.2e}'.format(ansi.WHITE_B, k, ansi.ENDC, v) for k, v in zip(labels, totals)]
print('\rEpoch #{} at {:4.1f}s{}'.format(epoch+1, time.time()-start, ' '*args.epoch_size)) print('\rEpoch #{} at {:4.1f}s, lr={:4.2e} {}'.format(epoch+1, time.time()-start, l_r, ' '*args.epoch_size))
print(' - generator {}'.format(' '.join(gen_info))) print(' - generator {}'.format(' '.join(gen_info)))
real, fake = stats[:args.batch_size], stats[args.batch_size:] real, fake = stats[:args.batch_size], stats[args.batch_size:]
print(' - discriminator', real.mean(), len(np.where(real > 0.5)[0]), fake.mean(), len(np.where(fake < 0.5)[0])) print(' - discriminator', real.mean(), len(np.where(real > 0.5)[0]), fake.mean(), len(np.where(fake < 0.5)[0]))
if epoch == args.adversary_start-1: if epoch == args.adversarial_start-1:
print(' - adversary mode: generator engaging discriminator.') print(' - adversary mode: generator engaging discriminator.')
self.model.adversary_weight.set_value(args.adversary_weight) self.model.adversary_weight.set_value(args.adversary_weight)
running = None running = None

Loading…
Cancel
Save