diff --git a/enhance.py b/enhance.py index 4f23967..ecd8afc 100755 --- a/enhance.py +++ b/enhance.py @@ -368,23 +368,26 @@ class Model(object): name = list(self.network.keys())[list(self.network.values()).index(l)] yield (name, l) + def get_filename(self): + filename = 'ne%ix-%s-%s-%s.pkl.bz2' % (args.zoom, args.type, args.model, __version__) + return os.path.join(os.path.dirname(__file__), filename) + def save_generator(self): def cast(p): return p.get_value().astype(np.float16) params = {k: [cast(p) for p in l.get_params()] for (k, l) in self.list_generator_layers()} config = {k: getattr(args, k) for k in ['generator_blocks', 'generator_residual', 'generator_filters'] + \ ['generator_upscale', 'generator_downscale']} - filename = 'ne%ix-%s-%s-%s.pkl.bz2' % (args.zoom, args.type, args.model, __version__) - pickle.dump((config, params), bz2.open(filename, 'wb')) - print(' - Saved model as `{}` after training.'.format(filename)) + + pickle.dump((config, params), bz2.open(self.get_filename(), 'wb')) + print(' - Saved model as `{}` after training.'.format(self.get_filename())) def load_model(self): - filename = 'ne%ix-%s-%s-%s.pkl.bz2' % (args.zoom, args.type, args.model, __version__) - if not os.path.exists(filename): + if not os.path.exists(self.get_filename()): if args.train: return {}, {} error("Model file with pre-trained convolution layers not found. Download it here...", - "https://github.com/alexjc/neural-enhance/releases/download/v%s/%s"%(__version__, filename)) - print(' - Loaded file `{}` with trained model.'.format(filename)) - return pickle.load(bz2.open(filename, 'rb')) + "https://github.com/alexjc/neural-enhance/releases/download/v%s/%s"%(__version__, self.get_filename())) + print(' - Loaded file `{}` with trained model.'.format(self.get_filename())) + return pickle.load(bz2.open(self.get_filename(), 'rb')) def load_generator(self, params): if len(params) == 0: return