|
|
|
|
@ -36,12 +36,18 @@ add_arg('files', nargs='*', default=[])
|
|
|
|
|
add_arg('--scales', default=2, type=int, help='How many times to perform 2x upsampling.')
|
|
|
|
|
add_arg('--model', default='ne%ix.pkl.bz2', type=str, help='Name of the neural network to load/save.')
|
|
|
|
|
add_arg('--train', default=False, action='store_true', help='Learn new or fine-tune a neural network.')
|
|
|
|
|
add_arg('--batch-size', default=15, type=int, help='Number of images per training batch.')
|
|
|
|
|
add_arg('--batch-resolution', default=192, type=int, help='Resolution of images in training batch.')
|
|
|
|
|
add_arg('--batch-size', default=15, type=int, help='Number of images per training batch.')
|
|
|
|
|
add_arg('--buffer-size', default=1500, type=int, help='Total image fragments kept in cache.')
|
|
|
|
|
add_arg('--buffer-similar', default=5, type=int, help='Fragments cached for each image loaded.')
|
|
|
|
|
add_arg('--epochs', default=10, type=int, help='Total number of iterations in training.')
|
|
|
|
|
add_arg('--epoch-size', default=36, type=int, help='Number of batches trained in an epoch.')
|
|
|
|
|
add_arg('--generator-filters', default=128, type=int, help='Number of convolution units in network.')
|
|
|
|
|
add_arg('--generator-blocks', default=16, type=int, help='Number of residual blocks in total.')
|
|
|
|
|
add_arg('--epoch-size', default=72, type=int, help='Number of batches trained in an epoch.')
|
|
|
|
|
add_arg('--learning-rate', default=1E-4, type=float, help='Parameter for the ADAM optimizer.')
|
|
|
|
|
add_arg('--learning-period', default=50, type=int, help='How often to decay the learning rate.')
|
|
|
|
|
add_arg('--learning-decay', default=0.5, type=float, help='How much to decay the learning rate.')
|
|
|
|
|
add_arg('--generator-filters', default=[64], nargs='+', type=int, help='Number of convolution units in network.')
|
|
|
|
|
add_arg('--generator-blocks', default=12, type=int, help='Number of residual blocks per iteration.')
|
|
|
|
|
add_arg('--generator-iters', default=1, type=int, help='Number of iterations in total.')
|
|
|
|
|
add_arg('--generator-residual', default=2, type=int, help='Number of layers in a residual block.')
|
|
|
|
|
add_arg('--perceptual-layer', default='conv2_2', type=str, help='Which VGG layer to use as loss component.')
|
|
|
|
|
add_arg('--perceptual-weight', default=1e0, type=float, help='Weight for VGG-layer perceptual loss.')
|
|
|
|
|
@ -50,7 +56,7 @@ add_arg('--adversary-weight', default=1e2, type=float, help='Weight
|
|
|
|
|
add_arg('--generator-start', default=0, type=int, help='Epoch count to start training generator.')
|
|
|
|
|
add_arg('--discriminator-start',default=1, type=int, help='Epoch count to update the discriminator.')
|
|
|
|
|
add_arg('--adversarial-start', default=2, type=int, help='Epoch for generator to use discriminator.')
|
|
|
|
|
add_arg('--device', default='gpu0', type=str, help='Name of the CPU/GPU to use, for Theano.')
|
|
|
|
|
add_arg('--device', default='cpu', type=str, help='Name of the CPU/GPU to use, for Theano.')
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -119,50 +125,54 @@ class DataLoader(threading.Thread):
|
|
|
|
|
self.data_copied = threading.Event()
|
|
|
|
|
|
|
|
|
|
self.resolution = args.batch_resolution
|
|
|
|
|
self.images = np.zeros((args.batch_size, 3, self.resolution, self.resolution), dtype=np.float32)
|
|
|
|
|
self.cache = np.zeros((args.buffer_size, 3, self.resolution, self.resolution), dtype=np.float32)
|
|
|
|
|
|
|
|
|
|
self.available = set(range(args.buffer_size))
|
|
|
|
|
self.ready = set()
|
|
|
|
|
|
|
|
|
|
self.cwd = os.getcwd()
|
|
|
|
|
self.start()
|
|
|
|
|
|
|
|
|
|
def run(self):
|
|
|
|
|
files, cache = glob.glob('train/*.jpg'), {}
|
|
|
|
|
files = glob.glob('dataset/*/*.jpg')
|
|
|
|
|
while True:
|
|
|
|
|
random.shuffle(files)
|
|
|
|
|
for i, f in enumerate(files[:args.batch_size]):
|
|
|
|
|
|
|
|
|
|
for f in files:
|
|
|
|
|
filename = os.path.join(self.cwd, f)
|
|
|
|
|
try:
|
|
|
|
|
if f not in cache:
|
|
|
|
|
if len(cache) > 2048:
|
|
|
|
|
del cache[random.choice(list(cache.keys()))]
|
|
|
|
|
|
|
|
|
|
img = scipy.ndimage.imread(filename, mode='RGB')
|
|
|
|
|
ratio = min(1024 / img.shape[0], 1024 / img.shape[1])
|
|
|
|
|
if ratio < 1.0:
|
|
|
|
|
img = scipy.misc.imresize(img, ratio, interp='bicubic')
|
|
|
|
|
cache[f] = img
|
|
|
|
|
else:
|
|
|
|
|
img = cache[f]
|
|
|
|
|
img = scipy.ndimage.imread(filename, mode='RGB')
|
|
|
|
|
except Exception as e:
|
|
|
|
|
warn('Could not load `{}` as image.'.format(filename),
|
|
|
|
|
' - Try fixing or removing the file before next run.')
|
|
|
|
|
files.remove(f)
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
if random.choice([True, False]): img[:,:] = img[:,::-1]
|
|
|
|
|
h = random.randint(0, img.shape[0] - self.resolution)
|
|
|
|
|
w = random.randint(0, img.shape[1] - self.resolution)
|
|
|
|
|
img = img[h:h+self.resolution, w:w+self.resolution]
|
|
|
|
|
self.images[i] = np.transpose(img / 255.0 - 0.5, (2, 0, 1))
|
|
|
|
|
for _ in range(args.buffer_similar):
|
|
|
|
|
copy = img[:,::-1] if random.choice([True, False]) else img
|
|
|
|
|
h = random.randint(0, copy.shape[0] - self.resolution)
|
|
|
|
|
w = random.randint(0, copy.shape[1] - self.resolution)
|
|
|
|
|
copy = copy[h:h+self.resolution, w:w+self.resolution]
|
|
|
|
|
|
|
|
|
|
self.data_ready.set()
|
|
|
|
|
self.data_copied.wait()
|
|
|
|
|
self.data_copied.clear()
|
|
|
|
|
while len(self.available) == 0:
|
|
|
|
|
self.data_copied.wait()
|
|
|
|
|
self.data_copied.clear()
|
|
|
|
|
|
|
|
|
|
i = self.available.pop()
|
|
|
|
|
self.cache[i] = np.transpose(copy / 255.0 - 0.5, (2, 0, 1))
|
|
|
|
|
self.ready.add(i)
|
|
|
|
|
|
|
|
|
|
if len(self.ready) >= args.batch_size:
|
|
|
|
|
self.data_ready.set()
|
|
|
|
|
|
|
|
|
|
def copy(self, output):
|
|
|
|
|
self.data_ready.wait()
|
|
|
|
|
self.data_ready.clear()
|
|
|
|
|
|
|
|
|
|
output[:] = self.images
|
|
|
|
|
for i, j in enumerate(random.sample(self.ready, args.batch_size)):
|
|
|
|
|
output[i] = self.cache[j]
|
|
|
|
|
self.available.add(j)
|
|
|
|
|
|
|
|
|
|
self.data_copied.set()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|