Tweak code for generalized super-resolution.

main
Alex J. Champandard 9 years ago
parent 9558a11397
commit d8c9292b80

@ -34,20 +34,22 @@ parser = argparse.ArgumentParser(description='Generate a new image by applying s
add_arg = parser.add_argument add_arg = parser.add_argument
add_arg('files', nargs='*', default=[]) add_arg('files', nargs='*', default=[])
add_arg('--train', default=False, action='store_true') add_arg('--train', default=False, action='store_true')
add_arg('--load', default=None, action='store_true')
add_arg('--save', default=None, action='store_true') add_arg('--save', default=None, action='store_true')
add_arg('--model', default='ne%ix.pkl.bz2', type=str) add_arg('--model', default='ne%ix.pkl.bz2', type=str)
add_arg('--batch-size', default=1, type=int) add_arg('--batch-size', default=15, type=int)
add_arg('--batch-resolution', default=224, type=int) add_arg('--batch-resolution', default=192, type=int)
add_arg('--epoch-size', default=36, type=int) add_arg('--generator-filters', default=128, type=int)
add_arg('--epochs', default=10, type=int) add_arg('--generator-blocks', default=16, type=int)
add_arg('--generator-filters', default=256, type=int)
add_arg('--generator-blocks', default=4, type=int)
add_arg('--generator-residual', default=2, type=int) add_arg('--generator-residual', default=2, type=int)
add_arg('--perceptual-layer', default='conv2_2', type=str) add_arg('--perceptual-layer', default='conv2_2', type=str)
add_arg('--perceptual-weight', default=1e0, type=float) add_arg('--perceptual-weight', default=1e0, type=float)
add_arg('--smoothness-weight', default=2e5, type=float) add_arg('--smoothness-weight', default=2e5, type=float)
add_arg('--adversary-weight', default=2e2, type=float) add_arg('--adversary-weight', default=1e2, type=float)
add_arg('--epoch-size', default=36, type=int)
add_arg('--epochs', default=10, type=int)
add_arg('--generator-start', default=0, type=int)
add_arg('--discriminator-start',default=1, type=int)
add_arg('--adversarial-start', default=2, type=int)
add_arg('--scales', default=2, type=int, help='') add_arg('--scales', default=2, type=int, help='')
add_arg('--device', default='cpu', type=str, help='Name of the CPU/GPU number to use, for Theano.') add_arg('--device', default='cpu', type=str, help='Name of the CPU/GPU number to use, for Theano.')
args = parser.parse_args() args = parser.parse_args()
@ -84,7 +86,7 @@ print("""{} {}Super Resolution for images and videos powered by Deep Learning!
# Load the underlying deep learning libraries based on the device specified. If you specify THEANO_FLAGS manually, # Load the underlying deep learning libraries based on the device specified. If you specify THEANO_FLAGS manually,
# the code assumes you know what you are doing and they are not overriden! # the code assumes you know what you are doing and they are not overriden!
os.environ.setdefault('THEANO_FLAGS', 'floatX=float32,device={},force_device=True,allow_gc=True,'\ os.environ.setdefault('THEANO_FLAGS', 'floatX=float32,device={},force_device=True,allow_gc=True,'\
'print_active_device=False'.format(args.device)) 'print_active_device=False,lib.cnmem=1.0'.format(args.device))
# Scientific & Imaging Libraries # Scientific & Imaging Libraries
import numpy as np import numpy as np
@ -130,7 +132,17 @@ class DataLoader(threading.Thread):
for i, f in enumerate(files[:args.batch_size]): for i, f in enumerate(files[:args.batch_size]):
filename = os.path.join(self.cwd, f) filename = os.path.join(self.cwd, f)
try: try:
img = cache.setdefault(f, scipy.ndimage.imread(filename, mode='RGB')) if f not in cache:
if len(cache) > 3172:
del cache[random.choice(list(cache.keys()))]
img = scipy.ndimage.imread(filename, mode='RGB')
ratio = min(1024 / img.shape[0], 1024 / img.shape[1])
if ratio < 1.0:
img = scipy.misc.imresize(img, ratio, interp='bicubic')
cache[f] = img
else:
img = cache[f]
except Exception as e: except Exception as e:
warn('Could not load `{}` as image.'.format(filename), warn('Could not load `{}` as image.'.format(filename),
' - Try fixing or removing the file before next run.') ' - Try fixing or removing the file before next run.')
@ -188,6 +200,7 @@ class Model(object):
self.network['seed'] = PoolLayer(self.network['img'], pool_size=2**args.scales, mode='average_exc_pad') self.network['seed'] = PoolLayer(self.network['img'], pool_size=2**args.scales, mode='average_exc_pad')
else: else:
self.network['img'] = InputLayer((None, 3, None, None)) self.network['img'] = InputLayer((None, 3, None, None))
self.network['seed'] = self.network['img']
self.setup_generator(self.last_layer()) self.setup_generator(self.last_layer())
if args.train: if args.train:
@ -267,10 +280,10 @@ class Model(object):
self.make_layer('disc1.1', batch_norm(self.network['conv1_2']), 64, filter_size=(5,5), stride=(2,2), pad=(2,2)) self.make_layer('disc1.1', batch_norm(self.network['conv1_2']), 64, filter_size=(5,5), stride=(2,2), pad=(2,2))
self.make_layer('disc1.2', self.last_layer(), 64, filter_size=(5,5), stride=(2,2), pad=(2,2)) self.make_layer('disc1.2', self.last_layer(), 64, filter_size=(5,5), stride=(2,2), pad=(2,2))
self.make_layer('disc2', batch_norm(self.network['conv2_2']), 128, filter_size=(5,5), stride=(2,2), pad=(2,2)) self.make_layer('disc2', batch_norm(self.network['conv2_2']), 128, filter_size=(5,5), stride=(2,2), pad=(2,2))
# self.make_layer('disc3', batch_norm(self.network['conv3_2']), 256, filter_size=(3,3), stride=(1,1), pad=(1,1)) self.make_layer('disc3', batch_norm(self.network['conv3_2']), 192, filter_size=(3,3), stride=(1,1), pad=(1,1))
hypercolumn = ConcatLayer([self.network['disc1.2>'], self.network['disc2>']]) hypercolumn = ConcatLayer([self.network['disc1.2>'], self.network['disc2>'], self.network['disc3>']])
self.make_layer('disc4', hypercolumn, 128, filter_size=(3,3), stride=(1,1)) self.make_layer('disc4', hypercolumn, 192, filter_size=(3,3), stride=(1,1))
self.make_layer('disc5', self.last_layer(), 64, filter_size=(3,3), stride=(1,1)) self.make_layer('disc5', self.last_layer(), 96, filter_size=(3,3), stride=(1,1))
self.network['disc'] = batch_norm(ConvLayer(self.last_layer(), 1, filter_size=(1,1), self.network['disc'] = batch_norm(ConvLayer(self.last_layer(), 1, filter_size=(1,1),
nonlinearity=lasagne.nonlinearities.sigmoid)) nonlinearity=lasagne.nonlinearities.sigmoid))
@ -307,7 +320,7 @@ class Model(object):
def load_generator(self): def load_generator(self):
filename = args.model % 2**args.scales filename = args.model % 2**args.scales
if not os.path.exists(filename) or not args.load: return if not os.path.exists(filename): return
params = pickle.load(bz2.open(filename, 'rb')) params = pickle.load(bz2.open(filename, 'rb'))
for k, l in self.list_generator_layers(): for k, l in self.list_generator_layers():
assert k in params, "Couldn't find layer `%s` in loaded model.'" assert k in params, "Couldn't find layer `%s` in loaded model.'"
@ -328,16 +341,16 @@ class Model(object):
return T.mean(((x[:,:,:-1,:-1] - x[:,:,1:,:-1])**2 + (x[:,:,:-1,:-1] - x[:,:,:-1,1:])**2)**1.25) return T.mean(((x[:,:,:-1,:-1] - x[:,:,1:,:-1])**2 + (x[:,:,:-1,:-1] - x[:,:,:-1,1:])**2)**1.25)
def loss_adversarial(self, d): def loss_adversarial(self, d):
return 1.0 - T.log(d[args.batch_size:]).mean() return 1.0 - T.log(1E-6 + d[args.batch_size:]).mean()
def loss_discriminator(self, d): def loss_discriminator(self, d):
return T.mean(T.log(d[args.batch_size:]) + T.log(1.0 - d[:args.batch_size])) return T.mean(T.log(1E-6 + d[args.batch_size:]) + T.log(1E-6 + 1.0 - d[:args.batch_size]))
def compile(self): def compile(self):
# Helper function for rendering test images during training, or standalone non-training mode. # Helper function for rendering test images during training, or standalone non-training mode.
input_tensor = T.tensor4() input_tensor = T.tensor4()
input_layers = {self.network['img']: input_tensor} input_layers = {self.network['img']: input_tensor}
output = lasagne.layers.get_output([self.network[k] for k in ['img', 'out']], input_layers, deterministic=True) output = lasagne.layers.get_output([self.network[k] for k in ['img', 'seed', 'out']], input_layers, deterministic=True)
self.predict = theano.function([input_tensor], output) self.predict = theano.function([input_tensor], output)
if not args.train: return if not args.train: return
@ -404,8 +417,9 @@ class NeuralEnhancer(object):
l_r = l_min + 0.5 * (l_max - l_min) * (1.0 + math.cos(t_cur / t_i * math.pi)) l_r = l_min + 0.5 * (l_max - l_min) * (1.0 + math.cos(t_cur / t_i * math.pi))
t_cur += 1 t_cur += 1
l_r = 1E-4 l_r = 1E-4
self.model.gen_lr.set_value(l_r)
self.model.disc_lr.set_value(l_r) if epoch >= args.generator_start: self.model.gen_lr.set_value(l_r)
if epoch >= args.discriminator_start: self.model.disc_lr.set_value(l_r)
if t_cur >= t_i: if t_cur >= t_i:
t_cur, t_i = 0, int(t_i * t_mult) t_cur, t_i = 0, int(t_i * t_mult)
@ -433,20 +447,23 @@ class NeuralEnhancer(object):
real, fake = stats[:args.batch_size], stats[args.batch_size:] real, fake = stats[:args.batch_size], stats[args.batch_size:]
print(' - discriminator', real.mean(), len(np.where(real > 0.5)[0]), fake.mean(), len(np.where(fake < 0.5)[0])) print(' - discriminator', real.mean(), len(np.where(real > 0.5)[0]), fake.mean(), len(np.where(fake < 0.5)[0]))
if epoch == 0: if epoch == args.adversary_start-1:
print(' - adversary mode: generator engaging discriminator.')
self.model.adversary_weight.set_value(args.adversary_weight) self.model.adversary_weight.set_value(args.adversary_weight)
running = None running = None
except KeyboardInterrupt: except KeyboardInterrupt:
pass pass
print('\n{}Trained {}x super-resolution for {} epochs.{}'\ print('\n{}Trained {}x super-resolution for {} epochs.{}'\
.format(ansi.CYAN_B, 2**args.scales, epoch, ansi.CYAN)) .format(ansi.CYAN_B, 2**args.scales, epoch+1, ansi.CYAN))
self.model.save_generator() self.model.save_generator()
print(ansi.ENDC) print(ansi.ENDC)
def process(self, images): def process(self, image):
_, repro = self.model.predict(images) img = np.transpose(image / 255.0 - 0.5, (2, 0, 1))[np.newaxis].astype(np.float32)
return repro *_, repro = self.model.predict(img)
repro = np.transpose(repro[0] + 0.5, (1, 2, 0)).clip(0.0, 1.0)
return scipy.misc.toimage(repro * 255.0, cmin=0, cmax=255)
if __name__ == "__main__": if __name__ == "__main__":
@ -456,10 +473,5 @@ if __name__ == "__main__":
enhancer.train() enhancer.train()
for filename in args.files: for filename in args.files:
img = scipy.ndimage.imread(filename, mode='RGB') out = enhancer.process(scipy.ndimage.imread(filename, mode='RGB'))
img = np.transpose(img / 255.0 - 0.5, (2, 0, 1))[np.newaxis] out.save(os.path.splitext(filename)[0]+'_enhanced.png')
out = enhancer.process(img.astype(np.float32))
out = np.transpose((out[0] + 0.5) * 255.0, (1, 2, 0)).astype(np.uint8)
scipy.misc.imsave(os.path.splitext(filename)[0]+'_enhanced.png', out)

Loading…
Cancel
Save