|
|
|
|
@ -34,20 +34,22 @@ parser = argparse.ArgumentParser(description='Generate a new image by applying s
|
|
|
|
|
add_arg = parser.add_argument
|
|
|
|
|
add_arg('files', nargs='*', default=[])
|
|
|
|
|
add_arg('--train', default=False, action='store_true')
|
|
|
|
|
add_arg('--load', default=None, action='store_true')
|
|
|
|
|
add_arg('--save', default=None, action='store_true')
|
|
|
|
|
add_arg('--model', default='ne%ix.pkl.bz2', type=str)
|
|
|
|
|
add_arg('--batch-size', default=1, type=int)
|
|
|
|
|
add_arg('--batch-resolution', default=224, type=int)
|
|
|
|
|
add_arg('--epoch-size', default=36, type=int)
|
|
|
|
|
add_arg('--epochs', default=10, type=int)
|
|
|
|
|
add_arg('--generator-filters', default=256, type=int)
|
|
|
|
|
add_arg('--generator-blocks', default=4, type=int)
|
|
|
|
|
add_arg('--batch-size', default=15, type=int)
|
|
|
|
|
add_arg('--batch-resolution', default=192, type=int)
|
|
|
|
|
add_arg('--generator-filters', default=128, type=int)
|
|
|
|
|
add_arg('--generator-blocks', default=16, type=int)
|
|
|
|
|
add_arg('--generator-residual', default=2, type=int)
|
|
|
|
|
add_arg('--perceptual-layer', default='conv2_2', type=str)
|
|
|
|
|
add_arg('--perceptual-weight', default=1e0, type=float)
|
|
|
|
|
add_arg('--smoothness-weight', default=2e5, type=float)
|
|
|
|
|
add_arg('--adversary-weight', default=2e2, type=float)
|
|
|
|
|
add_arg('--adversary-weight', default=1e2, type=float)
|
|
|
|
|
add_arg('--epoch-size', default=36, type=int)
|
|
|
|
|
add_arg('--epochs', default=10, type=int)
|
|
|
|
|
add_arg('--generator-start', default=0, type=int)
|
|
|
|
|
add_arg('--discriminator-start',default=1, type=int)
|
|
|
|
|
add_arg('--adversarial-start', default=2, type=int)
|
|
|
|
|
add_arg('--scales', default=2, type=int, help='')
|
|
|
|
|
add_arg('--device', default='cpu', type=str, help='Name of the CPU/GPU number to use, for Theano.')
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
@ -84,7 +86,7 @@ print("""{} {}Super Resolution for images and videos powered by Deep Learning!
|
|
|
|
|
# Load the underlying deep learning libraries based on the device specified. If you specify THEANO_FLAGS manually,
|
|
|
|
|
# the code assumes you know what you are doing and they are not overriden!
|
|
|
|
|
os.environ.setdefault('THEANO_FLAGS', 'floatX=float32,device={},force_device=True,allow_gc=True,'\
|
|
|
|
|
'print_active_device=False'.format(args.device))
|
|
|
|
|
'print_active_device=False,lib.cnmem=1.0'.format(args.device))
|
|
|
|
|
|
|
|
|
|
# Scientific & Imaging Libraries
|
|
|
|
|
import numpy as np
|
|
|
|
|
@ -130,7 +132,17 @@ class DataLoader(threading.Thread):
|
|
|
|
|
for i, f in enumerate(files[:args.batch_size]):
|
|
|
|
|
filename = os.path.join(self.cwd, f)
|
|
|
|
|
try:
|
|
|
|
|
img = cache.setdefault(f, scipy.ndimage.imread(filename, mode='RGB'))
|
|
|
|
|
if f not in cache:
|
|
|
|
|
if len(cache) > 3172:
|
|
|
|
|
del cache[random.choice(list(cache.keys()))]
|
|
|
|
|
|
|
|
|
|
img = scipy.ndimage.imread(filename, mode='RGB')
|
|
|
|
|
ratio = min(1024 / img.shape[0], 1024 / img.shape[1])
|
|
|
|
|
if ratio < 1.0:
|
|
|
|
|
img = scipy.misc.imresize(img, ratio, interp='bicubic')
|
|
|
|
|
cache[f] = img
|
|
|
|
|
else:
|
|
|
|
|
img = cache[f]
|
|
|
|
|
except Exception as e:
|
|
|
|
|
warn('Could not load `{}` as image.'.format(filename),
|
|
|
|
|
' - Try fixing or removing the file before next run.')
|
|
|
|
|
@ -188,6 +200,7 @@ class Model(object):
|
|
|
|
|
self.network['seed'] = PoolLayer(self.network['img'], pool_size=2**args.scales, mode='average_exc_pad')
|
|
|
|
|
else:
|
|
|
|
|
self.network['img'] = InputLayer((None, 3, None, None))
|
|
|
|
|
self.network['seed'] = self.network['img']
|
|
|
|
|
self.setup_generator(self.last_layer())
|
|
|
|
|
|
|
|
|
|
if args.train:
|
|
|
|
|
@ -267,10 +280,10 @@ class Model(object):
|
|
|
|
|
self.make_layer('disc1.1', batch_norm(self.network['conv1_2']), 64, filter_size=(5,5), stride=(2,2), pad=(2,2))
|
|
|
|
|
self.make_layer('disc1.2', self.last_layer(), 64, filter_size=(5,5), stride=(2,2), pad=(2,2))
|
|
|
|
|
self.make_layer('disc2', batch_norm(self.network['conv2_2']), 128, filter_size=(5,5), stride=(2,2), pad=(2,2))
|
|
|
|
|
# self.make_layer('disc3', batch_norm(self.network['conv3_2']), 256, filter_size=(3,3), stride=(1,1), pad=(1,1))
|
|
|
|
|
hypercolumn = ConcatLayer([self.network['disc1.2>'], self.network['disc2>']])
|
|
|
|
|
self.make_layer('disc4', hypercolumn, 128, filter_size=(3,3), stride=(1,1))
|
|
|
|
|
self.make_layer('disc5', self.last_layer(), 64, filter_size=(3,3), stride=(1,1))
|
|
|
|
|
self.make_layer('disc3', batch_norm(self.network['conv3_2']), 192, filter_size=(3,3), stride=(1,1), pad=(1,1))
|
|
|
|
|
hypercolumn = ConcatLayer([self.network['disc1.2>'], self.network['disc2>'], self.network['disc3>']])
|
|
|
|
|
self.make_layer('disc4', hypercolumn, 192, filter_size=(3,3), stride=(1,1))
|
|
|
|
|
self.make_layer('disc5', self.last_layer(), 96, filter_size=(3,3), stride=(1,1))
|
|
|
|
|
self.network['disc'] = batch_norm(ConvLayer(self.last_layer(), 1, filter_size=(1,1),
|
|
|
|
|
nonlinearity=lasagne.nonlinearities.sigmoid))
|
|
|
|
|
|
|
|
|
|
@ -307,7 +320,7 @@ class Model(object):
|
|
|
|
|
|
|
|
|
|
def load_generator(self):
|
|
|
|
|
filename = args.model % 2**args.scales
|
|
|
|
|
if not os.path.exists(filename) or not args.load: return
|
|
|
|
|
if not os.path.exists(filename): return
|
|
|
|
|
params = pickle.load(bz2.open(filename, 'rb'))
|
|
|
|
|
for k, l in self.list_generator_layers():
|
|
|
|
|
assert k in params, "Couldn't find layer `%s` in loaded model.'"
|
|
|
|
|
@ -328,16 +341,16 @@ class Model(object):
|
|
|
|
|
return T.mean(((x[:,:,:-1,:-1] - x[:,:,1:,:-1])**2 + (x[:,:,:-1,:-1] - x[:,:,:-1,1:])**2)**1.25)
|
|
|
|
|
|
|
|
|
|
def loss_adversarial(self, d):
|
|
|
|
|
return 1.0 - T.log(d[args.batch_size:]).mean()
|
|
|
|
|
return 1.0 - T.log(1E-6 + d[args.batch_size:]).mean()
|
|
|
|
|
|
|
|
|
|
def loss_discriminator(self, d):
|
|
|
|
|
return T.mean(T.log(d[args.batch_size:]) + T.log(1.0 - d[:args.batch_size]))
|
|
|
|
|
return T.mean(T.log(1E-6 + d[args.batch_size:]) + T.log(1E-6 + 1.0 - d[:args.batch_size]))
|
|
|
|
|
|
|
|
|
|
def compile(self):
|
|
|
|
|
# Helper function for rendering test images during training, or standalone non-training mode.
|
|
|
|
|
input_tensor = T.tensor4()
|
|
|
|
|
input_layers = {self.network['img']: input_tensor}
|
|
|
|
|
output = lasagne.layers.get_output([self.network[k] for k in ['img', 'out']], input_layers, deterministic=True)
|
|
|
|
|
output = lasagne.layers.get_output([self.network[k] for k in ['img', 'seed', 'out']], input_layers, deterministic=True)
|
|
|
|
|
self.predict = theano.function([input_tensor], output)
|
|
|
|
|
|
|
|
|
|
if not args.train: return
|
|
|
|
|
@ -404,8 +417,9 @@ class NeuralEnhancer(object):
|
|
|
|
|
l_r = l_min + 0.5 * (l_max - l_min) * (1.0 + math.cos(t_cur / t_i * math.pi))
|
|
|
|
|
t_cur += 1
|
|
|
|
|
l_r = 1E-4
|
|
|
|
|
self.model.gen_lr.set_value(l_r)
|
|
|
|
|
self.model.disc_lr.set_value(l_r)
|
|
|
|
|
|
|
|
|
|
if epoch >= args.generator_start: self.model.gen_lr.set_value(l_r)
|
|
|
|
|
if epoch >= args.discriminator_start: self.model.disc_lr.set_value(l_r)
|
|
|
|
|
|
|
|
|
|
if t_cur >= t_i:
|
|
|
|
|
t_cur, t_i = 0, int(t_i * t_mult)
|
|
|
|
|
@ -433,20 +447,23 @@ class NeuralEnhancer(object):
|
|
|
|
|
|
|
|
|
|
real, fake = stats[:args.batch_size], stats[args.batch_size:]
|
|
|
|
|
print(' - discriminator', real.mean(), len(np.where(real > 0.5)[0]), fake.mean(), len(np.where(fake < 0.5)[0]))
|
|
|
|
|
if epoch == 0:
|
|
|
|
|
if epoch == args.adversary_start-1:
|
|
|
|
|
print(' - adversary mode: generator engaging discriminator.')
|
|
|
|
|
self.model.adversary_weight.set_value(args.adversary_weight)
|
|
|
|
|
running = None
|
|
|
|
|
except KeyboardInterrupt:
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
print('\n{}Trained {}x super-resolution for {} epochs.{}'\
|
|
|
|
|
.format(ansi.CYAN_B, 2**args.scales, epoch, ansi.CYAN))
|
|
|
|
|
.format(ansi.CYAN_B, 2**args.scales, epoch+1, ansi.CYAN))
|
|
|
|
|
self.model.save_generator()
|
|
|
|
|
print(ansi.ENDC)
|
|
|
|
|
|
|
|
|
|
def process(self, images):
|
|
|
|
|
_, repro = self.model.predict(images)
|
|
|
|
|
return repro
|
|
|
|
|
def process(self, image):
|
|
|
|
|
img = np.transpose(image / 255.0 - 0.5, (2, 0, 1))[np.newaxis].astype(np.float32)
|
|
|
|
|
*_, repro = self.model.predict(img)
|
|
|
|
|
repro = np.transpose(repro[0] + 0.5, (1, 2, 0)).clip(0.0, 1.0)
|
|
|
|
|
return scipy.misc.toimage(repro * 255.0, cmin=0, cmax=255)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
@ -456,10 +473,5 @@ if __name__ == "__main__":
|
|
|
|
|
enhancer.train()
|
|
|
|
|
|
|
|
|
|
for filename in args.files:
|
|
|
|
|
img = scipy.ndimage.imread(filename, mode='RGB')
|
|
|
|
|
img = np.transpose(img / 255.0 - 0.5, (2, 0, 1))[np.newaxis]
|
|
|
|
|
|
|
|
|
|
out = enhancer.process(img.astype(np.float32))
|
|
|
|
|
|
|
|
|
|
out = np.transpose((out[0] + 0.5) * 255.0, (1, 2, 0)).astype(np.uint8)
|
|
|
|
|
scipy.misc.imsave(os.path.splitext(filename)[0]+'_enhanced.png', out)
|
|
|
|
|
out = enhancer.process(scipy.ndimage.imread(filename, mode='RGB'))
|
|
|
|
|
out.save(os.path.splitext(filename)[0]+'_enhanced.png')
|
|
|
|
|
|