Use traditional learning rate decaying rather than fast-restarts, works better when training continuously adapting GAN.

main
Alex J. Champandard 9 years ago
parent 3809e9b02a
commit 87304c93a6

@ -418,22 +418,17 @@ class NeuralEnhancer(object):
self.imsave('valid/%03i_pixels.png' % i, scald[i])
self.imsave('valid/%03i_reprod.png' % i, repro[i])
def decay_with_restart(self):
l_min, l_max, l_mult = 1E-7, 1E-3, 0.5
t_cur, t_i, t_mult = 10, 10, 1
def decay_learning_rate(self):
l_r, t_cur = args.learning_rate, 0
while True:
yield l_min + 0.5 * (l_max - l_min) * (1.0 + math.cos(t_cur / t_i * math.pi))
yield l_r if t_cur > 0 else l_r * 0.1
t_cur += 1
if t_cur > t_i:
t_cur, t_i = 0, int(t_i * t_mult)
l_max = max(l_max * l_mult, 1e-12)
l_min = max(l_min * l_mult, 1e-8)
if t_cur % args.learning_period == 0: l_r *= args.learning_decay
def train(self):
images = np.zeros((args.batch_size, 3, args.batch_resolution, args.batch_resolution), dtype=np.float32)
learning_rate = self.decay_with_restart()
learning_rate = self.decay_learning_rate()
try:
running, start = None, time.time()
for epoch in range(args.epochs):
@ -450,7 +445,7 @@ class NeuralEnhancer(object):
total = total + losses if total is not None else losses
l = np.sum(losses)
assert not np.isnan(losses).any()
running = l if running is None else running * 0.9 + 0.1 * l
running = l if running is None else running * 0.95 + 0.05 * l
print('' if l > running else '', end='', flush=True)
orign, scald, repro = self.model.predict(images)
@ -468,6 +463,7 @@ class NeuralEnhancer(object):
print(' - adversary mode: generator engaging discriminator.')
self.model.adversary_weight.set_value(args.adversary_weight)
running = None
except KeyboardInterrupt:
pass

Loading…
Cancel
Save