From 87304c93a6d373d3e3474f50967deb8a4ae39fa5 Mon Sep 17 00:00:00 2001 From: "Alex J. Champandard" Date: Fri, 28 Oct 2016 04:27:17 +0200 Subject: [PATCH] Use traditional learning rate decaying rather than fast-restarts, works better when training continuously adapting GAN. --- enhance.py | 20 ++++++++------------ 1 file changed, 8 insertions(+), 12 deletions(-) diff --git a/enhance.py b/enhance.py index fc6a913..eaa31ea 100644 --- a/enhance.py +++ b/enhance.py @@ -418,22 +418,17 @@ class NeuralEnhancer(object): self.imsave('valid/%03i_pixels.png' % i, scald[i]) self.imsave('valid/%03i_reprod.png' % i, repro[i]) - def decay_with_restart(self): - l_min, l_max, l_mult = 1E-7, 1E-3, 0.5 - t_cur, t_i, t_mult = 10, 10, 1 + def decay_learning_rate(self): + l_r, t_cur = args.learning_rate, 0 while True: - yield l_min + 0.5 * (l_max - l_min) * (1.0 + math.cos(t_cur / t_i * math.pi)) + yield l_r if t_cur > 0 else l_r * 0.1 t_cur += 1 - - if t_cur > t_i: - t_cur, t_i = 0, int(t_i * t_mult) - l_max = max(l_max * l_mult, 1e-12) - l_min = max(l_min * l_mult, 1e-8) + if t_cur % args.learning_period == 0: l_r *= args.learning_decay def train(self): images = np.zeros((args.batch_size, 3, args.batch_resolution, args.batch_resolution), dtype=np.float32) - learning_rate = self.decay_with_restart() + learning_rate = self.decay_learning_rate() try: running, start = None, time.time() for epoch in range(args.epochs): @@ -450,8 +445,8 @@ class NeuralEnhancer(object): total = total + losses if total is not None else losses l = np.sum(losses) assert not np.isnan(losses).any() - running = l if running is None else running * 0.9 + 0.1 * l - print('↑' if l > running else '↓', end=' ', flush=True) + running = l if running is None else running * 0.95 + 0.05 * l + print('↑' if l > running else '↓', end='', flush=True) orign, scald, repro = self.model.predict(images) self.show_progress(orign, scald, repro) @@ -468,6 +463,7 @@ class NeuralEnhancer(object): print(' - adversary mode: generator engaging discriminator.') self.model.adversary_weight.set_value(args.adversary_weight) running = None + except KeyboardInterrupt: pass