diff --git a/enhance.py b/enhance.py index 82c15d8..25b5689 100644 --- a/enhance.py +++ b/enhance.py @@ -1,4 +1,4 @@ -#!/usr/bin/env python3 + #!/usr/bin/env python3 """ _ _ _ __ ___ _ _ _ __ __ _| | ___ _ __ | |__ __ _ _ __ ___ ___ | '_ \ / _ \ | | | '__/ _` | | / _ \ '_ \| '_ \ / _` | '_ \ / __/ _ \ @@ -122,7 +122,7 @@ class DataLoader(threading.Thread): def run(self): files, cache = glob.glob('train/*.jpg'), {} while True: - # random.shuffle(files) + random.shuffle(files) for i, f in enumerate(files[:args.batch_size]): filename = os.path.join(self.cwd, f) try: @@ -133,9 +133,9 @@ class DataLoader(threading.Thread): files.remove(f) continue - # if random.choice([True, False]): img[:,:] = img[:,::-1] - h = (img.shape[0] - self.resolution) // 2 # random.randint(0, img.shape[0] - self.resolution) - w = (img.shape[1] - self.resolution) // 2 # random.randint(0, img.shape[1] - self.resolution) + if random.choice([True, False]): img[:,:] = img[:,::-1] + h = random.randint(0, img.shape[0] - self.resolution) + w = random.randint(0, img.shape[1] - self.resolution) img = img[h:h+self.resolution, w:w+self.resolution] self.images[i] = np.transpose(img / 255.0 - 0.5, (2, 0, 1)) @@ -187,6 +187,8 @@ class Model(object): self.setup_perceptual(concatenated) self.load_perceptual() self.setup_discriminator() + self.load_generator() + self.compile() #------------------------------------------------------------------------------------------------------------------ @@ -285,14 +287,18 @@ class Model(object): def save_generator(self): def cast(p): return p.get_value().astype(np.float16) params = {k: [cast(p) for p in l.get_params()] for (k, l) in self.list_generator_layers()} - pickle.dump(params, bz2.open(args.model % 2**args.scales, 'wb')) + filename = args.model % 2**args.scales + pickle.dump(params, bz2.open(filename, 'wb')) + print(' - Saved model as `{}` after training.'.format(filename)) def load_generator(self): filename = args.model % 2**args.scales if not os.path.exists(filename): return params = pickle.load(bz2.open(filename, 'rb')) for k, l in self.list_generator_layers(): + if k not in params: continue (p.set_value(v) for p, v in zip(l.get_params(), params[k])) + print(' - Loaded file `{}` with trained model.'.format(filename)) #------------------------------------------------------------------------------------------------------------------ # Training & Loss Functions @@ -367,15 +373,13 @@ class NeuralEnhancer(object): self.imsave('valid/%03i_repro.png' % i, repro[i]) def train(self): - self.model.load_generator() - images = np.zeros((args.batch_size, 3, args.batch_resolution, args.batch_resolution), dtype=np.float32) l_min, l_max, l_mult = 1E-7, 1E-3, 0.2 t_cur, t_i, t_mult = 120, 150, 1 - i, running = 0, None + i, running, start = 0, None, time.time() for k in range(args.epochs): - total, start = None, time.time() + total = None for _ in range(args.epoch_size): i += 1 l_r = l_min + 0.5 * (l_max - l_min) * (1.0 + math.cos(t_cur / t_i * math.pi)) @@ -400,14 +404,17 @@ class NeuralEnhancer(object): total /= args.epoch_size totals, labels = [sum(total)] + list(total), ['total', 'prcpt', 'smthn', 'advrs'] losses = ['{}{}{}={:4.2e}'.format(ansi.WHITE_B, k, ansi.ENDC, v) for k, v in zip(labels, totals)] - print('\rEpoch #{} in {:4.1f}s{}'.format(k+1, time.time()-start, ' '*args.epoch_size)) - print(' - losses {}\n'.format(' '.join(losses))) + print('\rEpoch #{} at {:4.1f}s{}'.format(k+1, time.time()-start, ' '*args.epoch_size)) + print(' - losses {}'.format(' '.join(losses))) # print(stats[:args.batch_size].mean(), stats[args.batch_size:].mean()) # if k == 0: self.model.disc_lr.set_value(l_r) # if k == 1: self.model.adversary_weight.set_value(args.adversary_weight) + print('\n{}Trained {}x super-resolution for {} epochs.{}'\ + .format(ansi.CYAN_B, 2**args.scales, args.epochs, ansi.CYAN)) self.model.save_generator() + print(ansi.ENDC) if __name__ == "__main__":