|
|
|
|
@ -368,9 +368,9 @@ class Model(object):
|
|
|
|
|
name = list(self.network.keys())[list(self.network.values()).index(l)]
|
|
|
|
|
yield (name, l)
|
|
|
|
|
|
|
|
|
|
def get_filename(self):
|
|
|
|
|
def get_filename(self, absolute=False):
|
|
|
|
|
filename = 'ne%ix-%s-%s-%s.pkl.bz2' % (args.zoom, args.type, args.model, __version__)
|
|
|
|
|
return os.path.join(os.path.dirname(__file__), filename)
|
|
|
|
|
return os.path.join(os.path.dirname(__file__), filename) if absolute else filename
|
|
|
|
|
|
|
|
|
|
def save_generator(self):
|
|
|
|
|
def cast(p): return p.get_value().astype(np.float16)
|
|
|
|
|
@ -378,11 +378,11 @@ class Model(object):
|
|
|
|
|
config = {k: getattr(args, k) for k in ['generator_blocks', 'generator_residual', 'generator_filters'] + \
|
|
|
|
|
['generator_upscale', 'generator_downscale']}
|
|
|
|
|
|
|
|
|
|
pickle.dump((config, params), bz2.open(self.get_filename(), 'wb'))
|
|
|
|
|
pickle.dump((config, params), bz2.open(self.get_filename(absolute=True), 'wb'))
|
|
|
|
|
print(' - Saved model as `{}` after training.'.format(self.get_filename()))
|
|
|
|
|
|
|
|
|
|
def load_model(self):
|
|
|
|
|
if not os.path.exists(self.get_filename()):
|
|
|
|
|
if not os.path.exists(self.get_filename(absolute=True)):
|
|
|
|
|
if args.train: return {}, {}
|
|
|
|
|
error("Model file with pre-trained convolution layers not found. Download it here...",
|
|
|
|
|
"https://github.com/alexjc/neural-enhance/releases/download/v%s/%s"%(__version__, self.get_filename()))
|
|
|
|
|
|