Add loading of images into a buffer, using multiple fragments per JPG loaded. Works well with larger datasets like OpenImages, fully GPU bound.

main
Alex J. Champandard 9 years ago
parent 619fad7f3c
commit 3809e9b02a

@ -36,12 +36,18 @@ add_arg('files', nargs='*', default=[])
add_arg('--scales', default=2, type=int, help='How many times to perform 2x upsampling.')
add_arg('--model', default='ne%ix.pkl.bz2', type=str, help='Name of the neural network to load/save.')
add_arg('--train', default=False, action='store_true', help='Learn new or fine-tune a neural network.')
add_arg('--batch-size', default=15, type=int, help='Number of images per training batch.')
add_arg('--batch-resolution', default=192, type=int, help='Resolution of images in training batch.')
add_arg('--batch-size', default=15, type=int, help='Number of images per training batch.')
add_arg('--buffer-size', default=1500, type=int, help='Total image fragments kept in cache.')
add_arg('--buffer-similar', default=5, type=int, help='Fragments cached for each image loaded.')
add_arg('--epochs', default=10, type=int, help='Total number of iterations in training.')
add_arg('--epoch-size', default=36, type=int, help='Number of batches trained in an epoch.')
add_arg('--generator-filters', default=128, type=int, help='Number of convolution units in network.')
add_arg('--generator-blocks', default=16, type=int, help='Number of residual blocks in total.')
add_arg('--epoch-size', default=72, type=int, help='Number of batches trained in an epoch.')
add_arg('--learning-rate', default=1E-4, type=float, help='Parameter for the ADAM optimizer.')
add_arg('--learning-period', default=50, type=int, help='How often to decay the learning rate.')
add_arg('--learning-decay', default=0.5, type=float, help='How much to decay the learning rate.')
add_arg('--generator-filters', default=[64], nargs='+', type=int, help='Number of convolution units in network.')
add_arg('--generator-blocks', default=12, type=int, help='Number of residual blocks per iteration.')
add_arg('--generator-iters', default=1, type=int, help='Number of iterations in total.')
add_arg('--generator-residual', default=2, type=int, help='Number of layers in a residual block.')
add_arg('--perceptual-layer', default='conv2_2', type=str, help='Which VGG layer to use as loss component.')
add_arg('--perceptual-weight', default=1e0, type=float, help='Weight for VGG-layer perceptual loss.')
@ -50,7 +56,7 @@ add_arg('--adversary-weight', default=1e2, type=float, help='Weight
add_arg('--generator-start', default=0, type=int, help='Epoch count to start training generator.')
add_arg('--discriminator-start',default=1, type=int, help='Epoch count to update the discriminator.')
add_arg('--adversarial-start', default=2, type=int, help='Epoch for generator to use discriminator.')
add_arg('--device', default='gpu0', type=str, help='Name of the CPU/GPU to use, for Theano.')
add_arg('--device', default='cpu', type=str, help='Name of the CPU/GPU to use, for Theano.')
args = parser.parse_args()
@ -119,50 +125,54 @@ class DataLoader(threading.Thread):
self.data_copied = threading.Event()
self.resolution = args.batch_resolution
self.images = np.zeros((args.batch_size, 3, self.resolution, self.resolution), dtype=np.float32)
self.cache = np.zeros((args.buffer_size, 3, self.resolution, self.resolution), dtype=np.float32)
self.available = set(range(args.buffer_size))
self.ready = set()
self.cwd = os.getcwd()
self.start()
def run(self):
files, cache = glob.glob('train/*.jpg'), {}
files = glob.glob('dataset/*/*.jpg')
while True:
random.shuffle(files)
for i, f in enumerate(files[:args.batch_size]):
for f in files:
filename = os.path.join(self.cwd, f)
try:
if f not in cache:
if len(cache) > 2048:
del cache[random.choice(list(cache.keys()))]
img = scipy.ndimage.imread(filename, mode='RGB')
ratio = min(1024 / img.shape[0], 1024 / img.shape[1])
if ratio < 1.0:
img = scipy.misc.imresize(img, ratio, interp='bicubic')
cache[f] = img
else:
img = cache[f]
img = scipy.ndimage.imread(filename, mode='RGB')
except Exception as e:
warn('Could not load `{}` as image.'.format(filename),
' - Try fixing or removing the file before next run.')
files.remove(f)
continue
for _ in range(args.buffer_similar):
copy = img[:,::-1] if random.choice([True, False]) else img
h = random.randint(0, copy.shape[0] - self.resolution)
w = random.randint(0, copy.shape[1] - self.resolution)
copy = copy[h:h+self.resolution, w:w+self.resolution]
while len(self.available) == 0:
self.data_copied.wait()
self.data_copied.clear()
if random.choice([True, False]): img[:,:] = img[:,::-1]
h = random.randint(0, img.shape[0] - self.resolution)
w = random.randint(0, img.shape[1] - self.resolution)
img = img[h:h+self.resolution, w:w+self.resolution]
self.images[i] = np.transpose(img / 255.0 - 0.5, (2, 0, 1))
i = self.available.pop()
self.cache[i] = np.transpose(copy / 255.0 - 0.5, (2, 0, 1))
self.ready.add(i)
self.data_ready.set()
self.data_copied.wait()
self.data_copied.clear()
if len(self.ready) >= args.batch_size:
self.data_ready.set()
def copy(self, output):
self.data_ready.wait()
self.data_ready.clear()
output[:] = self.images
for i, j in enumerate(random.sample(self.ready, args.batch_size)):
output[i] = self.cache[j]
self.available.add(j)
self.data_copied.set()

Loading…
Cancel
Save