Add command-line for super-resolving images without training.

main
Alex J. Champandard 9 years ago
parent 03a6813e95
commit 9558a11397

@ -1,4 +1,4 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
""" _ _ """ _ _
_ __ ___ _ _ _ __ __ _| | ___ _ __ | |__ __ _ _ __ ___ ___ _ __ ___ _ _ _ __ __ _| | ___ _ __ | |__ __ _ _ __ ___ ___
| '_ \ / _ \ | | | '__/ _` | | / _ \ '_ \| '_ \ / _` | '_ \ / __/ _ \ | '_ \ / _ \ | | | '__/ _` | | / _ \ '_ \| '_ \ / _` | '_ \ / __/ _ \
@ -32,22 +32,24 @@ import collections
parser = argparse.ArgumentParser(description='Generate a new image by applying style onto a content image.', parser = argparse.ArgumentParser(description='Generate a new image by applying style onto a content image.',
formatter_class=argparse.ArgumentDefaultsHelpFormatter) formatter_class=argparse.ArgumentDefaultsHelpFormatter)
add_arg = parser.add_argument add_arg = parser.add_argument
add_arg('files', nargs='*', default=[])
add_arg('--train', default=False, action='store_true')
add_arg('--load', default=None, action='store_true') add_arg('--load', default=None, action='store_true')
add_arg('--save', default=None, action='store_true') add_arg('--save', default=None, action='store_true')
add_arg('--model', default='ne%ix.pkl.bz2', type=str) add_arg('--model', default='ne%ix.pkl.bz2', type=str)
add_arg('--batch-size', default=15, type=int) add_arg('--batch-size', default=1, type=int)
add_arg('--batch-resolution', default=256, type=int) add_arg('--batch-resolution', default=224, type=int)
add_arg('--epoch-size', default=36, type=int) add_arg('--epoch-size', default=36, type=int)
add_arg('--epochs', default=10, type=int) add_arg('--epochs', default=10, type=int)
add_arg('--generator-filters', default=128, type=int) add_arg('--generator-filters', default=256, type=int)
add_arg('--generator-blocks', default=4, type=int) add_arg('--generator-blocks', default=4, type=int)
add_arg('--generator-residual', default=2, type=int) add_arg('--generator-residual', default=2, type=int)
add_arg('--perceptual-layer', default='conv2_2', type=str) add_arg('--perceptual-layer', default='conv2_2', type=str)
add_arg('--perceptual-weight', default=1e0, type=float) add_arg('--perceptual-weight', default=1e0, type=float)
add_arg('--smoothness-weight', default=1e7, type=float) add_arg('--smoothness-weight', default=2e5, type=float)
add_arg('--adversary-weight', default=0.0, type=float) add_arg('--adversary-weight', default=2e2, type=float)
add_arg('--scales', default=1, type=int, help='') add_arg('--scales', default=2, type=int, help='')
add_arg('--device', default='gpu0', type=str, help='Name of the CPU/GPU number to use, for Theano.') add_arg('--device', default='cpu', type=str, help='Name of the CPU/GPU number to use, for Theano.')
args = parser.parse_args() args = parser.parse_args()
@ -100,7 +102,7 @@ if sys.platform == 'win32':
# Deep Learning Framework # Deep Learning Framework
import lasagne import lasagne
from lasagne.layers import Conv2DLayer as ConvLayer, Deconv2DLayer as DeconvLayer, Pool2DLayer as PoolLayer from lasagne.layers import Conv2DLayer as ConvLayer, Deconv2DLayer as DeconvLayer, Pool2DLayer as PoolLayer
from lasagne.layers import InputLayer, ConcatLayer, ElemwiseSumLayer from lasagne.layers import InputLayer, ConcatLayer, ElemwiseSumLayer, batch_norm
print('{} - Using the device `{}` for neural computation.{}\n'.format(ansi.CYAN, theano.config.device, ansi.ENDC)) print('{} - Using the device `{}` for neural computation.{}\n'.format(ansi.CYAN, theano.config.device, ansi.ENDC))
@ -157,12 +159,12 @@ class DataLoader(threading.Thread):
# Convolution Networks # Convolution Networks
#====================================================================================================================== #======================================================================================================================
class SubpixelShuffle(lasagne.layers.Layer): class SubpixelReshuffleLayer(lasagne.layers.Layer):
"""Based on the code by ajbrock: https://github.com/ajbrock/Neural-Photo-Editor/ """Based on the code by ajbrock: https://github.com/ajbrock/Neural-Photo-Editor/
""" """
def __init__(self, incoming, channels, upscale, **kwargs): def __init__(self, incoming, channels, upscale, **kwargs):
super(SubpixelShuffle, self).__init__(incoming, **kwargs) super(SubpixelReshuffleLayer, self).__init__(incoming, **kwargs)
self.upscale = upscale self.upscale = upscale
self.channels = channels self.channels = channels
@ -181,10 +183,14 @@ class Model(object):
def __init__(self): def __init__(self):
self.network = collections.OrderedDict() self.network = collections.OrderedDict()
if args.train:
self.network['img'] = InputLayer((None, 3, None, None)) self.network['img'] = InputLayer((None, 3, None, None))
self.network['seed'] = PoolLayer(self.network['img'], pool_size=2**args.scales, mode='average_exc_pad') self.network['seed'] = PoolLayer(self.network['img'], pool_size=2**args.scales, mode='average_exc_pad')
self.setup_generator(self.network['seed']) else:
self.network['img'] = InputLayer((None, 3, None, None))
self.setup_generator(self.last_layer())
if args.train:
concatenated = lasagne.layers.ConcatLayer([self.network['img'], self.network['out']], axis=0) concatenated = lasagne.layers.ConcatLayer([self.network['img'], self.network['out']], axis=0)
self.setup_perceptual(concatenated) self.setup_perceptual(concatenated)
self.load_perceptual() self.load_perceptual()
@ -222,7 +228,7 @@ class Model(object):
for i in range(args.scales, 0, -1): for i in range(args.scales, 0, -1):
u = units // 2**(args.scales-i) u = units // 2**(args.scales-i)
self.make_layer('scale%i.3'%i, self.last_layer(), u*4) self.make_layer('scale%i.3'%i, self.last_layer(), u*4)
self.network['scale%i.2'%i] = SubpixelShuffle(self.last_layer(), u, 2) self.network['scale%i.2'%i] = SubpixelReshuffleLayer(self.last_layer(), u, 2)
self.make_layer('scale%i.1'%i, self.network['scale%i.2'%i], u) self.make_layer('scale%i.1'%i, self.network['scale%i.2'%i], u)
self.network['out'] = ConvLayer(self.last_layer(), 3, filter_size=(5,5), stride=(1,1), pad=(2,2), self.network['out'] = ConvLayer(self.last_layer(), 3, filter_size=(5,5), stride=(1,1), pad=(2,2),
@ -260,11 +266,11 @@ class Model(object):
def setup_discriminator(self): def setup_discriminator(self):
self.make_layer('disc1.1', batch_norm(self.network['conv1_2']), 64, filter_size=(5,5), stride=(2,2), pad=(2,2)) self.make_layer('disc1.1', batch_norm(self.network['conv1_2']), 64, filter_size=(5,5), stride=(2,2), pad=(2,2))
self.make_layer('disc1.2', self.last_layer(), 64, filter_size=(5,5), stride=(2,2), pad=(2,2)) self.make_layer('disc1.2', self.last_layer(), 64, filter_size=(5,5), stride=(2,2), pad=(2,2))
self.make_layer('disc2', self.network['conv2_2'], 128, filter_size=(5,5), stride=(2,2), pad=(2,2)) self.make_layer('disc2', batch_norm(self.network['conv2_2']), 128, filter_size=(5,5), stride=(2,2), pad=(2,2))
self.make_layer('disc3', self.network['conv3_2'], 256, filter_size=(3,3), stride=(1,1), pad=(1,1)) # self.make_layer('disc3', batch_norm(self.network['conv3_2']), 256, filter_size=(3,3), stride=(1,1), pad=(1,1))
hypercolumn = ConcatLayer([self.network['disc1.2>'], self.network['disc2>'], self.network['disc3>']]) hypercolumn = ConcatLayer([self.network['disc1.2>'], self.network['disc2>']])
self.make_layer('disc4', hypercolumn, 192, filter_size=(5,5), stride=(2,2)) self.make_layer('disc4', hypercolumn, 128, filter_size=(3,3), stride=(1,1))
self.make_layer('disc5', self.last_layer(), 96, filter_size=(5,5), stride=(2,2)) self.make_layer('disc5', self.last_layer(), 64, filter_size=(3,3), stride=(1,1))
self.network['disc'] = batch_norm(ConvLayer(self.last_layer(), 1, filter_size=(1,1), self.network['disc'] = batch_norm(ConvLayer(self.last_layer(), 1, filter_size=(1,1),
nonlinearity=lasagne.nonlinearities.sigmoid)) nonlinearity=lasagne.nonlinearities.sigmoid))
@ -286,7 +292,7 @@ class Model(object):
for p, d in zip(itertools.chain(*[l.get_params() for l in layers]), data): p.set_value(d) for p, d in zip(itertools.chain(*[l.get_params() for l in layers]), data): p.set_value(d)
def list_generator_layers(self): def list_generator_layers(self):
for l in lasagne.layers.get_all_layers(self.network['out'], treat_as_input=[self.network['seed']]): for l in lasagne.layers.get_all_layers(self.network['out'], treat_as_input=[self.network['img']]):
if not l.get_params(): continue if not l.get_params(): continue
name = list(self.network.keys())[list(self.network.values()).index(l)] name = list(self.network.keys())[list(self.network.values()).index(l)]
yield (name, l) yield (name, l)
@ -304,8 +310,10 @@ class Model(object):
if not os.path.exists(filename) or not args.load: return if not os.path.exists(filename) or not args.load: return
params = pickle.load(bz2.open(filename, 'rb')) params = pickle.load(bz2.open(filename, 'rb'))
for k, l in self.list_generator_layers(): for k, l in self.list_generator_layers():
if k not in params: continue assert k in params, "Couldn't find layer `%s` in loaded model.'"
assert len(l.get_params()) == len(params[k]), "Mismatch in types of layers."
for p, v in zip(l.get_params(), params[k]): for p, v in zip(l.get_params(), params[k]):
assert v.shape == p.get_value().shape, "Mismatch in number of parameters."
p.set_value(v.astype(np.float32)) p.set_value(v.astype(np.float32))
print(' - Loaded file `{}` with trained model.'.format(filename)) print(' - Loaded file `{}` with trained model.'.format(filename))
@ -326,10 +334,15 @@ class Model(object):
return T.mean(T.log(d[args.batch_size:]) + T.log(1.0 - d[:args.batch_size])) return T.mean(T.log(d[args.batch_size:]) + T.log(1.0 - d[:args.batch_size]))
def compile(self): def compile(self):
# Helper function for rendering test images during training, or standalone non-training mode.
input_tensor = T.tensor4() input_tensor = T.tensor4()
output_layers = [self.network['out'], self.network[args.perceptual_layer], self.network['disc']]
input_layers = {self.network['img']: input_tensor} input_layers = {self.network['img']: input_tensor}
output = lasagne.layers.get_output([self.network[k] for k in ['img', 'out']], input_layers, deterministic=True)
self.predict = theano.function([input_tensor], output)
if not args.train: return
output_layers = [self.network['out'], self.network[args.perceptual_layer], self.network['disc']]
gen_out, percept_out, disc_out = lasagne.layers.get_output(output_layers, input_layers, deterministic=False) gen_out, percept_out, disc_out = lasagne.layers.get_output(output_layers, input_layers, deterministic=False)
# Generator loss function, parameters and updates. # Generator loss function, parameters and updates.
@ -350,13 +363,9 @@ class Model(object):
disc_updates = lasagne.updates.adam(sum(disc_losses, 0.0), disc_params, learning_rate=self.disc_lr) disc_updates = lasagne.updates.adam(sum(disc_losses, 0.0), disc_params, learning_rate=self.disc_lr)
# Combined Theano function for updating both generator and discriminator at the same time. # Combined Theano function for updating both generator and discriminator at the same time.
updates = list(gen_updates.items()) + list(disc_updates.items()) updates = collections.OrderedDict(list(gen_updates.items()) + list(disc_updates.items()))
self.fit = theano.function([input_tensor], gen_losses, updates=collections.OrderedDict(updates)) self.fit = theano.function([input_tensor], gen_losses + [disc_out.mean(axis=(1,2,3))], updates=updates)
# Helper function for rendering test images deterministically, computing statistics.
*outputs, disc_out = lasagne.layers.get_output([self.network[k] for k in ['img', 'seed', 'out', 'disc']],
input_layers, deterministic=True)
self.predict = theano.function([input_tensor], outputs + [disc_out.mean(axis=(1,2,3))])
class NeuralEnhancer(object): class NeuralEnhancer(object):
@ -389,7 +398,7 @@ class NeuralEnhancer(object):
try: try:
i, running, start = 0, None, time.time() i, running, start = 0, None, time.time()
for epoch in range(args.epochs): for epoch in range(args.epochs):
total = None total, stats = None, None
for _ in range(args.epoch_size): for _ in range(args.epoch_size):
i += 1 i += 1
l_r = l_min + 0.5 * (l_max - l_min) * (1.0 + math.cos(t_cur / t_i * math.pi)) l_r = l_min + 0.5 * (l_max - l_min) * (1.0 + math.cos(t_cur / t_i * math.pi))
@ -404,16 +413,19 @@ class NeuralEnhancer(object):
l_min = max(l_min * l_mult, 1e-7) l_min = max(l_min * l_mult, 1e-7)
self.thread.copy(images) self.thread.copy(images)
losses = np.array(self.model.fit(images), dtype=np.float32) output = self.model.fit(images)
losses = np.array(output[:3], dtype=np.float32)
stats = (stats + output[3]) if stats is not None else output[3]
total = total + losses if total is not None else losses total = total + losses if total is not None else losses
l = np.sum(losses) l = np.sum(losses)
assert not np.isnan(losses).any() assert not np.isnan(losses).any()
running = l if running is None else running * 0.9 + 0.1 * l running = l if running is None else running * 0.9 + 0.1 * l
print('' if l > running else '', end=' ', flush=True) print('' if l > running else '', end=' ', flush=True)
orign, scald, repro, stats = self.model.predict(images) orign, scald, repro = self.model.predict(images)
self.show_progress(orign, scald, repro) self.show_progress(orign, scald, repro)
total /= args.epoch_size total /= args.epoch_size
stats /= args.epoch_size
totals, labels = [sum(total)] + list(total), ['total', 'prcpt', 'smthn', 'advrs'] totals, labels = [sum(total)] + list(total), ['total', 'prcpt', 'smthn', 'advrs']
gen_info = ['{}{}{}={:4.2e}'.format(ansi.WHITE_B, k, ansi.ENDC, v) for k, v in zip(labels, totals)] gen_info = ['{}{}{}={:4.2e}'.format(ansi.WHITE_B, k, ansi.ENDC, v) for k, v in zip(labels, totals)]
print('\rEpoch #{} at {:4.1f}s{}'.format(epoch+1, time.time()-start, ' '*args.epoch_size)) print('\rEpoch #{} at {:4.1f}s{}'.format(epoch+1, time.time()-start, ' '*args.epoch_size))
@ -432,7 +444,22 @@ class NeuralEnhancer(object):
self.model.save_generator() self.model.save_generator()
print(ansi.ENDC) print(ansi.ENDC)
def process(self, images):
_, repro = self.model.predict(images)
return repro
if __name__ == "__main__": if __name__ == "__main__":
enhancer = NeuralEnhancer() enhancer = NeuralEnhancer()
if args.train:
enhancer.train() enhancer.train()
for filename in args.files:
img = scipy.ndimage.imread(filename, mode='RGB')
img = np.transpose(img / 255.0 - 0.5, (2, 0, 1))[np.newaxis]
out = enhancer.process(img.astype(np.float32))
out = np.transpose((out[0] + 0.5) * 255.0, (1, 2, 0)).astype(np.uint8)
scipy.misc.imsave(os.path.splitext(filename)[0]+'_enhanced.png', out)

Loading…
Cancel
Save