Merge pull request #12 from dribnet/generic_seeds

Move generation of seeds out of training network.
main
Alex J. Champandard 9 years ago committed by GitHub
commit 2b67daedb6

@ -130,7 +130,10 @@ class DataLoader(threading.Thread):
self.data_copied = threading.Event() self.data_copied = threading.Event()
self.resolution = args.batch_resolution self.resolution = args.batch_resolution
self.seed_resolution = int(args.batch_resolution / 2**args.scales)
self.buffer = np.zeros((args.buffer_size, 3, self.resolution, self.resolution), dtype=np.float32) self.buffer = np.zeros((args.buffer_size, 3, self.resolution, self.resolution), dtype=np.float32)
self.seed_buffer = np.zeros((args.buffer_size, 3, self.seed_resolution, self.seed_resolution), dtype=np.float32)
self.files = glob.glob(args.train) self.files = glob.glob(args.train)
if len(self.files) == 0: if len(self.files) == 0:
error("There were no files found to train from searching for `{}`".format(args.train), error("There were no files found to train from searching for `{}`".format(args.train),
@ -168,17 +171,23 @@ class DataLoader(threading.Thread):
i = self.available.pop() i = self.available.pop()
self.buffer[i] = np.transpose(copy / 255.0 - 0.5, (2, 0, 1)) self.buffer[i] = np.transpose(copy / 255.0 - 0.5, (2, 0, 1))
seed_copy = scipy.misc.imresize(copy,
size=(self.seed_resolution, self.seed_resolution),
interp='bilinear')
self.seed_buffer[i] = np.transpose(seed_copy / 255.0 - 0.5, (2, 0, 1))
self.ready.add(i) self.ready.add(i)
if len(self.ready) >= args.batch_size: if len(self.ready) >= args.batch_size:
self.data_ready.set() self.data_ready.set()
def copy(self, output): def copy(self, images_out, seeds_out):
self.data_ready.wait() self.data_ready.wait()
self.data_ready.clear() self.data_ready.clear()
for i, j in enumerate(random.sample(self.ready, args.batch_size)): for i, j in enumerate(random.sample(self.ready, args.batch_size)):
output[i] = self.buffer[j] images_out[i] = self.buffer[j]
seeds_out[i] = self.seed_buffer[j]
self.available.add(j) self.available.add(j)
self.data_copied.set() self.data_copied.set()
@ -212,12 +221,8 @@ class Model(object):
def __init__(self): def __init__(self):
self.network = collections.OrderedDict() self.network = collections.OrderedDict()
if args.train: self.network['img'] = InputLayer((None, 3, None, None))
self.network['img'] = InputLayer((None, 3, None, None)) self.network['seed'] = InputLayer((None, 3, None, None))
self.network['seed'] = PoolLayer(self.network['img'], pool_size=2**args.scales, mode='average_exc_pad')
else:
self.network['img'] = InputLayer((None, 3, None, None))
self.network['seed'] = self.network['img']
config, params = self.load_model() config, params = self.load_model()
self.setup_generator(self.last_layer(), config) self.setup_generator(self.last_layer(), config)
@ -380,9 +385,10 @@ class Model(object):
def compile(self): def compile(self):
# Helper function for rendering test images during training, or standalone non-training mode. # Helper function for rendering test images during training, or standalone non-training mode.
input_tensor = T.tensor4() input_tensor = T.tensor4()
input_layers = {self.network['img']: input_tensor} seed_tensor = T.tensor4()
input_layers = {self.network['img']: input_tensor, self.network['seed']: seed_tensor}
output = lasagne.layers.get_output([self.network[k] for k in ['img', 'seed', 'out']], input_layers, deterministic=True) output = lasagne.layers.get_output([self.network[k] for k in ['img', 'seed', 'out']], input_layers, deterministic=True)
self.predict = theano.function([input_tensor], output) self.predict = theano.function([input_tensor, seed_tensor], output)
if not args.train: return if not args.train: return
@ -408,7 +414,7 @@ class Model(object):
# Combined Theano function for updating both generator and discriminator at the same time. # Combined Theano function for updating both generator and discriminator at the same time.
updates = collections.OrderedDict(list(gen_updates.items()) + list(disc_updates.items())) updates = collections.OrderedDict(list(gen_updates.items()) + list(disc_updates.items()))
self.fit = theano.function([input_tensor], gen_losses + [disc_out.mean(axis=(1,2,3))], updates=updates) self.fit = theano.function([input_tensor, seed_tensor], gen_losses + [disc_out.mean(axis=(1,2,3))], updates=updates)
@ -449,7 +455,9 @@ class NeuralEnhancer(object):
if t_cur % args.learning_period == 0: l_r *= args.learning_decay if t_cur % args.learning_period == 0: l_r *= args.learning_decay
def train(self): def train(self):
seed_size = int(args.batch_resolution / 2**args.scales)
images = np.zeros((args.batch_size, 3, args.batch_resolution, args.batch_resolution), dtype=np.float32) images = np.zeros((args.batch_size, 3, args.batch_resolution, args.batch_resolution), dtype=np.float32)
seeds = np.zeros((args.batch_size, 3, seed_size, seed_size), dtype=np.float32)
learning_rate = self.decay_learning_rate() learning_rate = self.decay_learning_rate()
try: try:
running, start = None, time.time() running, start = None, time.time()
@ -460,8 +468,8 @@ class NeuralEnhancer(object):
if epoch >= args.discriminator_start: self.model.disc_lr.set_value(l_r) if epoch >= args.discriminator_start: self.model.disc_lr.set_value(l_r)
for _ in range(args.epoch_size): for _ in range(args.epoch_size):
self.thread.copy(images) self.thread.copy(images, seeds)
output = self.model.fit(images) output = self.model.fit(images, seeds)
losses = np.array(output[:3], dtype=np.float32) losses = np.array(output[:3], dtype=np.float32)
stats = (stats + output[3]) if stats is not None else output[3] stats = (stats + output[3]) if stats is not None else output[3]
total = total + losses if total is not None else losses total = total + losses if total is not None else losses
@ -470,7 +478,7 @@ class NeuralEnhancer(object):
running = l if running is None else running * 0.95 + 0.05 * l running = l if running is None else running * 0.95 + 0.05 * l
print('' if l > running else '', end='', flush=True) print('' if l > running else '', end='', flush=True)
orign, scald, repro = self.model.predict(images) orign, scald, repro = self.model.predict(images, seeds)
self.show_progress(orign, scald, repro) self.show_progress(orign, scald, repro)
total /= args.epoch_size total /= args.epoch_size
stats /= args.epoch_size stats /= args.epoch_size
@ -498,7 +506,7 @@ class NeuralEnhancer(object):
def process(self, image): def process(self, image):
img = np.transpose(image / 255.0 - 0.5, (2, 0, 1))[np.newaxis].astype(np.float32) img = np.transpose(image / 255.0 - 0.5, (2, 0, 1))[np.newaxis].astype(np.float32)
*_, repro = self.model.predict(img) *_, repro = self.model.predict(img, img)
repro = np.transpose(repro[0] + 0.5, (1, 2, 0)).clip(0.0, 1.0) repro = np.transpose(repro[0] + 0.5, (1, 2, 0)).clip(0.0, 1.0)
return scipy.misc.toimage(repro * 255.0, cmin=0, cmax=255) return scipy.misc.toimage(repro * 255.0, cmin=0, cmax=255)

Loading…
Cancel
Save