|
|
|
|
@ -14,7 +14,7 @@
|
|
|
|
|
# without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
|
|
|
|
|
#
|
|
|
|
|
|
|
|
|
|
__version__ = '0.1'
|
|
|
|
|
__version__ = '0.2'
|
|
|
|
|
|
|
|
|
|
import io
|
|
|
|
|
import os
|
|
|
|
|
@ -159,10 +159,10 @@ class DataLoader(threading.Thread):
|
|
|
|
|
filename = os.path.join(self.cwd, f)
|
|
|
|
|
try:
|
|
|
|
|
orig = PIL.Image.open(filename).convert('RGB')
|
|
|
|
|
if all(s > args.batch_shape * 2 for s in orig.size):
|
|
|
|
|
orig = orig.resize((orig.size[0]//2, orig.size[1]//2), resample=PIL.Image.LANCZOS)
|
|
|
|
|
if any(s < args.batch_shape * 2 for s in orig.size):
|
|
|
|
|
raise ValueError('Image is too small for training with size {}'.format(img.shape))
|
|
|
|
|
# if all(s > args.batch_shape * 2 for s in orig.size):
|
|
|
|
|
# orig = orig.resize((orig.size[0]//2, orig.size[1]//2), resample=PIL.Image.LANCZOS)
|
|
|
|
|
if any(s < args.batch_shape for s in orig.size):
|
|
|
|
|
raise ValueError('Image is too small for training with size {}'.format(orig.size))
|
|
|
|
|
except Exception as e:
|
|
|
|
|
warn('Could not load `{}` as image.'.format(filename),
|
|
|
|
|
' - Try fixing or removing the file before next run.')
|
|
|
|
|
@ -468,9 +468,9 @@ class NeuralEnhancer(object):
|
|
|
|
|
def show_progress(self, orign, scald, repro):
|
|
|
|
|
os.makedirs('valid', exist_ok=True)
|
|
|
|
|
for i in range(args.batch_size):
|
|
|
|
|
self.imsave('valid/%03i_origin.png' % i, orign[i])
|
|
|
|
|
self.imsave('valid/%03i_pixels.png' % i, scald[i])
|
|
|
|
|
self.imsave('valid/%03i_reprod.png' % i, repro[i])
|
|
|
|
|
self.imsave('valid/%s_%03i_origin.png' % (args.model, i), orign[i])
|
|
|
|
|
self.imsave('valid/%s_%03i_pixels.png' % (args.model, i), scald[i])
|
|
|
|
|
self.imsave('valid/%s_%03i_reprod.png' % (args.model, i), repro[i])
|
|
|
|
|
|
|
|
|
|
def decay_learning_rate(self):
|
|
|
|
|
l_r, t_cur = args.learning_rate, 0
|
|
|
|
|
@ -510,7 +510,7 @@ class NeuralEnhancer(object):
|
|
|
|
|
stats /= args.epoch_size
|
|
|
|
|
totals, labels = [sum(total)] + list(total), ['total', 'prcpt', 'smthn', 'advrs']
|
|
|
|
|
gen_info = ['{}{}{}={:4.2e}'.format(ansi.WHITE_B, k, ansi.ENDC, v) for k, v in zip(labels, totals)]
|
|
|
|
|
print('\rEpoch #{} at {:4.1f}s, lr={:4.2e}{}'.format(epoch+1, time.time()-start, l_r, ' '*(args.epoch_size-60)))
|
|
|
|
|
print('\rEpoch #{} at {:4.1f}s, lr={:4.2e}{}'.format(epoch+1, time.time()-start, l_r, ' '*(args.epoch_size-35)))
|
|
|
|
|
print(' - generator {}'.format(' '.join(gen_info)))
|
|
|
|
|
|
|
|
|
|
real, fake = stats[:args.batch_size], stats[args.batch_size:]
|
|
|
|
|
|