diff --git a/enhance.py b/enhance.py index 73ba82e..2dd36a2 100755 --- a/enhance.py +++ b/enhance.py @@ -188,7 +188,7 @@ class DataLoader(threading.Thread): if args.train_noise: seed += scipy.random.normal(scale=args.train_noise, size=(seed.shape[0], seed.shape[1], 1)) ** 4.0 - for _ in range(args.buffer_similar): + for _ in range(seed.shape[0] * seed.shape[1] // self.seed_shape * 2): h = random.randint(0, seed.shape[0] - self.seed_shape) w = random.randint(0, seed.shape[1] - self.seed_shape) seed_chunk = seed[h:h+self.seed_shape, w:w+self.seed_shape]