From 8f5167d235ad831e7a7c38faae2d4231bf39a8ce Mon Sep 17 00:00:00 2001 From: Tom White Date: Mon, 31 Oct 2016 02:13:02 +1300 Subject: [PATCH] Fix enhancer.process to pass img, seed --- enhance.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/enhance.py b/enhance.py index 3bce683..c73213a 100644 --- a/enhance.py +++ b/enhance.py @@ -221,12 +221,8 @@ class Model(object): def __init__(self): self.network = collections.OrderedDict() - if args.train: - self.network['img'] = InputLayer((None, 3, None, None)) - self.network['seed'] = InputLayer((None, 3, None, None)) - else: - self.network['img'] = InputLayer((None, 3, None, None)) - self.network['seed'] = self.network['img'] + self.network['img'] = InputLayer((None, 3, None, None)) + self.network['seed'] = InputLayer((None, 3, None, None)) config, params = self.load_model() self.setup_generator(self.last_layer(), config) @@ -507,7 +503,7 @@ class NeuralEnhancer(object): def process(self, image): img = np.transpose(image / 255.0 - 0.5, (2, 0, 1))[np.newaxis].astype(np.float32) - *_, repro = self.model.predict(img) + *_, repro = self.model.predict(img, img) repro = np.transpose(repro[0] + 0.5, (1, 2, 0)).clip(0.0, 1.0) return scipy.misc.toimage(repro * 255.0, cmin=0, cmax=255)