Improve display and filenames for saving output.

main
Alex J. Champandard 9 years ago
parent 93e5a41d9a
commit 7924cc4a85

@ -14,7 +14,7 @@
# without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
#
__version__ = '0.1'
__version__ = '0.2'
import io
import os
@ -159,10 +159,10 @@ class DataLoader(threading.Thread):
filename = os.path.join(self.cwd, f)
try:
orig = PIL.Image.open(filename).convert('RGB')
if all(s > args.batch_shape * 2 for s in orig.size):
orig = orig.resize((orig.size[0]//2, orig.size[1]//2), resample=PIL.Image.LANCZOS)
if any(s < args.batch_shape * 2 for s in orig.size):
raise ValueError('Image is too small for training with size {}'.format(img.shape))
# if all(s > args.batch_shape * 2 for s in orig.size):
# orig = orig.resize((orig.size[0]//2, orig.size[1]//2), resample=PIL.Image.LANCZOS)
if any(s < args.batch_shape for s in orig.size):
raise ValueError('Image is too small for training with size {}'.format(orig.size))
except Exception as e:
warn('Could not load `{}` as image.'.format(filename),
' - Try fixing or removing the file before next run.')
@ -468,9 +468,9 @@ class NeuralEnhancer(object):
def show_progress(self, orign, scald, repro):
os.makedirs('valid', exist_ok=True)
for i in range(args.batch_size):
self.imsave('valid/%03i_origin.png' % i, orign[i])
self.imsave('valid/%03i_pixels.png' % i, scald[i])
self.imsave('valid/%03i_reprod.png' % i, repro[i])
self.imsave('valid/%s_%03i_origin.png' % (args.model, i), orign[i])
self.imsave('valid/%s_%03i_pixels.png' % (args.model, i), scald[i])
self.imsave('valid/%s_%03i_reprod.png' % (args.model, i), repro[i])
def decay_learning_rate(self):
l_r, t_cur = args.learning_rate, 0
@ -510,7 +510,7 @@ class NeuralEnhancer(object):
stats /= args.epoch_size
totals, labels = [sum(total)] + list(total), ['total', 'prcpt', 'smthn', 'advrs']
gen_info = ['{}{}{}={:4.2e}'.format(ansi.WHITE_B, k, ansi.ENDC, v) for k, v in zip(labels, totals)]
print('\rEpoch #{} at {:4.1f}s, lr={:4.2e}{}'.format(epoch+1, time.time()-start, l_r, ' '*(args.epoch_size-60)))
print('\rEpoch #{} at {:4.1f}s, lr={:4.2e}{}'.format(epoch+1, time.time()-start, l_r, ' '*(args.epoch_size-35)))
print(' - generator {}'.format(' '.join(gen_info)))
real, fake = stats[:args.batch_size], stats[args.batch_size:]

Loading…
Cancel
Save