You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
43 lines
1.5 KiB
43 lines
1.5 KiB
import os
|
|
import argparse
|
|
import tensorflow as tf
|
|
import model
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser(description='deblur arguments')
|
|
parser.add_argument('--gpu', type=str, default='0',
|
|
help='set gpu id or leave it blank for cpu')
|
|
parser.add_argument('--model', type=str, default='default',
|
|
help='choose the model trained on default data or all data')
|
|
parser.add_argument('--input_path', type=str, default='./testing_imgs',
|
|
help='path of testing folder or path of one testing image')
|
|
parser.add_argument('--max_height', type=int, default=720,
|
|
help='max height for the input tensor, should be multiples of 16')
|
|
parser.add_argument('--max_width', type=int, default=1280,
|
|
help='max width for the input tensor, should be multiples of 16')
|
|
args = parser.parse_args()
|
|
return args
|
|
|
|
|
|
def main(_):
|
|
args = parse_args()
|
|
|
|
# set gpu id or leave it blank for cpu
|
|
if args.gpu == 'cpu':
|
|
os.environ['CUDA_VISIBLE_DEVICES'] = ''
|
|
else:
|
|
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
|
|
|
|
# choose the model trained on default data or all data
|
|
if args.model == 'default':
|
|
model_path = os.path.join('checkpoints', 'default')
|
|
else:
|
|
model_path = os.path.join('checkpoints', 'alldata')
|
|
|
|
deblur = model.DEBLUR(args)
|
|
deblur.build(model_path)
|
|
deblur.test()
|
|
|
|
if __name__ == '__main__':
|
|
tf.app.run() |