Switch to sub-pixel deconvolution layer.

main
Alex J. Champandard 9 years ago
parent 9f13695050
commit 749b467f94

@ -34,12 +34,12 @@ add_arg = parser.add_argument
add_arg('--batch-size', default=15, type=int)
add_arg('--batch-resolution', default=256, type=int)
add_arg('--epoch-size', default=36, type=int)
add_arg('--epochs', default=100, type=int)
add_arg('--network-filters', default=128, type=int)
add_arg('--network-blocks', default=4, type=int)
add_arg('--perceptual-layer', default='mse', type=str)
add_arg('--epochs', default=15, type=int)
add_arg('--generator-filters', default=128, type=int)
add_arg('--generator-blocks', default=4, type=int)
add_arg('--perceptual-layer', default='conv2_2', type=str)
add_arg('--perceptual-weight', default=1e0, type=float)
add_arg('--smoothness-weight', default=0.0, type=float)
add_arg('--smoothness-weight', default=1e4, type=float)
add_arg('--adversary-weight', default=0.0, type=float)
add_arg('--scales', default=1, type=int, help='')
add_arg('--device', default='gpu0', type=str, help='Name of the CPU/GPU number to use, for Theano.')
@ -121,31 +121,23 @@ class DataLoader(threading.Thread):
while True:
random.shuffle(files)
for i, f in enumerate(files[:args.batch_size]):
filename = os.path.join(self.cwd, f)
try:
if f in cache:
img = cache[f]
else:
img = scipy.ndimage.imread(filename, mode='RGB')
cache[f] = img
img = cache.setdefault(f, scipy.ndimage.imread(filename, mode='RGB'))
except Exception as e:
warn('Could not load `{}` as image.'.format(filename),
' - Try fixing or removing the file before next run.')
files.remove(f)
continue
if random.choice([True, False]):
img[:,:] = img[:,::-1]
if random.choice([True, False]): img[:,:] = img[:,::-1]
h = random.randint(0, img.shape[0] - self.resolution)
w = random.randint(0, img.shape[1] - self.resolution)
img = img[h:h+self.resolution, w:w+self.resolution]
self.images[i] = np.transpose(img / 255.0 - 0.5, (2, 0, 1))
self.data_ready.set()
self.data_copied.wait()
self.data_copied.clear()
@ -160,6 +152,27 @@ class DataLoader(threading.Thread):
#----------------------------------------------------------------------------------------------------------------------
# Convolution Networks
#----------------------------------------------------------------------------------------------------------------------
class SubpixelShuffle(lasagne.layers.Layer):
"""Based on the code by ajbrock: https://github.com/ajbrock/Neural-Photo-Editor/
"""
def __init__(self, incoming, channels, upscale, **kwargs):
super(SubpixelShuffle, self).__init__(incoming, **kwargs)
self.upscale = upscale
self.channels = channels
def get_output_shape_for(self, input_shape):
def up(d): return self.upscale * d if d else d
return (input_shape[0], self.channels, up(input_shape[2]), up(input_shape[3]))
def get_output_for(self, input, deterministic=False, **kwargs):
out, r = T.zeros(self.get_output_shape_for(input.shape)), self.upscale
for y, x in itertools.product(range(r), repeat=2):
out=T.inc_subtensor(out[:,:,y::r,x::r], input[:,r*y+x::r*r,:,:])
return out
class Model(object):
def __init__(self):
@ -177,25 +190,30 @@ class Model(object):
def last_layer(self):
return list(self.network.values())[-1]
def make_block(self, input, units):
l1 = batch_norm(ConvLayer(input, units, filter_size=(3,3), stride=(1,1), pad=1))
l2 = batch_norm(ConvLayer(l1, units, filter_size=(3,3), stride=(1,1), pad=1))
return ElemwiseSumLayer([input, l2])
def make_layer(self, input, units, filter_size=(3,3), stride=(1,1), pad=(1,1), nl=None):
return ConvLayer(input, units, filter_size=filter_size, stride=stride, pad=pad,
nonlinearity=nl or lasagne.nonlinearities.rectify)
def setup_generator(self, input):
f = args.network_filters
self.network['iter.0'] = ConvLayer(input, f, filter_size=(1,1), stride=(1,1), pad=0)
units = args.generator_filters
self.network['iter.0'] = self.make_layer(input, units, filter_size=(5,5), pad=(2,2))
for i in range(0, args.network_blocks):
self.network['iter.%i'%(i+1)] = self.make_block(self.last_layer(), f)
for i in range(0, args.generator_blocks):
self.network['iter.%i'%(i+1)] = self.make_block(self.last_layer(), units)
for i in range(args.scales, 0, -1):
self.network['scale%i.2'%i] = DeconvLayer(self.last_layer(), f, filter_size=(4,4), stride=(2,2), crop=1)
self.network['scale%i.1'%i] = ConvLayer(self.network['scale%i.2'%i], f, filter_size=(3,3), pad=1)
self.network['scale%i.3'%i] = self.make_layer(self.last_layer(), units*2)
self.network['scale%i.2'%i] = SubpixelShuffle(self.network['scale%i.3'%i], units//2, 2)
self.network['scale%i.1'%i] = self.make_layer(self.network['scale%i.2'%i], units)
self.network['out'] = ConvLayer(self.last_layer(), 3, filter_size=(1,1), stride=(1,1), pad=0, b=None,
self.network['out'] = ConvLayer(self.last_layer(), 3, filter_size=(5,5), stride=(1,1), pad=(2,2),
nonlinearity=lasagne.nonlinearities.tanh)
def make_block(self, input, units):
l1 = batch_norm(ConvLayer(input, units, filter_size=(3,3), stride=(1,1), pad=1))
l2 = batch_norm(ConvLayer(l1, units, filter_size=(3,3), stride=(1,1), pad=1))
return ElemwiseSumLayer([input, l2])
def setup_discriminator(self):
self.network['disc1'] = ConvLayer(self.network['conv1_2'], 64, filter_size=(7,7), stride=(4,4), pad=(3,3))
self.network['disc2'] = ConvLayer(self.network['conv2_2'], 128, filter_size=(5,5), stride=(2,2), pad=(2,2))
@ -300,7 +318,7 @@ class NeuralEnhancer(object):
self.thread = DataLoader()
self.model = Model()
print('\n{}'.format(ansi.ENDC))
print('{}'.format(ansi.ENDC))
def imsave(self, fn, img):
img = np.transpose(img + 0.5, (1, 2, 0)).clip(0.0, 1.0)
@ -316,7 +334,7 @@ class NeuralEnhancer(object):
images = np.zeros((args.batch_size, 3, args.batch_resolution, args.batch_resolution), dtype=np.float32)
l_min, l_max, l_mult = 1E-7, 1E-3, 0.2
t_cur, t_i, t_mult = 0, 150, 1
t_cur, t_i, t_mult = 120, 150, 1
i, running = 0, None
for _ in range(args.epochs):
@ -339,12 +357,12 @@ class NeuralEnhancer(object):
l = np.sum(losses)
running = l if running is None else running * 0.9 + 0.1 * l
print('' if l >= running else '', end=' ', flush=True)
print('' if l > running else '', end=' ', flush=True)
self.show_progress(*self.model.predict(images))
total = total / args.epoch_size
labels = ['{}={:4.2e}'.format(k, v) for k, v in zip(['prcpt', 'smthn', 'advrs'], total)]
print('\nLosses: total={:4.2e} {}'.format(sum(total), ' '.join(labels)))
print('\nLosses: total={:4.2e} {}'.format(sum(total), ' '.join(labels)))
if __name__ == "__main__":

Loading…
Cancel
Save