diff --git a/enhance.py b/enhance.py index 27ba310..0b9f12f 100644 --- a/enhance.py +++ b/enhance.py @@ -37,7 +37,7 @@ add_arg = parser.add_argument add_arg('files', nargs='*', default=[]) add_arg('--scales', default=2, type=int, help='How many times to perform 2x upsampling.') add_arg('--model', default='medium', type=str, help='Name of the neural network to load/save.') -add_arg('--train', default=False, action='store_true', help='Learn new or fine-tune a neural network.') +add_arg('--train', default=False, type=str, help='File pattern to load for training.') add_arg('--batch-resolution', default=192, type=int, help='Resolution of images in training batch.') add_arg('--batch-size', default=15, type=int, help='Number of images per training batch.') add_arg('--buffer-size', default=1500, type=int, help='Total image fragments kept in cache.') @@ -79,7 +79,7 @@ class ansi: ENDC = '\033[0m' def error(message, *lines): - string = "\n{}ERROR: " + message + "{}\n" + "\n".join(lines) + "{}\n" + string = "\n{}ERROR: " + message + "{}\n" + "\n".join(lines) + ("{}\n" if lines else "{}") print(string.format(ansi.RED_B, ansi.RED, ansi.ENDC)) sys.exit(-1) @@ -129,7 +129,11 @@ class DataLoader(threading.Thread): self.data_copied = threading.Event() self.resolution = args.batch_resolution - self.cache = np.zeros((args.buffer_size, 3, self.resolution, self.resolution), dtype=np.float32) + self.buffer = np.zeros((args.buffer_size, 3, self.resolution, self.resolution), dtype=np.float32) + self.files = glob.glob(args.train) + if len(self.files) == 0: + error("There were no files found to train from searching for `{}`".format(args.train), + " - Try putting all your images in one folder and using `--train=data/*.jpg`") self.available = set(range(args.buffer_size)) self.ready = set() @@ -138,11 +142,10 @@ class DataLoader(threading.Thread): self.start() def run(self): - files = glob.glob('dataset/*/*.jpg') while True: - random.shuffle(files) + random.shuffle(self.files) - for f in files: + for f in self.files: filename = os.path.join(self.cwd, f) try: img = scipy.ndimage.imread(filename, mode='RGB') @@ -163,7 +166,7 @@ class DataLoader(threading.Thread): self.data_copied.clear() i = self.available.pop() - self.cache[i] = np.transpose(copy / 255.0 - 0.5, (2, 0, 1)) + self.buffer[i] = np.transpose(copy / 255.0 - 0.5, (2, 0, 1)) self.ready.add(i) if len(self.ready) >= args.batch_size: @@ -174,7 +177,7 @@ class DataLoader(threading.Thread): self.data_ready.clear() for i, j in enumerate(random.sample(self.ready, args.batch_size)): - output[i] = self.cache[j] + output[i] = self.buffer[j] self.available.add(j) self.data_copied.set() @@ -415,10 +418,11 @@ class NeuralEnhancer(object): print('{}Training {} epochs on random image sections with batch size {}.{}'\ .format(ansi.BLUE_B, args.epochs, args.batch_size, ansi.BLUE)) else: + if len(args.files) == 0: error("Specify the image(s) to enhance on the command-line.") print('{}Enhancing {} image(s) specified on the command-line.{}'\ .format(ansi.BLUE_B, len(args.files), ansi.BLUE)) - self.thread = DataLoader() + self.thread = DataLoader() if args.train else None self.model = Model() print('{}'.format(ansi.ENDC))