|
|
|
|
@ -1,4 +1,4 @@
|
|
|
|
|
#!/usr/bin/env python3
|
|
|
|
|
#!/usr/bin/env python3
|
|
|
|
|
""" _ _
|
|
|
|
|
_ __ ___ _ _ _ __ __ _| | ___ _ __ | |__ __ _ _ __ ___ ___
|
|
|
|
|
| '_ \ / _ \ | | | '__/ _` | | / _ \ '_ \| '_ \ / _` | '_ \ / __/ _ \
|
|
|
|
|
@ -122,7 +122,7 @@ class DataLoader(threading.Thread):
|
|
|
|
|
def run(self):
|
|
|
|
|
files, cache = glob.glob('train/*.jpg'), {}
|
|
|
|
|
while True:
|
|
|
|
|
# random.shuffle(files)
|
|
|
|
|
random.shuffle(files)
|
|
|
|
|
for i, f in enumerate(files[:args.batch_size]):
|
|
|
|
|
filename = os.path.join(self.cwd, f)
|
|
|
|
|
try:
|
|
|
|
|
@ -133,9 +133,9 @@ class DataLoader(threading.Thread):
|
|
|
|
|
files.remove(f)
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
# if random.choice([True, False]): img[:,:] = img[:,::-1]
|
|
|
|
|
h = (img.shape[0] - self.resolution) // 2 # random.randint(0, img.shape[0] - self.resolution)
|
|
|
|
|
w = (img.shape[1] - self.resolution) // 2 # random.randint(0, img.shape[1] - self.resolution)
|
|
|
|
|
if random.choice([True, False]): img[:,:] = img[:,::-1]
|
|
|
|
|
h = random.randint(0, img.shape[0] - self.resolution)
|
|
|
|
|
w = random.randint(0, img.shape[1] - self.resolution)
|
|
|
|
|
img = img[h:h+self.resolution, w:w+self.resolution]
|
|
|
|
|
self.images[i] = np.transpose(img / 255.0 - 0.5, (2, 0, 1))
|
|
|
|
|
|
|
|
|
|
@ -187,6 +187,8 @@ class Model(object):
|
|
|
|
|
self.setup_perceptual(concatenated)
|
|
|
|
|
self.load_perceptual()
|
|
|
|
|
self.setup_discriminator()
|
|
|
|
|
self.load_generator()
|
|
|
|
|
|
|
|
|
|
self.compile()
|
|
|
|
|
|
|
|
|
|
#------------------------------------------------------------------------------------------------------------------
|
|
|
|
|
@ -285,14 +287,18 @@ class Model(object):
|
|
|
|
|
def save_generator(self):
|
|
|
|
|
def cast(p): return p.get_value().astype(np.float16)
|
|
|
|
|
params = {k: [cast(p) for p in l.get_params()] for (k, l) in self.list_generator_layers()}
|
|
|
|
|
pickle.dump(params, bz2.open(args.model % 2**args.scales, 'wb'))
|
|
|
|
|
filename = args.model % 2**args.scales
|
|
|
|
|
pickle.dump(params, bz2.open(filename, 'wb'))
|
|
|
|
|
print(' - Saved model as `{}` after training.'.format(filename))
|
|
|
|
|
|
|
|
|
|
def load_generator(self):
|
|
|
|
|
filename = args.model % 2**args.scales
|
|
|
|
|
if not os.path.exists(filename): return
|
|
|
|
|
params = pickle.load(bz2.open(filename, 'rb'))
|
|
|
|
|
for k, l in self.list_generator_layers():
|
|
|
|
|
if k not in params: continue
|
|
|
|
|
(p.set_value(v) for p, v in zip(l.get_params(), params[k]))
|
|
|
|
|
print(' - Loaded file `{}` with trained model.'.format(filename))
|
|
|
|
|
|
|
|
|
|
#------------------------------------------------------------------------------------------------------------------
|
|
|
|
|
# Training & Loss Functions
|
|
|
|
|
@ -367,15 +373,13 @@ class NeuralEnhancer(object):
|
|
|
|
|
self.imsave('valid/%03i_repro.png' % i, repro[i])
|
|
|
|
|
|
|
|
|
|
def train(self):
|
|
|
|
|
self.model.load_generator()
|
|
|
|
|
|
|
|
|
|
images = np.zeros((args.batch_size, 3, args.batch_resolution, args.batch_resolution), dtype=np.float32)
|
|
|
|
|
l_min, l_max, l_mult = 1E-7, 1E-3, 0.2
|
|
|
|
|
t_cur, t_i, t_mult = 120, 150, 1
|
|
|
|
|
|
|
|
|
|
i, running = 0, None
|
|
|
|
|
i, running, start = 0, None, time.time()
|
|
|
|
|
for k in range(args.epochs):
|
|
|
|
|
total, start = None, time.time()
|
|
|
|
|
total = None
|
|
|
|
|
for _ in range(args.epoch_size):
|
|
|
|
|
i += 1
|
|
|
|
|
l_r = l_min + 0.5 * (l_max - l_min) * (1.0 + math.cos(t_cur / t_i * math.pi))
|
|
|
|
|
@ -400,14 +404,17 @@ class NeuralEnhancer(object):
|
|
|
|
|
total /= args.epoch_size
|
|
|
|
|
totals, labels = [sum(total)] + list(total), ['total', 'prcpt', 'smthn', 'advrs']
|
|
|
|
|
losses = ['{}{}{}={:4.2e}'.format(ansi.WHITE_B, k, ansi.ENDC, v) for k, v in zip(labels, totals)]
|
|
|
|
|
print('\rEpoch #{} in {:4.1f}s{}'.format(k+1, time.time()-start, ' '*args.epoch_size))
|
|
|
|
|
print(' - losses {}\n'.format(' '.join(losses)))
|
|
|
|
|
print('\rEpoch #{} at {:4.1f}s{}'.format(k+1, time.time()-start, ' '*args.epoch_size))
|
|
|
|
|
print(' - losses {}'.format(' '.join(losses)))
|
|
|
|
|
|
|
|
|
|
# print(stats[:args.batch_size].mean(), stats[args.batch_size:].mean())
|
|
|
|
|
# if k == 0: self.model.disc_lr.set_value(l_r)
|
|
|
|
|
# if k == 1: self.model.adversary_weight.set_value(args.adversary_weight)
|
|
|
|
|
|
|
|
|
|
print('\n{}Trained {}x super-resolution for {} epochs.{}'\
|
|
|
|
|
.format(ansi.CYAN_B, 2**args.scales, args.epochs, ansi.CYAN))
|
|
|
|
|
self.model.save_generator()
|
|
|
|
|
print(ansi.ENDC)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
|