Improve code for simply applying super-resolution.

main
Alex J. Champandard 9 years ago
parent 30534c6dd1
commit f868514be3

@ -1,11 +1,13 @@
Neural Enhance
==============
**Example #1** — China Town: `view comparison <http://5.9.70.47:4141/w/3b3c8054-9d00-11e6-9558-c86000be451f/view>`_ in 24-bit HD, `original photo <https://flic.kr/p/gnxcXH>`_ CC-BY-SA @cyalex.
.. image:: docs/OldStation_example.gif
.. image:: docs/Chinatown_example.gif
**Example #1** — Old Station: `view comparison <http://5.9.70.47:4141/w/0f5177f4-9ce6-11e6-992c-c86000be451f/view>`_ in 24-bit HD, `original photo <https://flic.kr/p/oYhbBv>`_ CC-BY-SA @siv-athens.
----
`As seen on TV! <https://www.youtube.com/watch?v=LhF_56SxrGk>`_ What if you could increase the resolution of your photos using technology from CSI laboratories? Thanks to deep learning and ``#NeuralEnhance``, it's now possible to train a neural network to zoom in to your images at 2x or even 4x. You'll get even better results by increasing the number of neurons or using specialized training images (e.g. faces).
`As seen on TV! <https://www.youtube.com/watch?v=LhF_56SxrGk>`_ What if you could increase the resolution of your photos using technology from CSI laboratories? Thanks to deep learning and ``#NeuralEnhance``, it's now possible to train a neural network to zoom in to your images at 2x or even 4x. You'll get even better results by increasing the number of neurons or training with a dataset similar to your low resolution image.
The catch? The neural network is hallucinating details based on its training from example images. It's not reconstructing your photo exactly as it would have been if it was HD. That's only possible in Holywood — but using deep learning as "Creative AI" works and its just as cool! Here's how you can get started...
@ -58,10 +60,10 @@ The default is to use ``--device=cpu``, if you have NVIDIA card setup with CUDA
--smoothness-weight=5e4 --adversary-weight=2e2 \
--generator-start=1 --discriminator-start=0 --adversarial-start=1
**Example #2** — Bank Lobby: `view comparison <http://5.9.70.47:4141/w/38d10880-9ce6-11e6-becb-c86000be451f/view>`_ in 24-bit HD, `original photo <https://flic.kr/p/6a8cwm>`_ CC-BY-SA @benarent.
.. image:: docs/BankLobby_example.gif
**Example #2** — Bank Lobby: `view comparison <http://5.9.70.47:4141/w/38d10880-9ce6-11e6-becb-c86000be451f/view>`_ in 24-bit HD, `original photo <https://flic.kr/p/6a8cwm>`_ CC-BY-SA @benarent.
2. Installation & Setup
=======================
@ -100,7 +102,7 @@ After this, you should have ``pillow``, ``theano`` and ``lasagne`` installed in
3. Background & Research
========================
This code uses a combination of techniques from the following papers, as well as some minor improvements yet to be documented:
This code uses a combination of techniques from the following papers, as well as some minor improvements yet to be documented (watch this repository for updates):
1. `Perceptual Losses for Real-Time Style Transfer and Super-Resolution <http://arxiv.org/abs/1603.08155>`_
2. `Real-Time Super-Resolution Using Efficient Sub-Pixel Convolution <https://arxiv.org/abs/1609.05158>`_
@ -142,10 +144,9 @@ It seems your terminal is misconfigured and not compatible with the way Python t
**FIX:** ``export LC_ALL=en_US.UTF-8``
.. image:: docs/Chinatown_example.gif
.. image:: docs/OldStation_example.gif
**Example #3** — Old Station: `view comparison <http://5.9.70.47:4141/w/0f5177f4-9ce6-11e6-992c-c86000be451f/view>`_ in 24-bit HD, `original photo <https://flic.kr/p/oYhbBv>`_ CC-BY-SA @siv-athens.
**Example #3** — China Town: `view comparison <http://5.9.70.47:4141/w/3b3c8054-9d00-11e6-9558-c86000be451f/view>`_ in 24-bit HD, `original photo <https://flic.kr/p/gnxcXH>`_ CC-BY-SA @cyalex.
----

@ -14,6 +14,8 @@
# without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
#
__version__ = '0.1'
import os
import sys
import bz2
@ -34,7 +36,7 @@ parser = argparse.ArgumentParser(description='Generate a new image by applying s
add_arg = parser.add_argument
add_arg('files', nargs='*', default=[])
add_arg('--scales', default=2, type=int, help='How many times to perform 2x upsampling.')
add_arg('--model', default='ne%ix.pkl.bz2', type=str, help='Name of the neural network to load/save.')
add_arg('--model', default='medium', type=str, help='Name of the neural network to load/save.')
add_arg('--train', default=False, action='store_true', help='Learn new or fine-tune a neural network.')
add_arg('--batch-resolution', default=192, type=int, help='Resolution of images in training batch.')
add_arg('--batch-size', default=15, type=int, help='Number of images per training batch.')
@ -46,11 +48,11 @@ add_arg('--learning-rate', default=1E-4, type=float, help='Parame
add_arg('--learning-period', default=50, type=int, help='How often to decay the learning rate.')
add_arg('--learning-decay', default=0.5, type=float, help='How much to decay the learning rate.')
add_arg('--generator-filters', default=[64], nargs='+', type=int, help='Number of convolution units in network.')
add_arg('--generator-blocks', default=12, type=int, help='Number of residual blocks per iteration.')
add_arg('--generator-iters', default=1, type=int, help='Number of iterations in total.')
add_arg('--generator-blocks', default=4, type=int, help='Number of residual blocks per iteration.')
add_arg('--generator-residual', default=2, type=int, help='Number of layers in a residual block.')
add_arg('--perceptual-layer', default='conv2_2', type=str, help='Which VGG layer to use as loss component.')
add_arg('--perceptual-weight', default=1e0, type=float, help='Weight for VGG-layer perceptual loss.')
add_arg('--discriminator-size', default=32, type=int, help='Multiplier for number of filters in D.')
add_arg('--smoothness-weight', default=2e5, type=float, help='Weight of the total-variation loss.')
add_arg('--adversary-weight', default=1e2, type=float, help='Weight of adversarial loss compoment.')
add_arg('--generator-start', default=0, type=int, help='Epoch count to start training generator.')
@ -102,7 +104,7 @@ import scipy.optimize, scipy.ndimage, scipy.misc
# Numeric Computing (GPU)
import theano
import theano.tensor as T
import theano.tensor.nnet.neighbours
T.nnet.softminus = lambda x: x - T.nnet.softplus(x)
# Support ansi colors in Windows too.
if sys.platform == 'win32':
@ -233,23 +235,8 @@ class Model(object):
return list(self.network.values())[-1]
def make_layer(self, name, input, units, filter_size=(3,3), stride=(1,1), pad=(1,1), alpha=0.25):
orig = '1.'+''.join(name.split('.')[1:])
if orig+'x' in self.network:
print('reused', orig, 'for', name)
l = self.network[orig +'x']
extra = {'W': l.W, 'b': l.b}
else:
extra = {}
conv = ConvLayer(input, units, filter_size=filter_size, stride=stride, pad=pad, nonlinearity=None, **extra)
alpha = lasagne.init.Constant(alpha)
if orig +'>' in self.network:
l = self.network[orig +'>']
alpha = l.alpha
prelu = lasagne.layers.ParametricRectifierLayer(conv, alpha=alpha)
conv = ConvLayer(input, units, filter_size=filter_size, stride=stride, pad=pad, nonlinearity=None)
prelu = lasagne.layers.ParametricRectifierLayer(conv, alpha=lasagne.init.Constant(alpha))
self.network[name+'x'] = conv
self.network[name+'>'] = prelu
return prelu
@ -267,12 +254,8 @@ class Model(object):
self.make_layer('iter.0-B', self.last_layer(), units, filter_size=(5,5), pad=(2,2))
self.network['iter.0'] = self.last_layer()
for i in range(0, args.generator_iters):
base = self.last_layer()
for j in range(0, args.generator_blocks):
self.make_block('%i.iter-%i'%(i+1, j), self.last_layer(), units)
print('iter.%i-%i'%(i+1, j))
# self.network['iter.%i'%(i+1)] = DropPathLayer([base, self.last_layer()])
for i in range(0, args.generator_blocks):
self.make_block('iter.%i'%(i+1), self.last_layer(), units)
for i in range(0, args.scales):
u = next(units_iter)
@ -280,7 +263,6 @@ class Model(object):
self.network['scale%i.2'%i] = SubpixelReshuffleLayer(self.last_layer(), u, 2)
self.make_layer('scale%i.1'%i, self.last_layer(), u)
self.network['out'] = ConvLayer(self.last_layer(), 3, filter_size=(5,5), stride=(1,1), pad=(2,2),
nonlinearity=lasagne.nonlinearities.tanh)
@ -314,15 +296,17 @@ class Model(object):
self.network['conv5_4'] = ConvLayer(self.network['conv5_3'], 512, 3, pad=1)
def setup_discriminator(self):
self.make_layer('disc1.1', batch_norm(self.network['conv1_2']), 64, filter_size=(5,5), stride=(2,2), pad=(2,2))
self.make_layer('disc1.2', self.last_layer(), 64, filter_size=(5,5), stride=(2,2), pad=(2,2))
self.make_layer('disc2', batch_norm(self.network['conv2_2']), 128, filter_size=(5,5), stride=(2,2), pad=(2,2))
self.make_layer('disc3', batch_norm(self.network['conv3_2']), 192, filter_size=(3,3), stride=(1,1), pad=(1,1))
c = args.discriminator_size
self.make_layer('disc1.1', batch_norm(self.network['conv1_2']), 1*c, filter_size=(5,5), stride=(2,2), pad=(2,2))
self.make_layer('disc1.2', self.last_layer(), 1*c, filter_size=(5,5), stride=(2,2), pad=(2,2))
self.make_layer('disc2', batch_norm(self.network['conv2_2']), 2*c, filter_size=(5,5), stride=(2,2), pad=(2,2))
self.make_layer('disc3', batch_norm(self.network['conv3_2']), 3*c, filter_size=(3,3), stride=(1,1), pad=(1,1))
hypercolumn = ConcatLayer([self.network['disc1.2>'], self.network['disc2>'], self.network['disc3>']])
self.make_layer('disc4', hypercolumn, 192, filter_size=(3,3), stride=(1,1))
self.make_layer('disc5', self.last_layer(), 96, filter_size=(3,3), stride=(1,1))
self.make_layer('disc4', hypercolumn, 4*c, filter_size=(1,1), stride=(1,1), pad=(0,0))
self.make_layer('disc5', self.last_layer(), 3*c, filter_size=(3,3), stride=(2,2))
self.make_layer('disc6', self.last_layer(), 2*c, filter_size=(1,1), stride=(1,1), pad=(0,0))
self.network['disc'] = batch_norm(ConvLayer(self.last_layer(), 1, filter_size=(1,1),
nonlinearity=lasagne.nonlinearities.sigmoid))
nonlinearity=lasagne.nonlinearities.linear))
#------------------------------------------------------------------------------------------------------------------
@ -351,20 +335,23 @@ class Model(object):
def cast(p): return p.get_value().astype(np.float16)
params = {k: [cast(p) for p in l.get_params()] for (k, l) in self.list_generator_layers()}
config = {k: getattr(args, k) for k in ['generator_blocks', 'generator_residual', 'generator_filters']}
filename = args.model % 2**args.scales
filename = 'ne%ix-%s-%s.pkl.bz2' % (2**args.scales, args.model, __version__)
pickle.dump((config, params), bz2.open(filename, 'wb'))
print(' - Saved model as `{}` after training.'.format(filename))
def load_model(self):
filename = args.model % 2**args.scales
if not os.path.exists(filename): return {}, {}
filename = 'ne%ix-%s-%s.pkl.bz2' % (2**args.scales, args.model, __version__)
if not os.path.exists(filename):
if args.train: return {}, {}
error("Model file with pre-trained convolution layers not found. Download it here...",
"https://github.com/alexjc/neural-enhance/releases/download/v%s/%s"%(__version__, filename))
print(' - Loaded file `{}` with trained model.'.format(filename))
return pickle.load(bz2.open(filename, 'rb'))
def load_generator(self, params):
if len(params) == 0: return
for k, l in self.list_generator_layers():
assert k in params, "Couldn't find layer `%s` in loaded model.'"
assert k in params, "Couldn't find layer `%s` in loaded model.'" % k
assert len(l.get_params()) == len(params[k]), "Mismatch in types of layers."
for p, v in zip(l.get_params(), params[k]):
assert v.shape == p.get_value().shape, "Mismatch in number of parameters."
@ -381,10 +368,10 @@ class Model(object):
return T.mean(((x[:,:,:-1,:-1] - x[:,:,1:,:-1])**2 + (x[:,:,:-1,:-1] - x[:,:,:-1,1:])**2)**1.25)
def loss_adversarial(self, d):
return 1.0 - T.log(1E-6 + d[args.batch_size:]).mean()
return T.mean(1.0 - T.nnet.softplus(d[args.batch_size:]))
def loss_discriminator(self, d):
return T.mean(T.log(1E-6 + d[args.batch_size:]) + T.log(1E-6 + 1.0 - d[:args.batch_size]))
return T.mean(T.nnet.softminus(d[args.batch_size:]) - T.nnet.softplus(d[:args.batch_size]))
def compile(self):
# Helper function for rendering test images during training, or standalone non-training mode.
@ -424,8 +411,12 @@ class Model(object):
class NeuralEnhancer(object):
def __init__(self):
print('{}Training {} epochs on random image sections with batch size {}.{}'\
.format(ansi.BLUE_B, args.epochs, args.batch_size, ansi.BLUE))
if args.train:
print('{}Training {} epochs on random image sections with batch size {}.{}'\
.format(ansi.BLUE_B, args.epochs, args.batch_size, ansi.BLUE))
else:
print('{}Enhancing {} image(s) specified on the command-line.{}'\
.format(ansi.BLUE_B, len(args.files), ansi.BLUE))
self.thread = DataLoader()
self.model = Model()
@ -483,7 +474,7 @@ class NeuralEnhancer(object):
print(' - generator {}'.format(' '.join(gen_info)))
real, fake = stats[:args.batch_size], stats[args.batch_size:]
print(' - discriminator', real.mean(), len(np.where(real > 0.5)[0]), fake.mean(), len(np.where(fake < 0.5)[0]))
print(' - discriminator', real.mean(), len(np.where(real > 0.5)[0]), fake.mean(), len(np.where(fake < -0.5)[0]))
if epoch == args.adversarial_start-1:
print(' - adversary mode: generator engaging discriminator.')
self.model.adversary_weight.set_value(args.adversary_weight)
@ -511,5 +502,9 @@ if __name__ == "__main__":
enhancer.train()
for filename in args.files:
print(filename)
out = enhancer.process(scipy.ndimage.imread(filename, mode='RGB'))
out.save(os.path.splitext(filename)[0]+'_enhanced.png')
out.save(os.path.splitext(filename)[0]+'_ne%ix-%s.png'%(2**args.scales, args.model))
if args.files:
print(ansi.ENDC)

Loading…
Cancel
Save