Move generation of seeds out of training network

This moves the generation of the image seeds out of the
training network and into the DataLoader. Currently seeds
are computed as bilinear downsamplings of the original image.

This is almost functionly equivalent to the version
it replaces, but opens up new possibilities at training
time because the seeds are now decoupled from the netork.
For example, seeds could be made with different interpolations
or even with other transformations such as image compression.
main
Tom White 9 years ago
parent f68f04fb1c
commit 37cb208374

@ -130,7 +130,10 @@ class DataLoader(threading.Thread):
self.data_copied = threading.Event() self.data_copied = threading.Event()
self.resolution = args.batch_resolution self.resolution = args.batch_resolution
self.seed_resolution = int(args.batch_resolution / 2**args.scales)
self.buffer = np.zeros((args.buffer_size, 3, self.resolution, self.resolution), dtype=np.float32) self.buffer = np.zeros((args.buffer_size, 3, self.resolution, self.resolution), dtype=np.float32)
self.seed_buffer = np.zeros((args.buffer_size, 3, self.seed_resolution, self.seed_resolution), dtype=np.float32)
self.files = glob.glob(args.train) self.files = glob.glob(args.train)
if len(self.files) == 0: if len(self.files) == 0:
error("There were no files found to train from searching for `{}`".format(args.train), error("There were no files found to train from searching for `{}`".format(args.train),
@ -168,17 +171,23 @@ class DataLoader(threading.Thread):
i = self.available.pop() i = self.available.pop()
self.buffer[i] = np.transpose(copy / 255.0 - 0.5, (2, 0, 1)) self.buffer[i] = np.transpose(copy / 255.0 - 0.5, (2, 0, 1))
seed_copy = scipy.misc.imresize(copy,
size=(self.seed_resolution, self.seed_resolution),
interp='bilinear')
self.seed_buffer[i] = np.transpose(seed_copy / 255.0 - 0.5, (2, 0, 1))
self.ready.add(i) self.ready.add(i)
if len(self.ready) >= args.batch_size: if len(self.ready) >= args.batch_size:
self.data_ready.set() self.data_ready.set()
def copy(self, output): def copy(self, images_out, seeds_out):
self.data_ready.wait() self.data_ready.wait()
self.data_ready.clear() self.data_ready.clear()
for i, j in enumerate(random.sample(self.ready, args.batch_size)): for i, j in enumerate(random.sample(self.ready, args.batch_size)):
output[i] = self.buffer[j] images_out[i] = self.buffer[j]
seeds_out[i] = self.seed_buffer[j]
self.available.add(j) self.available.add(j)
self.data_copied.set() self.data_copied.set()
@ -214,7 +223,7 @@ class Model(object):
self.network = collections.OrderedDict() self.network = collections.OrderedDict()
if args.train: if args.train:
self.network['img'] = InputLayer((None, 3, None, None)) self.network['img'] = InputLayer((None, 3, None, None))
self.network['seed'] = PoolLayer(self.network['img'], pool_size=2**args.scales, mode='average_exc_pad') self.network['seed'] = InputLayer((None, 3, None, None))
else: else:
self.network['img'] = InputLayer((None, 3, None, None)) self.network['img'] = InputLayer((None, 3, None, None))
self.network['seed'] = self.network['img'] self.network['seed'] = self.network['img']
@ -380,9 +389,10 @@ class Model(object):
def compile(self): def compile(self):
# Helper function for rendering test images during training, or standalone non-training mode. # Helper function for rendering test images during training, or standalone non-training mode.
input_tensor = T.tensor4() input_tensor = T.tensor4()
input_layers = {self.network['img']: input_tensor} seed_tensor = T.tensor4()
input_layers = {self.network['img']: input_tensor, self.network['seed']: seed_tensor}
output = lasagne.layers.get_output([self.network[k] for k in ['img', 'seed', 'out']], input_layers, deterministic=True) output = lasagne.layers.get_output([self.network[k] for k in ['img', 'seed', 'out']], input_layers, deterministic=True)
self.predict = theano.function([input_tensor], output) self.predict = theano.function([input_tensor, seed_tensor], output)
if not args.train: return if not args.train: return
@ -408,7 +418,7 @@ class Model(object):
# Combined Theano function for updating both generator and discriminator at the same time. # Combined Theano function for updating both generator and discriminator at the same time.
updates = collections.OrderedDict(list(gen_updates.items()) + list(disc_updates.items())) updates = collections.OrderedDict(list(gen_updates.items()) + list(disc_updates.items()))
self.fit = theano.function([input_tensor], gen_losses + [disc_out.mean(axis=(1,2,3))], updates=updates) self.fit = theano.function([input_tensor, seed_tensor], gen_losses + [disc_out.mean(axis=(1,2,3))], updates=updates)
@ -448,7 +458,9 @@ class NeuralEnhancer(object):
if t_cur % args.learning_period == 0: l_r *= args.learning_decay if t_cur % args.learning_period == 0: l_r *= args.learning_decay
def train(self): def train(self):
seed_size = int(args.batch_resolution / 2**args.scales)
images = np.zeros((args.batch_size, 3, args.batch_resolution, args.batch_resolution), dtype=np.float32) images = np.zeros((args.batch_size, 3, args.batch_resolution, args.batch_resolution), dtype=np.float32)
seeds = np.zeros((args.batch_size, 3, seed_size, seed_size), dtype=np.float32)
learning_rate = self.decay_learning_rate() learning_rate = self.decay_learning_rate()
try: try:
running, start = None, time.time() running, start = None, time.time()
@ -459,8 +471,8 @@ class NeuralEnhancer(object):
if epoch >= args.discriminator_start: self.model.disc_lr.set_value(l_r) if epoch >= args.discriminator_start: self.model.disc_lr.set_value(l_r)
for _ in range(args.epoch_size): for _ in range(args.epoch_size):
self.thread.copy(images) self.thread.copy(images, seeds)
output = self.model.fit(images) output = self.model.fit(images, seeds)
losses = np.array(output[:3], dtype=np.float32) losses = np.array(output[:3], dtype=np.float32)
stats = (stats + output[3]) if stats is not None else output[3] stats = (stats + output[3]) if stats is not None else output[3]
total = total + losses if total is not None else losses total = total + losses if total is not None else losses
@ -469,7 +481,7 @@ class NeuralEnhancer(object):
running = l if running is None else running * 0.95 + 0.05 * l running = l if running is None else running * 0.95 + 0.05 * l
print('' if l > running else '', end='', flush=True) print('' if l > running else '', end='', flush=True)
orign, scald, repro = self.model.predict(images) orign, scald, repro = self.model.predict(images, seeds)
self.show_progress(orign, scald, repro) self.show_progress(orign, scald, repro)
total /= args.epoch_size total /= args.epoch_size
stats /= args.epoch_size stats /= args.epoch_size

Loading…
Cancel
Save