Move generation of seeds out of training network

This moves the generation of the image seeds out of the
training network and into the DataLoader. Currently seeds
are computed as bilinear downsamplings of the original image.

This is almost functionly equivalent to the version
it replaces, but opens up new possibilities at training
time because the seeds are now decoupled from the netork.
For example, seeds could be made with different interpolations
or even with other transformations such as image compression.
main
Tom White 9 years ago
parent f68f04fb1c
commit 37cb208374

@ -130,7 +130,10 @@ class DataLoader(threading.Thread):
self.data_copied = threading.Event()
self.resolution = args.batch_resolution
self.seed_resolution = int(args.batch_resolution / 2**args.scales)
self.buffer = np.zeros((args.buffer_size, 3, self.resolution, self.resolution), dtype=np.float32)
self.seed_buffer = np.zeros((args.buffer_size, 3, self.seed_resolution, self.seed_resolution), dtype=np.float32)
self.files = glob.glob(args.train)
if len(self.files) == 0:
error("There were no files found to train from searching for `{}`".format(args.train),
@ -168,17 +171,23 @@ class DataLoader(threading.Thread):
i = self.available.pop()
self.buffer[i] = np.transpose(copy / 255.0 - 0.5, (2, 0, 1))
seed_copy = scipy.misc.imresize(copy,
size=(self.seed_resolution, self.seed_resolution),
interp='bilinear')
self.seed_buffer[i] = np.transpose(seed_copy / 255.0 - 0.5, (2, 0, 1))
self.ready.add(i)
if len(self.ready) >= args.batch_size:
self.data_ready.set()
def copy(self, output):
def copy(self, images_out, seeds_out):
self.data_ready.wait()
self.data_ready.clear()
for i, j in enumerate(random.sample(self.ready, args.batch_size)):
output[i] = self.buffer[j]
images_out[i] = self.buffer[j]
seeds_out[i] = self.seed_buffer[j]
self.available.add(j)
self.data_copied.set()
@ -214,7 +223,7 @@ class Model(object):
self.network = collections.OrderedDict()
if args.train:
self.network['img'] = InputLayer((None, 3, None, None))
self.network['seed'] = PoolLayer(self.network['img'], pool_size=2**args.scales, mode='average_exc_pad')
self.network['seed'] = InputLayer((None, 3, None, None))
else:
self.network['img'] = InputLayer((None, 3, None, None))
self.network['seed'] = self.network['img']
@ -380,9 +389,10 @@ class Model(object):
def compile(self):
# Helper function for rendering test images during training, or standalone non-training mode.
input_tensor = T.tensor4()
input_layers = {self.network['img']: input_tensor}
seed_tensor = T.tensor4()
input_layers = {self.network['img']: input_tensor, self.network['seed']: seed_tensor}
output = lasagne.layers.get_output([self.network[k] for k in ['img', 'seed', 'out']], input_layers, deterministic=True)
self.predict = theano.function([input_tensor], output)
self.predict = theano.function([input_tensor, seed_tensor], output)
if not args.train: return
@ -408,7 +418,7 @@ class Model(object):
# Combined Theano function for updating both generator and discriminator at the same time.
updates = collections.OrderedDict(list(gen_updates.items()) + list(disc_updates.items()))
self.fit = theano.function([input_tensor], gen_losses + [disc_out.mean(axis=(1,2,3))], updates=updates)
self.fit = theano.function([input_tensor, seed_tensor], gen_losses + [disc_out.mean(axis=(1,2,3))], updates=updates)
@ -448,7 +458,9 @@ class NeuralEnhancer(object):
if t_cur % args.learning_period == 0: l_r *= args.learning_decay
def train(self):
seed_size = int(args.batch_resolution / 2**args.scales)
images = np.zeros((args.batch_size, 3, args.batch_resolution, args.batch_resolution), dtype=np.float32)
seeds = np.zeros((args.batch_size, 3, seed_size, seed_size), dtype=np.float32)
learning_rate = self.decay_learning_rate()
try:
running, start = None, time.time()
@ -459,8 +471,8 @@ class NeuralEnhancer(object):
if epoch >= args.discriminator_start: self.model.disc_lr.set_value(l_r)
for _ in range(args.epoch_size):
self.thread.copy(images)
output = self.model.fit(images)
self.thread.copy(images, seeds)
output = self.model.fit(images, seeds)
losses = np.array(output[:3], dtype=np.float32)
stats = (stats + output[3]) if stats is not None else output[3]
total = total + losses if total is not None else losses
@ -469,7 +481,7 @@ class NeuralEnhancer(object):
running = l if running is None else running * 0.95 + 0.05 * l
print('' if l > running else '', end='', flush=True)
orign, scald, repro = self.model.predict(images)
orign, scald, repro = self.model.predict(images, seeds)
self.show_progress(orign, scald, repro)
total /= args.epoch_size
stats /= args.epoch_size

Loading…
Cancel
Save