|
|
|
|
@ -36,7 +36,7 @@ parser = argparse.ArgumentParser(description='Generate a new image by applying s
|
|
|
|
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
|
|
|
|
add_arg = parser.add_argument
|
|
|
|
|
add_arg('files', nargs='*', default=[])
|
|
|
|
|
add_arg('--scales', default=2, type=int, help='How many times to perform 2x upsampling.')
|
|
|
|
|
add_arg('--zoom', default=4, type=int, help='Resolution increase factor for inference.')
|
|
|
|
|
add_arg('--model', default='small', type=str, help='Name of the neural network to load/save.')
|
|
|
|
|
add_arg('--train', default=False, type=str, help='File pattern to load for training.')
|
|
|
|
|
add_arg('--train-blur', default=None, type=float, help='Sigma value for gaussian blur preprocess.')
|
|
|
|
|
@ -50,8 +50,10 @@ add_arg('--batch-size', default=15, type=int, help='Number
|
|
|
|
|
add_arg('--buffer-size', default=1500, type=int, help='Total image fragments kept in cache.')
|
|
|
|
|
add_arg('--buffer-similar', default=5, type=int, help='Fragments cached for each image loaded.')
|
|
|
|
|
add_arg('--learning-rate', default=1E-4, type=float, help='Parameter for the ADAM optimizer.')
|
|
|
|
|
add_arg('--learning-period', default=50, type=int, help='How often to decay the learning rate.')
|
|
|
|
|
add_arg('--learning-period', default=100, type=int, help='How often to decay the learning rate.')
|
|
|
|
|
add_arg('--learning-decay', default=0.5, type=float, help='How much to decay the learning rate.')
|
|
|
|
|
add_arg('--generator-upscale', default=2, type=int, help='Steps of 2x up-sampling as post-process.')
|
|
|
|
|
add_arg('--generator-downscale',default=0, type=int, help='Steps of 2x down-sampling as preprocess.')
|
|
|
|
|
add_arg('--generator-filters', default=[64], nargs='+', type=int, help='Number of convolution units in network.')
|
|
|
|
|
add_arg('--generator-blocks', default=4, type=int, help='Number of residual blocks per iteration.')
|
|
|
|
|
add_arg('--generator-residual', default=2, type=int, help='Number of layers in a residual block.')
|
|
|
|
|
@ -59,7 +61,7 @@ add_arg('--perceptual-layer', default='conv2_2', type=str, help='Which
|
|
|
|
|
add_arg('--perceptual-weight', default=1e0, type=float, help='Weight for VGG-layer perceptual loss.')
|
|
|
|
|
add_arg('--discriminator-size', default=32, type=int, help='Multiplier for number of filters in D.')
|
|
|
|
|
add_arg('--smoothness-weight', default=2e5, type=float, help='Weight of the total-variation loss.')
|
|
|
|
|
add_arg('--adversary-weight', default=1e2, type=float, help='Weight of adversarial loss compoment.')
|
|
|
|
|
add_arg('--adversary-weight', default=5e2, type=float, help='Weight of adversarial loss compoment.')
|
|
|
|
|
add_arg('--generator-start', default=0, type=int, help='Epoch count to start training generator.')
|
|
|
|
|
add_arg('--discriminator-start',default=1, type=int, help='Epoch count to update the discriminator.')
|
|
|
|
|
add_arg('--adversarial-start', default=2, type=int, help='Epoch for generator to use discriminator.')
|
|
|
|
|
@ -132,7 +134,7 @@ class DataLoader(threading.Thread):
|
|
|
|
|
self.data_ready = threading.Event()
|
|
|
|
|
self.data_copied = threading.Event()
|
|
|
|
|
|
|
|
|
|
self.orig_shape, self.seed_shape = args.batch_shape, int(args.batch_shape / 2**args.scales)
|
|
|
|
|
self.orig_shape, self.seed_shape = args.batch_shape, int(args.batch_shape / args.zoom)
|
|
|
|
|
|
|
|
|
|
self.orig_buffer = np.zeros((args.buffer_size, 3, self.orig_shape, self.orig_shape), dtype=np.float32)
|
|
|
|
|
self.seed_buffer = np.zeros((args.buffer_size, 3, self.seed_shape, self.seed_shape), dtype=np.float32)
|
|
|
|
|
@ -199,9 +201,7 @@ class DataLoader(threading.Thread):
|
|
|
|
|
for i, j in enumerate(random.sample(self.ready, args.batch_size)):
|
|
|
|
|
origs_out[i] = self.orig_buffer[j]
|
|
|
|
|
seeds_out[i] = self.seed_buffer[j]
|
|
|
|
|
|
|
|
|
|
self.available.add(j)
|
|
|
|
|
|
|
|
|
|
self.data_copied.set()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -269,20 +269,26 @@ class Model(object):
|
|
|
|
|
|
|
|
|
|
def setup_generator(self, input, config):
|
|
|
|
|
for k, v in config.items(): setattr(args, k, v)
|
|
|
|
|
args.zoom = 2**(args.generator_upscale - args.generator_downscale)
|
|
|
|
|
|
|
|
|
|
units_iter = extend(args.generator_filters)
|
|
|
|
|
units = next(units_iter)
|
|
|
|
|
self.make_layer('iter.0-A', input, units, filter_size=(5,5), pad=(2,2))
|
|
|
|
|
self.make_layer('iter.0-B', self.last_layer(), units, filter_size=(5,5), pad=(2,2))
|
|
|
|
|
self.network['iter.0'] = self.last_layer()
|
|
|
|
|
|
|
|
|
|
for i in range(0, args.generator_downscale):
|
|
|
|
|
self.make_layer('downscale%i'%i, self.last_layer(), next(units_iter), filter_size=(4,4), stride=(2,2))
|
|
|
|
|
|
|
|
|
|
units = next(units_iter)
|
|
|
|
|
for i in range(0, args.generator_blocks):
|
|
|
|
|
self.make_block('iter.%i'%(i+1), self.last_layer(), units)
|
|
|
|
|
|
|
|
|
|
for i in range(0, args.scales):
|
|
|
|
|
for i in range(0, args.generator_upscale):
|
|
|
|
|
u = next(units_iter)
|
|
|
|
|
self.make_layer('scale%i.3'%i, self.last_layer(), u*4)
|
|
|
|
|
self.network['scale%i.2'%i] = SubpixelReshuffleLayer(self.last_layer(), u, 2)
|
|
|
|
|
self.make_layer('scale%i.1'%i, self.last_layer(), u)
|
|
|
|
|
self.make_layer('upscale%i.3'%i, self.last_layer(), u*4)
|
|
|
|
|
self.network['upscale%i.2'%i] = SubpixelReshuffleLayer(self.last_layer(), u, 2)
|
|
|
|
|
self.make_layer('upscale%i.1'%i, self.last_layer(), u)
|
|
|
|
|
|
|
|
|
|
self.network['out'] = ConvLayer(self.last_layer(), 3, filter_size=(5,5), stride=(1,1), pad=(2,2),
|
|
|
|
|
nonlinearity=lasagne.nonlinearities.tanh)
|
|
|
|
|
@ -355,13 +361,14 @@ class Model(object):
|
|
|
|
|
def save_generator(self):
|
|
|
|
|
def cast(p): return p.get_value().astype(np.float16)
|
|
|
|
|
params = {k: [cast(p) for p in l.get_params()] for (k, l) in self.list_generator_layers()}
|
|
|
|
|
config = {k: getattr(args, k) for k in ['generator_blocks', 'generator_residual', 'generator_filters']}
|
|
|
|
|
filename = 'ne%ix-%s-%s.pkl.bz2' % (2**args.scales, args.model, __version__)
|
|
|
|
|
config = {k: getattr(args, k) for k in ['generator_blocks', 'generator_residual', 'generator_filters'] + \
|
|
|
|
|
['generator_upscale', 'generator_downscale']}
|
|
|
|
|
filename = 'ne%ix-%s-%s.pkl.bz2' % (args.zoom, args.model, __version__)
|
|
|
|
|
pickle.dump((config, params), bz2.open(filename, 'wb'))
|
|
|
|
|
print(' - Saved model as `{}` after training.'.format(filename))
|
|
|
|
|
|
|
|
|
|
def load_model(self):
|
|
|
|
|
filename = 'ne%ix-%s-%s.pkl.bz2' % (2**args.scales, args.model, __version__)
|
|
|
|
|
filename = 'ne%ix-%s-%s.pkl.bz2' % (args.zoom, args.model, __version__)
|
|
|
|
|
if not os.path.exists(filename):
|
|
|
|
|
if args.train: return {}, {}
|
|
|
|
|
error("Model file with pre-trained convolution layers not found. Download it here...",
|
|
|
|
|
@ -431,7 +438,7 @@ class Model(object):
|
|
|
|
|
|
|
|
|
|
class NeuralEnhancer(object):
|
|
|
|
|
|
|
|
|
|
def __init__(self):
|
|
|
|
|
def __init__(self, loader):
|
|
|
|
|
if args.train:
|
|
|
|
|
print('{}Training {} epochs on random image sections with batch size {}.{}'\
|
|
|
|
|
.format(ansi.BLUE_B, args.epochs, args.batch_size, ansi.BLUE))
|
|
|
|
|
@ -440,8 +447,8 @@ class NeuralEnhancer(object):
|
|
|
|
|
print('{}Enhancing {} image(s) specified on the command-line.{}'\
|
|
|
|
|
.format(ansi.BLUE_B, len(args.files), ansi.BLUE))
|
|
|
|
|
|
|
|
|
|
self.thread = DataLoader() if args.train else None
|
|
|
|
|
self.model = Model()
|
|
|
|
|
self.thread = DataLoader() if loader else None
|
|
|
|
|
|
|
|
|
|
print('{}'.format(ansi.ENDC))
|
|
|
|
|
|
|
|
|
|
@ -466,7 +473,7 @@ class NeuralEnhancer(object):
|
|
|
|
|
if t_cur % args.learning_period == 0: l_r *= args.learning_decay
|
|
|
|
|
|
|
|
|
|
def train(self):
|
|
|
|
|
seed_size = int(args.batch_shape / 2**args.scales)
|
|
|
|
|
seed_size = int(args.batch_shape / args.zoom)
|
|
|
|
|
images = np.zeros((args.batch_size, 3, args.batch_shape, args.batch_shape), dtype=np.float32)
|
|
|
|
|
seeds = np.zeros((args.batch_size, 3, seed_size, seed_size), dtype=np.float32)
|
|
|
|
|
learning_rate = self.decay_learning_rate()
|
|
|
|
|
@ -512,7 +519,7 @@ class NeuralEnhancer(object):
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
print('\n{}Trained {}x super-resolution for {} epochs.{}'\
|
|
|
|
|
.format(ansi.CYAN_B, 2**args.scales, epoch+1, ansi.CYAN))
|
|
|
|
|
.format(ansi.CYAN_B, args.zoom, epoch+1, ansi.CYAN))
|
|
|
|
|
self.model.save_generator()
|
|
|
|
|
print(ansi.ENDC)
|
|
|
|
|
|
|
|
|
|
@ -524,11 +531,12 @@ class NeuralEnhancer(object):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
enhancer = NeuralEnhancer()
|
|
|
|
|
|
|
|
|
|
if args.train:
|
|
|
|
|
args.zoom = 2**(args.generator_upscale - args.generator_downscale)
|
|
|
|
|
enhancer = NeuralEnhancer(loader=True)
|
|
|
|
|
enhancer.train()
|
|
|
|
|
else:
|
|
|
|
|
enhancer = NeuralEnhancer(loader=False)
|
|
|
|
|
for filename in args.files:
|
|
|
|
|
print(filename)
|
|
|
|
|
img = scipy.ndimage.imread(filename, mode='RGB')
|
|
|
|
|
@ -538,5 +546,5 @@ if __name__ == "__main__":
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
out = enhancer.process(img)
|
|
|
|
|
out.save(os.path.splitext(filename)[0]+'_ne%ix.png'%(2**args.scales))
|
|
|
|
|
out.save(os.path.splitext(filename)[0]+'_ne%ix.png'%args.zoom)
|
|
|
|
|
print(ansi.ENDC)
|
|
|
|
|
|