|
|
|
|
@ -418,22 +418,17 @@ class NeuralEnhancer(object):
|
|
|
|
|
self.imsave('valid/%03i_pixels.png' % i, scald[i])
|
|
|
|
|
self.imsave('valid/%03i_reprod.png' % i, repro[i])
|
|
|
|
|
|
|
|
|
|
def decay_with_restart(self):
|
|
|
|
|
l_min, l_max, l_mult = 1E-7, 1E-3, 0.5
|
|
|
|
|
t_cur, t_i, t_mult = 10, 10, 1
|
|
|
|
|
def decay_learning_rate(self):
|
|
|
|
|
l_r, t_cur = args.learning_rate, 0
|
|
|
|
|
|
|
|
|
|
while True:
|
|
|
|
|
yield l_min + 0.5 * (l_max - l_min) * (1.0 + math.cos(t_cur / t_i * math.pi))
|
|
|
|
|
yield l_r if t_cur > 0 else l_r * 0.1
|
|
|
|
|
t_cur += 1
|
|
|
|
|
|
|
|
|
|
if t_cur > t_i:
|
|
|
|
|
t_cur, t_i = 0, int(t_i * t_mult)
|
|
|
|
|
l_max = max(l_max * l_mult, 1e-12)
|
|
|
|
|
l_min = max(l_min * l_mult, 1e-8)
|
|
|
|
|
if t_cur % args.learning_period == 0: l_r *= args.learning_decay
|
|
|
|
|
|
|
|
|
|
def train(self):
|
|
|
|
|
images = np.zeros((args.batch_size, 3, args.batch_resolution, args.batch_resolution), dtype=np.float32)
|
|
|
|
|
learning_rate = self.decay_with_restart()
|
|
|
|
|
learning_rate = self.decay_learning_rate()
|
|
|
|
|
try:
|
|
|
|
|
running, start = None, time.time()
|
|
|
|
|
for epoch in range(args.epochs):
|
|
|
|
|
@ -450,8 +445,8 @@ class NeuralEnhancer(object):
|
|
|
|
|
total = total + losses if total is not None else losses
|
|
|
|
|
l = np.sum(losses)
|
|
|
|
|
assert not np.isnan(losses).any()
|
|
|
|
|
running = l if running is None else running * 0.9 + 0.1 * l
|
|
|
|
|
print('↑' if l > running else '↓', end=' ', flush=True)
|
|
|
|
|
running = l if running is None else running * 0.95 + 0.05 * l
|
|
|
|
|
print('↑' if l > running else '↓', end='', flush=True)
|
|
|
|
|
|
|
|
|
|
orign, scald, repro = self.model.predict(images)
|
|
|
|
|
self.show_progress(orign, scald, repro)
|
|
|
|
|
@ -468,6 +463,7 @@ class NeuralEnhancer(object):
|
|
|
|
|
print(' - adversary mode: generator engaging discriminator.')
|
|
|
|
|
self.model.adversary_weight.set_value(args.adversary_weight)
|
|
|
|
|
running = None
|
|
|
|
|
|
|
|
|
|
except KeyboardInterrupt:
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|