Add argument for specifying training images, cleaned up file handling.

main
Alex J. Champandard 9 years ago
parent 1c38f2ca31
commit 4c55c48f62

@ -37,7 +37,7 @@ add_arg = parser.add_argument
add_arg('files', nargs='*', default=[])
add_arg('--scales', default=2, type=int, help='How many times to perform 2x upsampling.')
add_arg('--model', default='medium', type=str, help='Name of the neural network to load/save.')
add_arg('--train', default=False, action='store_true', help='Learn new or fine-tune a neural network.')
add_arg('--train', default=False, type=str, help='File pattern to load for training.')
add_arg('--batch-resolution', default=192, type=int, help='Resolution of images in training batch.')
add_arg('--batch-size', default=15, type=int, help='Number of images per training batch.')
add_arg('--buffer-size', default=1500, type=int, help='Total image fragments kept in cache.')
@ -79,7 +79,7 @@ class ansi:
ENDC = '\033[0m'
def error(message, *lines):
string = "\n{}ERROR: " + message + "{}\n" + "\n".join(lines) + "{}\n"
string = "\n{}ERROR: " + message + "{}\n" + "\n".join(lines) + ("{}\n" if lines else "{}")
print(string.format(ansi.RED_B, ansi.RED, ansi.ENDC))
sys.exit(-1)
@ -129,7 +129,11 @@ class DataLoader(threading.Thread):
self.data_copied = threading.Event()
self.resolution = args.batch_resolution
self.cache = np.zeros((args.buffer_size, 3, self.resolution, self.resolution), dtype=np.float32)
self.buffer = np.zeros((args.buffer_size, 3, self.resolution, self.resolution), dtype=np.float32)
self.files = glob.glob(args.train)
if len(self.files) == 0:
error("There were no files found to train from searching for `{}`".format(args.train),
" - Try putting all your images in one folder and using `--train=data/*.jpg`")
self.available = set(range(args.buffer_size))
self.ready = set()
@ -138,11 +142,10 @@ class DataLoader(threading.Thread):
self.start()
def run(self):
files = glob.glob('dataset/*/*.jpg')
while True:
random.shuffle(files)
random.shuffle(self.files)
for f in files:
for f in self.files:
filename = os.path.join(self.cwd, f)
try:
img = scipy.ndimage.imread(filename, mode='RGB')
@ -163,7 +166,7 @@ class DataLoader(threading.Thread):
self.data_copied.clear()
i = self.available.pop()
self.cache[i] = np.transpose(copy / 255.0 - 0.5, (2, 0, 1))
self.buffer[i] = np.transpose(copy / 255.0 - 0.5, (2, 0, 1))
self.ready.add(i)
if len(self.ready) >= args.batch_size:
@ -174,7 +177,7 @@ class DataLoader(threading.Thread):
self.data_ready.clear()
for i, j in enumerate(random.sample(self.ready, args.batch_size)):
output[i] = self.cache[j]
output[i] = self.buffer[j]
self.available.add(j)
self.data_copied.set()
@ -415,10 +418,11 @@ class NeuralEnhancer(object):
print('{}Training {} epochs on random image sections with batch size {}.{}'\
.format(ansi.BLUE_B, args.epochs, args.batch_size, ansi.BLUE))
else:
if len(args.files) == 0: error("Specify the image(s) to enhance on the command-line.")
print('{}Enhancing {} image(s) specified on the command-line.{}'\
.format(ansi.BLUE_B, len(args.files), ansi.BLUE))
self.thread = DataLoader()
self.thread = DataLoader() if args.train else None
self.model = Model()
print('{}'.format(ansi.ENDC))

Loading…
Cancel
Save