You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
91 lines
3.6 KiB
91 lines
3.6 KiB
from __future__ import print_function
|
|
import argparse
|
|
import numpy as np
|
|
import torch
|
|
import cv2
|
|
import yaml
|
|
import os
|
|
from torch.autograd import Variable
|
|
from models.networks import get_generator
|
|
import torchvision
|
|
import time
|
|
import torch.nn.functional as F
|
|
|
|
def get_args():
|
|
parser = argparse.ArgumentParser('Test an image')
|
|
parser.add_argument('--weights_path', required=True, help='Weights path')
|
|
return parser.parse_args()
|
|
|
|
if __name__ == '__main__':
|
|
args = get_args()
|
|
with open('config/config_Stripformer_gopro.yaml') as cfg:
|
|
config = yaml.safe_load(cfg)
|
|
blur_path = './datasets/Realblur_R/test/blur/'
|
|
out_path = './out/Stripformer_realblur_R_results'
|
|
model = get_generator(config['model'])
|
|
model.load_state_dict(torch.load(args.weights_path))
|
|
model = model.cuda()
|
|
if not os.path.isdir(out_path):
|
|
os.mkdir(out_path)
|
|
test_time = 0
|
|
iteration = 0
|
|
total_image_number = 980
|
|
|
|
# warm up
|
|
warm_up = 0
|
|
print('Hardware warm-up')
|
|
for file in os.listdir(blur_path):
|
|
if not os.path.isdir(out_path + '/' + file):
|
|
os.mkdir(out_path + '/' + file)
|
|
for img_name in os.listdir(blur_path + '/' + file):
|
|
warm_up += 1
|
|
img = cv2.imread(blur_path + '/' + file + '/' + img_name)
|
|
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
|
img_tensor = torch.from_numpy(np.transpose(img / 255, (2, 0, 1)).astype('float32')) - 0.5
|
|
with torch.no_grad():
|
|
img_tensor = Variable(img_tensor.unsqueeze(0)).cuda()
|
|
factor = 8
|
|
h, w = img_tensor.shape[2], img_tensor.shape[3]
|
|
H, W = ((h + factor) // factor) * factor, ((w + factor) // factor) * factor
|
|
padh = H - h if h % factor != 0 else 0
|
|
padw = W - w if w % factor != 0 else 0
|
|
img_tensor = F.pad(img_tensor, (0, padw, 0, padh), 'reflect')
|
|
result_image = model(img_tensor)
|
|
if warm_up == 20:
|
|
break
|
|
break
|
|
|
|
for file in os.listdir(blur_path):
|
|
if not os.path.isdir(out_path + '/' + file):
|
|
os.mkdir(out_path + '/' + file)
|
|
for img_name in os.listdir(blur_path + '/' + file):
|
|
img = cv2.imread(blur_path + '/' + file + '/' + img_name)
|
|
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
|
img_tensor = torch.from_numpy(np.transpose(img / 255, (2, 0, 1)).astype('float32')) - 0.5
|
|
with torch.no_grad():
|
|
iteration += 1
|
|
img_tensor = Variable(img_tensor.unsqueeze(0)).cuda()
|
|
|
|
factor = 8
|
|
h, w = img_tensor.shape[2], img_tensor.shape[3]
|
|
H, W = ((h + factor) // factor) * factor, ((w + factor) // factor) * factor
|
|
padh = H - h if h % factor != 0 else 0
|
|
padw = W - w if w % factor != 0 else 0
|
|
img_tensor = F.pad(img_tensor, (0, padw, 0, padh), 'reflect')
|
|
H, W = img_tensor.shape[2], img_tensor.shape[3]
|
|
|
|
start = time.time()
|
|
_output = model(img_tensor)
|
|
stop = time.time()
|
|
|
|
result_image = _output[:, :, :h, :w]
|
|
result_image = torch.clamp(result_image, -0.5, 0.5)
|
|
result_image = result_image + 0.5
|
|
|
|
test_time += stop - start
|
|
print('Image:{}/{}, CNN Runtime:{:.4f}'.format(iteration, total_image_number, (stop - start)))
|
|
print('Average Runtime:{:.4f}'.format(test_time / float(iteration)))
|
|
out_file_name = out_path + '/' + file + '/' + img_name
|
|
torchvision.utils.save_image(result_image, out_file_name)
|
|
|