commit
df69507ef2
@ -0,0 +1,62 @@
|
|||||||
|
from typing import List
|
||||||
|
|
||||||
|
import albumentations as albu
|
||||||
|
from torchvision import transforms
|
||||||
|
|
||||||
|
def get_transforms(size: int, scope: str = 'geometric', crop='random'):
|
||||||
|
augs = {'strong': albu.Compose([albu.HorizontalFlip(),
|
||||||
|
albu.ShiftScaleRotate(shift_limit=0.0, scale_limit=0.2, rotate_limit=20, p=.4),
|
||||||
|
albu.ElasticTransform(),
|
||||||
|
albu.OpticalDistortion(),
|
||||||
|
albu.OneOf([
|
||||||
|
albu.CLAHE(clip_limit=2),
|
||||||
|
albu.IAASharpen(),
|
||||||
|
albu.IAAEmboss(),
|
||||||
|
albu.RandomBrightnessContrast(),
|
||||||
|
albu.RandomGamma()
|
||||||
|
], p=0.5),
|
||||||
|
albu.OneOf([
|
||||||
|
albu.RGBShift(),
|
||||||
|
albu.HueSaturationValue(),
|
||||||
|
], p=0.5),
|
||||||
|
]),
|
||||||
|
'weak': albu.Compose([albu.HorizontalFlip(),
|
||||||
|
]),
|
||||||
|
'geometric': albu.Compose([albu.HorizontalFlip(),
|
||||||
|
albu.VerticalFlip(),
|
||||||
|
albu.RandomRotate90(),
|
||||||
|
]),
|
||||||
|
'None': None
|
||||||
|
}
|
||||||
|
|
||||||
|
aug_fn = augs[scope]
|
||||||
|
crop_fn = {'random': albu.RandomCrop(size, size, always_apply=True),
|
||||||
|
'center': albu.CenterCrop(size, size, always_apply=True)}[crop]
|
||||||
|
|
||||||
|
pipeline = albu.Compose([aug_fn, crop_fn], additional_targets={'target': 'image'})
|
||||||
|
|
||||||
|
|
||||||
|
def process(a, b):
|
||||||
|
r = pipeline(image=a, target=b)
|
||||||
|
return r['image'], r['target']
|
||||||
|
|
||||||
|
return process
|
||||||
|
|
||||||
|
|
||||||
|
def get_normalize():
|
||||||
|
transform = transforms.Compose([
|
||||||
|
transforms.ToTensor()
|
||||||
|
])
|
||||||
|
|
||||||
|
def process(a, b):
|
||||||
|
image = transform(a).permute(1, 2, 0) - 0.5
|
||||||
|
target = transform(b).permute(1, 2, 0) - 0.5
|
||||||
|
return image, target
|
||||||
|
|
||||||
|
return process
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@ -0,0 +1,40 @@
|
|||||||
|
---
|
||||||
|
experiment_desc: Stripformer_gopro
|
||||||
|
|
||||||
|
train:
|
||||||
|
files_a: ./datasets/GoPro/train/blur/**/*.png
|
||||||
|
files_b: ./datasets/GoPro/train/sharp/**/*.png
|
||||||
|
size: &SIZE 512
|
||||||
|
crop: random
|
||||||
|
preload: &PRELOAD false
|
||||||
|
preload_size: &PRELOAD_SIZE 0
|
||||||
|
bounds: [0, 1]
|
||||||
|
scope: geometric
|
||||||
|
|
||||||
|
val:
|
||||||
|
files_a: ./datasets/GoPro/test/blur/**/*.png
|
||||||
|
files_b: ./datasets/GoPro/test/sharp/**/*.png
|
||||||
|
size: *SIZE
|
||||||
|
scope: None
|
||||||
|
crop: random
|
||||||
|
preload: *PRELOAD
|
||||||
|
preload_size: *PRELOAD_SIZE
|
||||||
|
bounds: [0, 1]
|
||||||
|
|
||||||
|
model:
|
||||||
|
g_name: Stripformer
|
||||||
|
content_loss: Stripformer_Loss
|
||||||
|
|
||||||
|
num_epochs: 1000
|
||||||
|
train_batches_per_epoch: 2103
|
||||||
|
val_batches_per_epoch: 1111
|
||||||
|
batch_size: 8
|
||||||
|
image_size: [512, 512]
|
||||||
|
|
||||||
|
optimizer:
|
||||||
|
name: adam
|
||||||
|
lr: 0.0001
|
||||||
|
scheduler:
|
||||||
|
name: cosine
|
||||||
|
start_epoch: 50
|
||||||
|
min_lr: 0.0000001
|
||||||
@ -0,0 +1,40 @@
|
|||||||
|
---
|
||||||
|
experiment_desc: Stripformer_pretrained
|
||||||
|
|
||||||
|
train:
|
||||||
|
files_a: ./datasets/GoPro/train/blur/**/*.png
|
||||||
|
files_b: ./datasets/GoPro/train/sharp/**/*.png
|
||||||
|
size: &SIZE 256
|
||||||
|
crop: random
|
||||||
|
preload: &PRELOAD false
|
||||||
|
preload_size: &PRELOAD_SIZE 0
|
||||||
|
bounds: [0, 1]
|
||||||
|
scope: geometric
|
||||||
|
|
||||||
|
val:
|
||||||
|
files_a: ./datasets/GoPro/test/blur/**/*.png
|
||||||
|
files_b: ./datasets/GoPro/test/sharp/**/*.png
|
||||||
|
size: *SIZE
|
||||||
|
scope: None
|
||||||
|
crop: random
|
||||||
|
preload: *PRELOAD
|
||||||
|
preload_size: *PRELOAD_SIZE
|
||||||
|
bounds: [0, 1]
|
||||||
|
|
||||||
|
model:
|
||||||
|
g_name: Stripformer
|
||||||
|
content_loss: Stripformer_Loss
|
||||||
|
|
||||||
|
num_epochs: 3000
|
||||||
|
train_batches_per_epoch: 2103
|
||||||
|
val_batches_per_epoch: 1111
|
||||||
|
batch_size: 8
|
||||||
|
image_size: [256, 256]
|
||||||
|
|
||||||
|
optimizer:
|
||||||
|
name: adam
|
||||||
|
lr: 0.0001
|
||||||
|
scheduler:
|
||||||
|
name: cosine
|
||||||
|
start_epoch: 50
|
||||||
|
min_lr: 0.0000001
|
||||||
@ -0,0 +1,140 @@
|
|||||||
|
import os
|
||||||
|
from copy import deepcopy
|
||||||
|
from functools import partial
|
||||||
|
from glob import glob
|
||||||
|
from hashlib import sha1
|
||||||
|
from typing import Callable, Iterable, Optional, Tuple
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
from glog import logger
|
||||||
|
from joblib import Parallel, cpu_count, delayed
|
||||||
|
from skimage.io import imread
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
import aug
|
||||||
|
|
||||||
|
|
||||||
|
def subsample(data: Iterable, bounds: Tuple[float, float], hash_fn: Callable, n_buckets=100, salt='', verbose=True):
|
||||||
|
data = list(data)
|
||||||
|
buckets = split_into_buckets(data, n_buckets=n_buckets, salt=salt, hash_fn=hash_fn)
|
||||||
|
|
||||||
|
lower_bound, upper_bound = [x * n_buckets for x in bounds]
|
||||||
|
msg = f'Subsampling buckets from {lower_bound} to {upper_bound}, total buckets number is {n_buckets}'
|
||||||
|
if salt:
|
||||||
|
msg += f'; salt is {salt}'
|
||||||
|
if verbose:
|
||||||
|
logger.info(msg)
|
||||||
|
return np.array([sample for bucket, sample in zip(buckets, data) if lower_bound <= bucket < upper_bound])
|
||||||
|
|
||||||
|
|
||||||
|
def hash_from_paths(x: Tuple[str, str], salt: str = '') -> str:
|
||||||
|
path_a, path_b = x
|
||||||
|
names = ''.join(map(os.path.basename, (path_a, path_b)))
|
||||||
|
return sha1(f'{names}_{salt}'.encode()).hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
def split_into_buckets(data: Iterable, n_buckets: int, hash_fn: Callable, salt=''):
|
||||||
|
hashes = map(partial(hash_fn, salt=salt), data)
|
||||||
|
return np.array([int(x, 16) % n_buckets for x in hashes])
|
||||||
|
|
||||||
|
|
||||||
|
def _read_img(x: str):
|
||||||
|
img = cv2.imread(x)
|
||||||
|
if img is None:
|
||||||
|
logger.warning(f'Can not read image {x} with OpenCV, switching to scikit-image')
|
||||||
|
img = imread(x)
|
||||||
|
return img
|
||||||
|
|
||||||
|
|
||||||
|
class PairedDataset(Dataset):
|
||||||
|
def __init__(self,
|
||||||
|
files_a: Tuple[str],
|
||||||
|
files_b: Tuple[str],
|
||||||
|
transform_fn: Callable,
|
||||||
|
normalize_fn: Callable,
|
||||||
|
corrupt_fn: Optional[Callable] = None,
|
||||||
|
preload: bool = True,
|
||||||
|
preload_size: Optional[int] = 0,
|
||||||
|
verbose=True):
|
||||||
|
|
||||||
|
assert len(files_a) == len(files_b)
|
||||||
|
|
||||||
|
self.preload = preload
|
||||||
|
self.data_a = files_a
|
||||||
|
self.data_b = files_b
|
||||||
|
self.verbose = verbose
|
||||||
|
self.corrupt_fn = corrupt_fn
|
||||||
|
self.transform_fn = transform_fn
|
||||||
|
self.normalize_fn = normalize_fn
|
||||||
|
logger.info(f'Dataset has been created with {len(self.data_a)} samples')
|
||||||
|
|
||||||
|
if preload:
|
||||||
|
preload_fn = partial(self._bulk_preload, preload_size=preload_size)
|
||||||
|
if files_a == files_b:
|
||||||
|
self.data_a = self.data_b = preload_fn(self.data_a)
|
||||||
|
else:
|
||||||
|
self.data_a, self.data_b = map(preload_fn, (self.data_a, self.data_b))
|
||||||
|
self.preload = True
|
||||||
|
|
||||||
|
def _bulk_preload(self, data: Iterable[str], preload_size: int):
|
||||||
|
jobs = [delayed(self._preload)(x, preload_size=preload_size) for x in data]
|
||||||
|
jobs = tqdm(jobs, desc='preloading images', disable=not self.verbose)
|
||||||
|
return Parallel(n_jobs=cpu_count(), backend='threading')(jobs)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _preload(x: str, preload_size: int):
|
||||||
|
img = _read_img(x)
|
||||||
|
if preload_size:
|
||||||
|
h, w, *_ = img.shape
|
||||||
|
h_scale = preload_size / h
|
||||||
|
w_scale = preload_size / w
|
||||||
|
scale = max(h_scale, w_scale)
|
||||||
|
img = cv2.resize(img, fx=scale, fy=scale, dsize=None)
|
||||||
|
assert min(img.shape[:2]) >= preload_size, f'weird img shape: {img.shape}'
|
||||||
|
return img
|
||||||
|
|
||||||
|
def _preprocess(self, img, res):
|
||||||
|
def transpose(x):
|
||||||
|
return np.transpose(x, (2, 0, 1))
|
||||||
|
|
||||||
|
return map(transpose, self.normalize_fn(img, res))
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.data_a)
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
a, b = self.data_a[idx], self.data_b[idx]
|
||||||
|
if not self.preload:
|
||||||
|
a, b = map(_read_img, (a, b))
|
||||||
|
a, b = self.transform_fn(a, b)
|
||||||
|
if self.corrupt_fn is not None:
|
||||||
|
a = self.corrupt_fn(a)
|
||||||
|
a, b = self._preprocess(a, b)
|
||||||
|
return {'a': a, 'b': b}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_config(config):
|
||||||
|
config = deepcopy(config)
|
||||||
|
files_a, files_b = map(lambda x: sorted(glob(config[x], recursive=True)), ('files_a', 'files_b'))
|
||||||
|
transform_fn = aug.get_transforms(size=config['size'], scope=config['scope'], crop=config['crop'])
|
||||||
|
normalize_fn = aug.get_normalize()
|
||||||
|
|
||||||
|
hash_fn = hash_from_paths
|
||||||
|
# ToDo: add more hash functions
|
||||||
|
verbose = config.get('verbose', True)
|
||||||
|
data = subsample(data=zip(files_a, files_b),
|
||||||
|
bounds=config.get('bounds', (0, 1)),
|
||||||
|
hash_fn=hash_fn,
|
||||||
|
verbose=verbose)
|
||||||
|
|
||||||
|
files_a, files_b = map(list, zip(*data))
|
||||||
|
|
||||||
|
return PairedDataset(files_a=files_a,
|
||||||
|
files_b=files_b,
|
||||||
|
preload=config['preload'],
|
||||||
|
preload_size=config['preload_size'],
|
||||||
|
normalize_fn=normalize_fn,
|
||||||
|
transform_fn=transform_fn,
|
||||||
|
verbose=verbose)
|
||||||
@ -0,0 +1 @@
|
|||||||
|
Download 'GoPro' datasets and put the datasets into folder './datasets'
|
||||||
@ -0,0 +1,101 @@
|
|||||||
|
import os
|
||||||
|
from skimage import io
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
from skimage.metrics import structural_similarity
|
||||||
|
import concurrent.futures
|
||||||
|
|
||||||
|
def image_align(deblurred, gt):
|
||||||
|
# this function is based on kohler evaluation code
|
||||||
|
z = deblurred
|
||||||
|
c = np.ones_like(z)
|
||||||
|
x = gt
|
||||||
|
|
||||||
|
zs = (np.sum(x * z) / np.sum(z * z)) * z # simple intensity matching
|
||||||
|
|
||||||
|
warp_mode = cv2.MOTION_HOMOGRAPHY
|
||||||
|
warp_matrix = np.eye(3, 3, dtype=np.float32)
|
||||||
|
|
||||||
|
# Specify the number of iterations.
|
||||||
|
number_of_iterations = 100
|
||||||
|
|
||||||
|
termination_eps = 0
|
||||||
|
|
||||||
|
criteria = (cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT,
|
||||||
|
number_of_iterations, termination_eps)
|
||||||
|
|
||||||
|
# Run the ECC algorithm. The results are stored in warp_matrix.
|
||||||
|
(cc, warp_matrix) = cv2.findTransformECC(cv2.cvtColor(x, cv2.COLOR_RGB2GRAY), cv2.cvtColor(zs, cv2.COLOR_RGB2GRAY),
|
||||||
|
warp_matrix, warp_mode, criteria, inputMask=None, gaussFiltSize=5)
|
||||||
|
|
||||||
|
target_shape = x.shape
|
||||||
|
shift = warp_matrix
|
||||||
|
|
||||||
|
zr = cv2.warpPerspective(
|
||||||
|
zs,
|
||||||
|
warp_matrix,
|
||||||
|
(target_shape[1], target_shape[0]),
|
||||||
|
flags=cv2.INTER_CUBIC + cv2.WARP_INVERSE_MAP,
|
||||||
|
borderMode=cv2.BORDER_REFLECT)
|
||||||
|
|
||||||
|
cr = cv2.warpPerspective(
|
||||||
|
np.ones_like(zs, dtype='float32'),
|
||||||
|
warp_matrix,
|
||||||
|
(target_shape[1], target_shape[0]),
|
||||||
|
flags=cv2.INTER_NEAREST + cv2.WARP_INVERSE_MAP,
|
||||||
|
borderMode=cv2.BORDER_CONSTANT,
|
||||||
|
borderValue=0)
|
||||||
|
|
||||||
|
zr = zr * cr
|
||||||
|
xr = x * cr
|
||||||
|
|
||||||
|
return zr, xr, cr, shift
|
||||||
|
|
||||||
|
def compute_psnr(image_true, image_test, image_mask, data_range=None):
|
||||||
|
# this function is based on skimage.metrics.peak_signal_noise_ratio
|
||||||
|
err = np.sum((image_true - image_test) ** 2, dtype=np.float64) / np.sum(image_mask)
|
||||||
|
return 10 * np.log10((data_range ** 2) / err)
|
||||||
|
|
||||||
|
|
||||||
|
def compute_ssim(tar_img, prd_img, cr1):
|
||||||
|
ssim_pre, ssim_map = structural_similarity(tar_img, prd_img, multichannel=True, gaussian_weights=True,
|
||||||
|
use_sample_covariance=False, data_range=1.0, full=True)
|
||||||
|
ssim_map = ssim_map * cr1
|
||||||
|
r = int(3.5 * 1.5 + 0.5) # radius as in ndimage
|
||||||
|
win_size = 2 * r + 1
|
||||||
|
pad = (win_size - 1) // 2
|
||||||
|
ssim = ssim_map[pad:-pad, pad:-pad, :]
|
||||||
|
crop_cr1 = cr1[pad:-pad, pad:-pad, :]
|
||||||
|
ssim = ssim.sum(axis=0).sum(axis=0) / crop_cr1.sum(axis=0).sum(axis=0)
|
||||||
|
ssim = np.mean(ssim)
|
||||||
|
return ssim
|
||||||
|
|
||||||
|
total_psnr = 0.
|
||||||
|
total_ssim = 0.
|
||||||
|
count = 0
|
||||||
|
img_path = './out/Stripformer_realblur_J_results'
|
||||||
|
gt_path = './datasets/Realblur_J/test/sharp'
|
||||||
|
print(img_path)
|
||||||
|
for file in os.listdir(img_path):
|
||||||
|
for img_name in os.listdir(img_path + '/' + file):
|
||||||
|
count += 1
|
||||||
|
number = img_name.split('_')[1]
|
||||||
|
gt_name = 'gt_' + number
|
||||||
|
img_dir = img_path + '/' + file + '/' + img_name
|
||||||
|
gt_dir = gt_path + '/' + file + '/' + gt_name
|
||||||
|
with concurrent.futures.ProcessPoolExecutor(max_workers=10) as executor:
|
||||||
|
tar_img = io.imread(gt_dir)
|
||||||
|
prd_img = io.imread(img_dir)
|
||||||
|
tar_img = tar_img.astype(np.float32) / 255.0
|
||||||
|
prd_img = prd_img.astype(np.float32) / 255.0
|
||||||
|
prd_img, tar_img, cr1, shift = image_align(prd_img, tar_img)
|
||||||
|
PSNR = compute_psnr(tar_img, prd_img, cr1, data_range=1)
|
||||||
|
SSIM = compute_ssim(tar_img, prd_img, cr1)
|
||||||
|
total_psnr += PSNR
|
||||||
|
total_ssim += SSIM
|
||||||
|
print(count, PSNR)
|
||||||
|
|
||||||
|
print('PSNR:', total_psnr / count)
|
||||||
|
print('SSIM:', total_ssim / count)
|
||||||
|
print(img_path)
|
||||||
|
|
||||||
@ -0,0 +1,101 @@
|
|||||||
|
import os
|
||||||
|
from skimage import io
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
from skimage.metrics import structural_similarity
|
||||||
|
import concurrent.futures
|
||||||
|
|
||||||
|
def image_align(deblurred, gt):
|
||||||
|
# this function is based on kohler evaluation code
|
||||||
|
z = deblurred
|
||||||
|
c = np.ones_like(z)
|
||||||
|
x = gt
|
||||||
|
|
||||||
|
zs = (np.sum(x * z) / np.sum(z * z)) * z # simple intensity matching
|
||||||
|
|
||||||
|
warp_mode = cv2.MOTION_HOMOGRAPHY
|
||||||
|
warp_matrix = np.eye(3, 3, dtype=np.float32)
|
||||||
|
|
||||||
|
# Specify the number of iterations.
|
||||||
|
number_of_iterations = 100
|
||||||
|
|
||||||
|
termination_eps = 0
|
||||||
|
|
||||||
|
criteria = (cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT,
|
||||||
|
number_of_iterations, termination_eps)
|
||||||
|
|
||||||
|
# Run the ECC algorithm. The results are stored in warp_matrix.
|
||||||
|
(cc, warp_matrix) = cv2.findTransformECC(cv2.cvtColor(x, cv2.COLOR_RGB2GRAY), cv2.cvtColor(zs, cv2.COLOR_RGB2GRAY),
|
||||||
|
warp_matrix, warp_mode, criteria, inputMask=None, gaussFiltSize=5)
|
||||||
|
|
||||||
|
target_shape = x.shape
|
||||||
|
shift = warp_matrix
|
||||||
|
|
||||||
|
zr = cv2.warpPerspective(
|
||||||
|
zs,
|
||||||
|
warp_matrix,
|
||||||
|
(target_shape[1], target_shape[0]),
|
||||||
|
flags=cv2.INTER_CUBIC + cv2.WARP_INVERSE_MAP,
|
||||||
|
borderMode=cv2.BORDER_REFLECT)
|
||||||
|
|
||||||
|
cr = cv2.warpPerspective(
|
||||||
|
np.ones_like(zs, dtype='float32'),
|
||||||
|
warp_matrix,
|
||||||
|
(target_shape[1], target_shape[0]),
|
||||||
|
flags=cv2.INTER_NEAREST + cv2.WARP_INVERSE_MAP,
|
||||||
|
borderMode=cv2.BORDER_CONSTANT,
|
||||||
|
borderValue=0)
|
||||||
|
|
||||||
|
zr = zr * cr
|
||||||
|
xr = x * cr
|
||||||
|
|
||||||
|
return zr, xr, cr, shift
|
||||||
|
|
||||||
|
def compute_psnr(image_true, image_test, image_mask, data_range=None):
|
||||||
|
# this function is based on skimage.metrics.peak_signal_noise_ratio
|
||||||
|
err = np.sum((image_true - image_test) ** 2, dtype=np.float64) / np.sum(image_mask)
|
||||||
|
return 10 * np.log10((data_range ** 2) / err)
|
||||||
|
|
||||||
|
|
||||||
|
def compute_ssim(tar_img, prd_img, cr1):
|
||||||
|
ssim_pre, ssim_map = structural_similarity(tar_img, prd_img, multichannel=True, gaussian_weights=True,
|
||||||
|
use_sample_covariance=False, data_range=1.0, full=True)
|
||||||
|
ssim_map = ssim_map * cr1
|
||||||
|
r = int(3.5 * 1.5 + 0.5) # radius as in ndimage
|
||||||
|
win_size = 2 * r + 1
|
||||||
|
pad = (win_size - 1) // 2
|
||||||
|
ssim = ssim_map[pad:-pad, pad:-pad, :]
|
||||||
|
crop_cr1 = cr1[pad:-pad, pad:-pad, :]
|
||||||
|
ssim = ssim.sum(axis=0).sum(axis=0) / crop_cr1.sum(axis=0).sum(axis=0)
|
||||||
|
ssim = np.mean(ssim)
|
||||||
|
return ssim
|
||||||
|
|
||||||
|
total_psnr = 0.
|
||||||
|
total_ssim = 0.
|
||||||
|
count = 0
|
||||||
|
img_path = './out/Stripformer_realblur_R_results'
|
||||||
|
gt_path = './datasets/Realblur_R/test/sharp'
|
||||||
|
print(img_path)
|
||||||
|
for file in os.listdir(img_path):
|
||||||
|
for img_name in os.listdir(img_path + '/' + file):
|
||||||
|
count += 1
|
||||||
|
number = img_name.split('_')[1]
|
||||||
|
gt_name = 'gt_' + number
|
||||||
|
img_dir = img_path + '/' + file + '/' + img_name
|
||||||
|
gt_dir = gt_path + '/' + file + '/' + gt_name
|
||||||
|
with concurrent.futures.ProcessPoolExecutor(max_workers=10) as executor:
|
||||||
|
tar_img = io.imread(gt_dir)
|
||||||
|
prd_img = io.imread(img_dir)
|
||||||
|
tar_img = tar_img.astype(np.float32) / 255.0
|
||||||
|
prd_img = prd_img.astype(np.float32) / 255.0
|
||||||
|
prd_img, tar_img, cr1, shift = image_align(prd_img, tar_img)
|
||||||
|
PSNR = compute_psnr(tar_img, prd_img, cr1, data_range=1)
|
||||||
|
SSIM = compute_ssim(tar_img, prd_img, cr1)
|
||||||
|
total_psnr += PSNR
|
||||||
|
total_ssim += SSIM
|
||||||
|
print(count, PSNR)
|
||||||
|
|
||||||
|
print('PSNR:', total_psnr / count)
|
||||||
|
print('SSIM:', total_ssim / count)
|
||||||
|
print(img_path)
|
||||||
|
|
||||||
@ -0,0 +1,60 @@
|
|||||||
|
p = genpath('.\out\stripformer_10_patches');% GoPro Deblur Results
|
||||||
|
gt = genpath('.\datasets\GoPro\test\sharp');% GoPro GT Results
|
||||||
|
|
||||||
|
length_p = size(p,2);
|
||||||
|
path = {};
|
||||||
|
temp = [];
|
||||||
|
for i = 1:length_p
|
||||||
|
if p(i) ~= ';'
|
||||||
|
temp = [temp p(i)];
|
||||||
|
else
|
||||||
|
temp = [temp '\'];
|
||||||
|
path = [path ; temp];
|
||||||
|
temp = [];
|
||||||
|
end
|
||||||
|
end
|
||||||
|
clear p length_p temp;
|
||||||
|
length_gt = size(gt,2);
|
||||||
|
path_gt = {};
|
||||||
|
temp_gt = [];
|
||||||
|
for i = 1:length_gt
|
||||||
|
if gt(i) ~= ';'
|
||||||
|
temp_gt = [temp_gt gt(i)];
|
||||||
|
else
|
||||||
|
temp_gt = [temp_gt '\'];
|
||||||
|
path_gt = [path_gt ; temp_gt];
|
||||||
|
temp_gt = [];
|
||||||
|
end
|
||||||
|
end
|
||||||
|
clear gt length_gt temp_gt;
|
||||||
|
|
||||||
|
file_num = size(path,1);
|
||||||
|
total_psnr = 0;
|
||||||
|
n = 0;
|
||||||
|
total_ssim = 0;
|
||||||
|
for i = 1:file_num
|
||||||
|
file_path = path{i};
|
||||||
|
gt_file_path = path_gt{i};
|
||||||
|
img_path_list = dir(strcat(file_path,'*.png'));
|
||||||
|
gt_path_list = dir(strcat(gt_file_path,'*.png'));
|
||||||
|
img_num = length(img_path_list);
|
||||||
|
if img_num > 0
|
||||||
|
for j = 1:img_num
|
||||||
|
image_name = img_path_list(j).name;
|
||||||
|
gt_name = gt_path_list(j).name;
|
||||||
|
image = imread(strcat(file_path,image_name));
|
||||||
|
gt = imread(strcat(gt_file_path,gt_name));
|
||||||
|
size(image);
|
||||||
|
size(gt);
|
||||||
|
peaksnr = psnr(image,gt);
|
||||||
|
ssimval = ssim(image,gt);
|
||||||
|
total_psnr = total_psnr + peaksnr;
|
||||||
|
total_ssim = total_ssim + ssimval;
|
||||||
|
n = n + 1
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
psnr = total_psnr / n
|
||||||
|
ssim = total_ssim / n
|
||||||
|
close all;clear all;
|
||||||
|
|
||||||
@ -0,0 +1,60 @@
|
|||||||
|
p = genpath('.\out\Stripformer_HIDE_result');% HIDE Deblur Results
|
||||||
|
gt = genpath('.\datasets\HIDE\sharp');% HIDE GT Results
|
||||||
|
|
||||||
|
length_p = size(p,2);
|
||||||
|
path = {};
|
||||||
|
temp = [];
|
||||||
|
for i = 1:length_p
|
||||||
|
if p(i) ~= ';'
|
||||||
|
temp = [temp p(i)];
|
||||||
|
else
|
||||||
|
temp = [temp '\'];
|
||||||
|
path = [path ; temp];
|
||||||
|
temp = [];
|
||||||
|
end
|
||||||
|
end
|
||||||
|
clear p length_p temp;
|
||||||
|
length_gt = size(gt,2);
|
||||||
|
path_gt = {};
|
||||||
|
temp_gt = [];
|
||||||
|
for i = 1:length_gt
|
||||||
|
if gt(i) ~= ';'
|
||||||
|
temp_gt = [temp_gt gt(i)];
|
||||||
|
else
|
||||||
|
temp_gt = [temp_gt '\'];
|
||||||
|
path_gt = [path_gt ; temp_gt];
|
||||||
|
temp_gt = [];
|
||||||
|
end
|
||||||
|
end
|
||||||
|
clear gt length_gt temp_gt;
|
||||||
|
|
||||||
|
file_num = size(path,1);
|
||||||
|
total_psnr = 0;
|
||||||
|
n = 0;
|
||||||
|
total_ssim = 0;
|
||||||
|
for i = 1:file_num
|
||||||
|
file_path = path{i};
|
||||||
|
gt_file_path = path_gt{i};
|
||||||
|
img_path_list = dir(strcat(file_path,'*.png'));
|
||||||
|
gt_path_list = dir(strcat(gt_file_path,'*.png'));
|
||||||
|
img_num = length(img_path_list);
|
||||||
|
if img_num > 0
|
||||||
|
for j = 1:img_num
|
||||||
|
image_name = img_path_list(j).name;
|
||||||
|
gt_name = gt_path_list(j).name;
|
||||||
|
image = imread(strcat(file_path,image_name));
|
||||||
|
gt = imread(strcat(gt_file_path,gt_name));
|
||||||
|
size(image);
|
||||||
|
size(gt);
|
||||||
|
peaksnr = psnr(image,gt);
|
||||||
|
ssimval = ssim(image,gt);
|
||||||
|
total_psnr = total_psnr + peaksnr;
|
||||||
|
total_ssim = total_ssim + ssimval;
|
||||||
|
n = n + 1
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
psnr = total_psnr / n
|
||||||
|
ssim = total_ssim / n
|
||||||
|
close all;clear all;
|
||||||
|
|
||||||
@ -0,0 +1,55 @@
|
|||||||
|
import logging
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from tensorboardX import SummaryWriter
|
||||||
|
|
||||||
|
WINDOW_SIZE = 100
|
||||||
|
|
||||||
|
|
||||||
|
class MetricCounter:
|
||||||
|
def __init__(self, exp_name):
|
||||||
|
self.writer = SummaryWriter(exp_name)
|
||||||
|
logging.basicConfig(filename='{}.log'.format(exp_name), level=logging.DEBUG)
|
||||||
|
self.metrics = defaultdict(list)
|
||||||
|
self.images = defaultdict(list)
|
||||||
|
self.best_metric = 0
|
||||||
|
|
||||||
|
def add_image(self, x: np.ndarray, tag: str):
|
||||||
|
self.images[tag].append(x)
|
||||||
|
|
||||||
|
def clear(self):
|
||||||
|
self.metrics = defaultdict(list)
|
||||||
|
self.images = defaultdict(list)
|
||||||
|
|
||||||
|
def add_losses(self, l_G):
|
||||||
|
for name, value in zip(('G_loss', None), (l_G, None)):
|
||||||
|
self.metrics[name].append(value)
|
||||||
|
|
||||||
|
def add_metrics(self, psnr, ssim):
|
||||||
|
for name, value in zip(('PSNR', 'SSIM'),
|
||||||
|
(psnr, ssim)):
|
||||||
|
self.metrics[name].append(value)
|
||||||
|
|
||||||
|
def loss_message(self):
|
||||||
|
metrics = ((k, np.mean(self.metrics[k][-WINDOW_SIZE:])) for k in ('G_loss', 'PSNR', 'SSIM'))
|
||||||
|
return '; '.join(map(lambda x: f'{x[0]}={x[1]:.4f}', metrics))
|
||||||
|
|
||||||
|
def write_to_tensorboard(self, epoch_num, validation=False):
|
||||||
|
scalar_prefix = 'Validation' if validation else 'Train'
|
||||||
|
for tag in ('G_loss', 'SSIM', 'PSNR'):
|
||||||
|
self.writer.add_scalar(f'{scalar_prefix}_{tag}', np.mean(self.metrics[tag]), global_step=epoch_num)
|
||||||
|
for tag in self.images:
|
||||||
|
imgs = self.images[tag]
|
||||||
|
if imgs:
|
||||||
|
imgs = np.array(imgs)
|
||||||
|
self.writer.add_images(tag, imgs[:, :, :, ::-1].astype('float32') / 255, dataformats='NHWC',
|
||||||
|
global_step=epoch_num)
|
||||||
|
self.images[tag] = []
|
||||||
|
|
||||||
|
def update_best_model(self):
|
||||||
|
cur_metric = np.mean(self.metrics['PSNR'])
|
||||||
|
if self.best_metric < cur_metric:
|
||||||
|
self.best_metric = cur_metric
|
||||||
|
return True
|
||||||
|
return False
|
||||||
@ -0,0 +1,374 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import math
|
||||||
|
|
||||||
|
class Embeddings(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super(Embeddings, self).__init__()
|
||||||
|
|
||||||
|
self.activation = nn.LeakyReLU(0.2, True)
|
||||||
|
|
||||||
|
self.en_layer1_1 = nn.Sequential(
|
||||||
|
nn.Conv2d(3, 64, kernel_size=3, padding=1),
|
||||||
|
self.activation,
|
||||||
|
)
|
||||||
|
self.en_layer1_2 = nn.Sequential(
|
||||||
|
nn.Conv2d(64, 64, kernel_size=3, padding=1),
|
||||||
|
self.activation,
|
||||||
|
nn.Conv2d(64, 64, kernel_size=3, padding=1))
|
||||||
|
self.en_layer1_3 = nn.Sequential(
|
||||||
|
nn.Conv2d(64, 64, kernel_size=3, padding=1),
|
||||||
|
self.activation,
|
||||||
|
nn.Conv2d(64, 64, kernel_size=3, padding=1))
|
||||||
|
self.en_layer1_4 = nn.Sequential(
|
||||||
|
nn.Conv2d(64, 64, kernel_size=3, padding=1),
|
||||||
|
self.activation,
|
||||||
|
nn.Conv2d(64, 64, kernel_size=3, padding=1))
|
||||||
|
|
||||||
|
self.en_layer2_1 = nn.Sequential(
|
||||||
|
nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
|
||||||
|
self.activation,
|
||||||
|
)
|
||||||
|
self.en_layer2_2 = nn.Sequential(
|
||||||
|
nn.Conv2d(128, 128, kernel_size=3, padding=1),
|
||||||
|
self.activation,
|
||||||
|
nn.Conv2d(128, 128, kernel_size=3, padding=1))
|
||||||
|
self.en_layer2_3 = nn.Sequential(
|
||||||
|
nn.Conv2d(128, 128, kernel_size=3, padding=1),
|
||||||
|
self.activation,
|
||||||
|
nn.Conv2d(128, 128, kernel_size=3, padding=1))
|
||||||
|
self.en_layer2_4 = nn.Sequential(
|
||||||
|
nn.Conv2d(128, 128, kernel_size=3, padding=1),
|
||||||
|
self.activation,
|
||||||
|
nn.Conv2d(128, 128, kernel_size=3, padding=1))
|
||||||
|
|
||||||
|
|
||||||
|
self.en_layer3_1 = nn.Sequential(
|
||||||
|
nn.Conv2d(128, 320, kernel_size=3, stride=2, padding=1),
|
||||||
|
self.activation,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
|
||||||
|
hx = self.en_layer1_1(x)
|
||||||
|
hx = self.activation(self.en_layer1_2(hx) + hx)
|
||||||
|
hx = self.activation(self.en_layer1_3(hx) + hx)
|
||||||
|
hx = self.activation(self.en_layer1_4(hx) + hx)
|
||||||
|
residual_1 = hx
|
||||||
|
hx = self.en_layer2_1(hx)
|
||||||
|
hx = self.activation(self.en_layer2_2(hx) + hx)
|
||||||
|
hx = self.activation(self.en_layer2_3(hx) + hx)
|
||||||
|
hx = self.activation(self.en_layer2_4(hx) + hx)
|
||||||
|
residual_2 = hx
|
||||||
|
hx = self.en_layer3_1(hx)
|
||||||
|
|
||||||
|
return hx, residual_1, residual_2
|
||||||
|
|
||||||
|
|
||||||
|
class Embeddings_output(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super(Embeddings_output, self).__init__()
|
||||||
|
|
||||||
|
self.activation = nn.LeakyReLU(0.2, True)
|
||||||
|
|
||||||
|
self.de_layer3_1 = nn.Sequential(
|
||||||
|
nn.ConvTranspose2d(320, 192, kernel_size=4, stride=2, padding=1),
|
||||||
|
self.activation,
|
||||||
|
)
|
||||||
|
head_num = 3
|
||||||
|
dim = 192
|
||||||
|
|
||||||
|
self.de_layer2_2 = nn.Sequential(
|
||||||
|
nn.Conv2d(192+128, 192, kernel_size=1, padding=0),
|
||||||
|
self.activation,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.de_block_1 = Intra_SA(dim, head_num)
|
||||||
|
self.de_block_2 = Inter_SA(dim, head_num)
|
||||||
|
self.de_block_3 = Intra_SA(dim, head_num)
|
||||||
|
self.de_block_4 = Inter_SA(dim, head_num)
|
||||||
|
self.de_block_5 = Intra_SA(dim, head_num)
|
||||||
|
self.de_block_6 = Inter_SA(dim, head_num)
|
||||||
|
|
||||||
|
|
||||||
|
self.de_layer2_1 = nn.Sequential(
|
||||||
|
nn.ConvTranspose2d(192, 64, kernel_size=4, stride=2, padding=1),
|
||||||
|
self.activation,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.de_layer1_3 = nn.Sequential(
|
||||||
|
nn.Conv2d(128, 64, kernel_size=1, padding=0),
|
||||||
|
self.activation,
|
||||||
|
nn.Conv2d(64, 64, kernel_size=3, padding=1))
|
||||||
|
self.de_layer1_2 = nn.Sequential(
|
||||||
|
nn.Conv2d(64, 64, kernel_size=3, padding=1),
|
||||||
|
self.activation,
|
||||||
|
nn.Conv2d(64, 64, kernel_size=3, padding=1))
|
||||||
|
self.de_layer1_1 = nn.Sequential(
|
||||||
|
nn.Conv2d(64, 3, kernel_size=3, padding=1),
|
||||||
|
self.activation
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x, residual_1, residual_2):
|
||||||
|
|
||||||
|
|
||||||
|
hx = self.de_layer3_1(x)
|
||||||
|
|
||||||
|
hx = self.de_layer2_2(torch.cat((hx, residual_2), dim = 1))
|
||||||
|
hx = self.de_block_1(hx)
|
||||||
|
hx = self.de_block_2(hx)
|
||||||
|
hx = self.de_block_3(hx)
|
||||||
|
hx = self.de_block_4(hx)
|
||||||
|
hx = self.de_block_5(hx)
|
||||||
|
hx = self.de_block_6(hx)
|
||||||
|
hx = self.de_layer2_1(hx)
|
||||||
|
|
||||||
|
hx = self.activation(self.de_layer1_3(torch.cat((hx, residual_1), dim = 1)) + hx)
|
||||||
|
hx = self.activation(self.de_layer1_2(hx) + hx)
|
||||||
|
hx = self.de_layer1_1(hx)
|
||||||
|
|
||||||
|
return hx
|
||||||
|
|
||||||
|
class Attention(nn.Module):
|
||||||
|
def __init__(self, head_num):
|
||||||
|
super(Attention, self).__init__()
|
||||||
|
self.num_attention_heads = head_num
|
||||||
|
self.softmax = nn.Softmax(dim=-1)
|
||||||
|
|
||||||
|
def transpose_for_scores(self, x):
|
||||||
|
B, N, C = x.size()
|
||||||
|
attention_head_size = int(C / self.num_attention_heads)
|
||||||
|
new_x_shape = x.size()[:-1] + (self.num_attention_heads, attention_head_size)
|
||||||
|
x = x.view(*new_x_shape)
|
||||||
|
return x.permute(0, 2, 1, 3).contiguous()
|
||||||
|
|
||||||
|
def forward(self, query_layer, key_layer, value_layer):
|
||||||
|
B, N, C = query_layer.size()
|
||||||
|
query_layer = self.transpose_for_scores(query_layer)
|
||||||
|
key_layer = self.transpose_for_scores(key_layer)
|
||||||
|
value_layer = self.transpose_for_scores(value_layer)
|
||||||
|
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
||||||
|
_, _, _, d = query_layer.size()
|
||||||
|
attention_scores = attention_scores / math.sqrt(d)
|
||||||
|
attention_probs = self.softmax(attention_scores)
|
||||||
|
context_layer = torch.matmul(attention_probs, value_layer)
|
||||||
|
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||||||
|
new_context_layer_shape = context_layer.size()[:-2] + (C,)
|
||||||
|
attention_out = context_layer.view(*new_context_layer_shape)
|
||||||
|
|
||||||
|
return attention_out
|
||||||
|
|
||||||
|
|
||||||
|
class Mlp(nn.Module):
|
||||||
|
def __init__(self, hidden_size):
|
||||||
|
super(Mlp, self).__init__()
|
||||||
|
self.fc1 = nn.Linear(hidden_size, 4*hidden_size)
|
||||||
|
self.fc2 = nn.Linear(4*hidden_size, hidden_size)
|
||||||
|
self.act_fn = torch.nn.functional.gelu
|
||||||
|
self._init_weights()
|
||||||
|
|
||||||
|
def _init_weights(self):
|
||||||
|
nn.init.xavier_uniform_(self.fc1.weight)
|
||||||
|
nn.init.xavier_uniform_(self.fc2.weight)
|
||||||
|
nn.init.normal_(self.fc1.bias, std=1e-6)
|
||||||
|
nn.init.normal_(self.fc2.bias, std=1e-6)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.fc1(x)
|
||||||
|
x = self.act_fn(x)
|
||||||
|
x = self.fc2(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
# CPE (Conditional Positional Embedding)
|
||||||
|
class PEG(nn.Module):
|
||||||
|
def __init__(self, hidden_size):
|
||||||
|
super(PEG, self).__init__()
|
||||||
|
self.PEG = nn.Conv2d(hidden_size, hidden_size, kernel_size=3, padding=1, groups=hidden_size)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.PEG(x) + x
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Intra_SA(nn.Module):
|
||||||
|
def __init__(self, dim, head_num):
|
||||||
|
super(Intra_SA, self).__init__()
|
||||||
|
self.hidden_size = dim // 2
|
||||||
|
self.head_num = head_num
|
||||||
|
self.attention_norm = nn.LayerNorm(dim)
|
||||||
|
self.conv_input = nn.Conv2d(dim, dim, kernel_size=1, padding=0)
|
||||||
|
self.qkv_local_h = nn.Linear(self.hidden_size, self.hidden_size * 3) # qkv_h
|
||||||
|
self.qkv_local_v = nn.Linear(self.hidden_size, self.hidden_size * 3) # qkv_v
|
||||||
|
self.fuse_out = nn.Conv2d(dim, dim, kernel_size=1, padding=0)
|
||||||
|
self.ffn_norm = nn.LayerNorm(dim)
|
||||||
|
self.ffn = Mlp(dim)
|
||||||
|
self.attn = Attention(head_num=self.head_num)
|
||||||
|
self.PEG = PEG(dim)
|
||||||
|
def forward(self, x):
|
||||||
|
h = x
|
||||||
|
B, C, H, W = x.size()
|
||||||
|
|
||||||
|
x = x.view(B, C, H*W).permute(0, 2, 1).contiguous()
|
||||||
|
x = self.attention_norm(x).permute(0, 2, 1).contiguous()
|
||||||
|
x = x.view(B, C, H, W)
|
||||||
|
|
||||||
|
x_input = torch.chunk(self.conv_input(x), 2, dim=1)
|
||||||
|
feature_h = (x_input[0]).permute(0, 2, 3, 1).contiguous()
|
||||||
|
feature_h = feature_h.view(B * H, W, C//2)
|
||||||
|
feature_v = (x_input[1]).permute(0, 3, 2, 1).contiguous()
|
||||||
|
feature_v = feature_v.view(B * W, H, C//2)
|
||||||
|
qkv_h = torch.chunk(self.qkv_local_h(feature_h), 3, dim=2)
|
||||||
|
qkv_v = torch.chunk(self.qkv_local_v(feature_v), 3, dim=2)
|
||||||
|
q_h, k_h, v_h = qkv_h[0], qkv_h[1], qkv_h[2]
|
||||||
|
q_v, k_v, v_v = qkv_v[0], qkv_v[1], qkv_v[2]
|
||||||
|
|
||||||
|
if H == W:
|
||||||
|
query = torch.cat((q_h, q_v), dim=0)
|
||||||
|
key = torch.cat((k_h, k_v), dim=0)
|
||||||
|
value = torch.cat((v_h, v_v), dim=0)
|
||||||
|
attention_output = self.attn(query, key, value)
|
||||||
|
attention_output = torch.chunk(attention_output, 2, dim=0)
|
||||||
|
attention_output_h = attention_output[0]
|
||||||
|
attention_output_v = attention_output[1]
|
||||||
|
attention_output_h = attention_output_h.view(B, H, W, C//2).permute(0, 3, 1, 2).contiguous()
|
||||||
|
attention_output_v = attention_output_v.view(B, W, H, C//2).permute(0, 3, 2, 1).contiguous()
|
||||||
|
attn_out = self.fuse_out(torch.cat((attention_output_h, attention_output_v), dim=1))
|
||||||
|
else:
|
||||||
|
attention_output_h = self.attn(q_h, k_h, v_h)
|
||||||
|
attention_output_v = self.attn(q_v, k_v, v_v)
|
||||||
|
attention_output_h = attention_output_h.view(B, H, W, C//2).permute(0, 3, 1, 2).contiguous()
|
||||||
|
attention_output_v = attention_output_v.view(B, W, H, C//2).permute(0, 3, 2, 1).contiguous()
|
||||||
|
attn_out = self.fuse_out(torch.cat((attention_output_h, attention_output_v), dim=1))
|
||||||
|
|
||||||
|
x = attn_out + h
|
||||||
|
x = x.view(B, C, H*W).permute(0, 2, 1).contiguous()
|
||||||
|
h = x
|
||||||
|
x = self.ffn_norm(x)
|
||||||
|
x = self.ffn(x)
|
||||||
|
x = x + h
|
||||||
|
x = x.permute(0, 2, 1).contiguous()
|
||||||
|
x = x.view(B, C, H, W)
|
||||||
|
|
||||||
|
x = self.PEG(x)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
class Inter_SA(nn.Module):
|
||||||
|
def __init__(self,dim, head_num):
|
||||||
|
super(Inter_SA, self).__init__()
|
||||||
|
self.hidden_size = dim
|
||||||
|
self.head_num = head_num
|
||||||
|
self.attention_norm = nn.LayerNorm(self.hidden_size)
|
||||||
|
self.conv_input = nn.Conv2d(self.hidden_size, self.hidden_size, kernel_size=1, padding=0)
|
||||||
|
self.conv_h = nn.Conv2d(self.hidden_size//2, 3 * (self.hidden_size//2), kernel_size=1, padding=0) # qkv_h
|
||||||
|
self.conv_v = nn.Conv2d(self.hidden_size//2, 3 * (self.hidden_size//2), kernel_size=1, padding=0) # qkv_v
|
||||||
|
self.ffn_norm = nn.LayerNorm(self.hidden_size)
|
||||||
|
self.ffn = Mlp(self.hidden_size)
|
||||||
|
self.fuse_out = nn.Conv2d(self.hidden_size, self.hidden_size, kernel_size=1, padding=0)
|
||||||
|
self.attn = Attention(head_num=self.head_num)
|
||||||
|
self.PEG = PEG(dim)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
h = x
|
||||||
|
B, C, H, W = x.size()
|
||||||
|
|
||||||
|
x = x.view(B, C, H*W).permute(0, 2, 1).contiguous()
|
||||||
|
x = self.attention_norm(x).permute(0, 2, 1).contiguous()
|
||||||
|
x = x.view(B, C, H, W)
|
||||||
|
|
||||||
|
x_input = torch.chunk(self.conv_input(x), 2, dim=1)
|
||||||
|
feature_h = torch.chunk(self.conv_h(x_input[0]), 3, dim=1)
|
||||||
|
feature_v = torch.chunk(self.conv_v(x_input[1]), 3, dim=1)
|
||||||
|
query_h, key_h, value_h = feature_h[0], feature_h[1], feature_h[2]
|
||||||
|
query_v, key_v, value_v = feature_v[0], feature_v[1], feature_v[2]
|
||||||
|
|
||||||
|
horizontal_groups = torch.cat((query_h, key_h, value_h), dim=0)
|
||||||
|
horizontal_groups = horizontal_groups.permute(0, 2, 1, 3).contiguous()
|
||||||
|
horizontal_groups = horizontal_groups.view(3*B, H, -1)
|
||||||
|
horizontal_groups = torch.chunk(horizontal_groups, 3, dim=0)
|
||||||
|
query_h, key_h, value_h = horizontal_groups[0], horizontal_groups[1], horizontal_groups[2]
|
||||||
|
|
||||||
|
vertical_groups = torch.cat((query_v, key_v, value_v), dim=0)
|
||||||
|
vertical_groups = vertical_groups.permute(0, 3, 1, 2).contiguous()
|
||||||
|
vertical_groups = vertical_groups.view(3*B, W, -1)
|
||||||
|
vertical_groups = torch.chunk(vertical_groups, 3, dim=0)
|
||||||
|
query_v, key_v, value_v = vertical_groups[0], vertical_groups[1], vertical_groups[2]
|
||||||
|
|
||||||
|
|
||||||
|
if H == W:
|
||||||
|
query = torch.cat((query_h, query_v), dim=0)
|
||||||
|
key = torch.cat((key_h, key_v), dim=0)
|
||||||
|
value = torch.cat((value_h, value_v), dim=0)
|
||||||
|
attention_output = self.attn(query, key, value)
|
||||||
|
attention_output = torch.chunk(attention_output, 2, dim=0)
|
||||||
|
attention_output_h = attention_output[0]
|
||||||
|
attention_output_v = attention_output[1]
|
||||||
|
attention_output_h = attention_output_h.view(B, H, C//2, W).permute(0, 2, 1, 3).contiguous()
|
||||||
|
attention_output_v = attention_output_v.view(B, W, C//2, H).permute(0, 2, 3, 1).contiguous()
|
||||||
|
attn_out = self.fuse_out(torch.cat((attention_output_h, attention_output_v), dim=1))
|
||||||
|
else:
|
||||||
|
attention_output_h = self.attn(query_h, key_h, value_h)
|
||||||
|
attention_output_v = self.attn(query_v, key_v, value_v)
|
||||||
|
attention_output_h = attention_output_h.view(B, H, C//2, W).permute(0, 2, 1, 3).contiguous()
|
||||||
|
attention_output_v = attention_output_v.view(B, W, C//2, H).permute(0, 2, 3, 1).contiguous()
|
||||||
|
attn_out = self.fuse_out(torch.cat((attention_output_h, attention_output_v), dim=1))
|
||||||
|
|
||||||
|
x = attn_out + h
|
||||||
|
x = x.view(B, C, H*W).permute(0, 2, 1).contiguous()
|
||||||
|
h = x
|
||||||
|
x = self.ffn_norm(x)
|
||||||
|
x = self.ffn(x)
|
||||||
|
x = x + h
|
||||||
|
x = x.permute(0, 2, 1).contiguous()
|
||||||
|
x = x.view(B, C, H, W)
|
||||||
|
|
||||||
|
x = self.PEG(x)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
class Stripformer(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super(Stripformer, self).__init__()
|
||||||
|
|
||||||
|
self.encoder = Embeddings()
|
||||||
|
head_num = 5
|
||||||
|
dim = 320
|
||||||
|
self.Trans_block_1 = Intra_SA(dim, head_num)
|
||||||
|
self.Trans_block_2 = Inter_SA(dim, head_num)
|
||||||
|
self.Trans_block_3 = Intra_SA(dim, head_num)
|
||||||
|
self.Trans_block_4 = Inter_SA(dim, head_num)
|
||||||
|
self.Trans_block_5 = Intra_SA(dim, head_num)
|
||||||
|
self.Trans_block_6 = Inter_SA(dim, head_num)
|
||||||
|
self.Trans_block_7 = Intra_SA(dim, head_num)
|
||||||
|
self.Trans_block_8 = Inter_SA(dim, head_num)
|
||||||
|
self.Trans_block_9 = Intra_SA(dim, head_num)
|
||||||
|
self.Trans_block_10 = Inter_SA(dim, head_num)
|
||||||
|
self.Trans_block_11 = Intra_SA(dim, head_num)
|
||||||
|
self.Trans_block_12 = Inter_SA(dim, head_num)
|
||||||
|
self.decoder = Embeddings_output()
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
|
||||||
|
hx, residual_1, residual_2 = self.encoder(x)
|
||||||
|
hx = self.Trans_block_1(hx)
|
||||||
|
hx = self.Trans_block_2(hx)
|
||||||
|
hx = self.Trans_block_3(hx)
|
||||||
|
hx = self.Trans_block_4(hx)
|
||||||
|
hx = self.Trans_block_5(hx)
|
||||||
|
hx = self.Trans_block_6(hx)
|
||||||
|
hx = self.Trans_block_7(hx)
|
||||||
|
hx = self.Trans_block_8(hx)
|
||||||
|
hx = self.Trans_block_9(hx)
|
||||||
|
hx = self.Trans_block_10(hx)
|
||||||
|
hx = self.Trans_block_11(hx)
|
||||||
|
hx = self.Trans_block_12(hx)
|
||||||
|
hx = self.decoder(hx, residual_1, residual_2)
|
||||||
|
|
||||||
|
return hx + x
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@ -0,0 +1,131 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torchvision.models as models
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
class Vgg19(torch.nn.Module):
|
||||||
|
def __init__(self, requires_grad=False):
|
||||||
|
super(Vgg19, self).__init__()
|
||||||
|
vgg_pretrained_features = models.vgg19(pretrained=True).features
|
||||||
|
self.slice1 = torch.nn.Sequential()
|
||||||
|
|
||||||
|
for x in range(12):
|
||||||
|
self.slice1.add_module(str(x), vgg_pretrained_features[x])
|
||||||
|
|
||||||
|
if not requires_grad:
|
||||||
|
for param in self.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
|
def forward(self, X):
|
||||||
|
h_relu1 = self.slice1(X)
|
||||||
|
return h_relu1
|
||||||
|
|
||||||
|
class ContrastLoss(nn.Module):
|
||||||
|
def __init__(self, ablation=False):
|
||||||
|
|
||||||
|
super(ContrastLoss, self).__init__()
|
||||||
|
self.vgg = Vgg19().cuda()
|
||||||
|
self.l1 = nn.L1Loss()
|
||||||
|
self.ab = ablation
|
||||||
|
self.down_sample_4 = nn.Upsample(scale_factor=1 / 4, mode='bilinear')
|
||||||
|
def forward(self, restore, sharp, blur):
|
||||||
|
B, C, H, W = restore.size()
|
||||||
|
restore_vgg, sharp_vgg, blur_vgg = self.vgg(restore), self.vgg(sharp), self.vgg(blur)
|
||||||
|
|
||||||
|
# filter out sharp regions
|
||||||
|
threshold = 0.01
|
||||||
|
mask = torch.mean(torch.abs(sharp-blur), dim=1).view(B, 1, H, W)
|
||||||
|
mask[mask <= threshold] = 0
|
||||||
|
mask[mask > threshold] = 1
|
||||||
|
mask = self.down_sample_4(mask)
|
||||||
|
d_ap = torch.mean(torch.abs((restore_vgg - sharp_vgg.detach())), dim=1).view(B, 1, H//4, W//4)
|
||||||
|
d_an = torch.mean(torch.abs((restore_vgg - blur_vgg.detach())), dim=1).view(B, 1, H//4, W//4)
|
||||||
|
mask_size = torch.sum(mask)
|
||||||
|
contrastive = torch.sum((d_ap / (d_an + 1e-7)) * mask) / mask_size
|
||||||
|
|
||||||
|
return contrastive
|
||||||
|
|
||||||
|
|
||||||
|
class ContrastLoss_Ori(nn.Module):
|
||||||
|
def __init__(self, ablation=False):
|
||||||
|
super(ContrastLoss_Ori, self).__init__()
|
||||||
|
self.vgg = Vgg19().cuda()
|
||||||
|
self.l1 = nn.L1Loss()
|
||||||
|
self.ab = ablation
|
||||||
|
|
||||||
|
def forward(self, restore, sharp, blur):
|
||||||
|
|
||||||
|
restore_vgg, sharp_vgg, blur_vgg = self.vgg(restore), self.vgg(sharp), self.vgg(blur)
|
||||||
|
d_ap = self.l1(restore_vgg, sharp_vgg.detach())
|
||||||
|
d_an = self.l1(restore_vgg, blur_vgg.detach())
|
||||||
|
contrastive_loss = d_ap / (d_an + 1e-7)
|
||||||
|
|
||||||
|
return contrastive_loss
|
||||||
|
|
||||||
|
class CharbonnierLoss(nn.Module):
|
||||||
|
"""Charbonnier Loss (L1)"""
|
||||||
|
|
||||||
|
def __init__(self, eps=1e-3):
|
||||||
|
super(CharbonnierLoss, self).__init__()
|
||||||
|
self.eps = eps
|
||||||
|
|
||||||
|
def forward(self, x, y):
|
||||||
|
diff = x - y
|
||||||
|
# loss = torch.sum(torch.sqrt(diff * diff + self.eps))
|
||||||
|
loss = torch.mean(torch.sqrt((diff * diff) + (self.eps * self.eps)))
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
class EdgeLoss(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super(EdgeLoss, self).__init__()
|
||||||
|
k = torch.Tensor([[.05, .25, .4, .25, .05]])
|
||||||
|
self.kernel = torch.matmul(k.t(), k).unsqueeze(0).repeat(3, 1, 1, 1)
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
self.kernel = self.kernel.cuda()
|
||||||
|
self.loss = CharbonnierLoss()
|
||||||
|
|
||||||
|
def conv_gauss(self, img):
|
||||||
|
n_channels, _, kw, kh = self.kernel.shape
|
||||||
|
img = F.pad(img, (kw // 2, kh // 2, kw // 2, kh // 2), mode='replicate')
|
||||||
|
return F.conv2d(img, self.kernel, groups=n_channels)
|
||||||
|
|
||||||
|
def laplacian_kernel(self, current):
|
||||||
|
filtered = self.conv_gauss(current) # filter
|
||||||
|
down = filtered[:, :, ::2, ::2] # downsample
|
||||||
|
new_filter = torch.zeros_like(filtered)
|
||||||
|
new_filter[:, :, ::2, ::2] = down * 4 # upsample
|
||||||
|
filtered = self.conv_gauss(new_filter) # filter
|
||||||
|
diff = current - filtered
|
||||||
|
return diff
|
||||||
|
|
||||||
|
def forward(self, x, y):
|
||||||
|
# x = torch.clamp(x + 0.5, min = 0,max = 1)
|
||||||
|
# y = torch.clamp(y + 0.5, min = 0,max = 1)
|
||||||
|
loss = self.loss(self.laplacian_kernel(x), self.laplacian_kernel(y))
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
class Stripformer_Loss(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, ):
|
||||||
|
super(Stripformer_Loss, self).__init__()
|
||||||
|
|
||||||
|
self.char = CharbonnierLoss()
|
||||||
|
self.edge = EdgeLoss()
|
||||||
|
self.contrastive = ContrastLoss()
|
||||||
|
|
||||||
|
def forward(self, restore, sharp, blur):
|
||||||
|
char = self.char(restore, sharp)
|
||||||
|
edge = 0.05 * self.edge(restore, sharp)
|
||||||
|
contrastive = 0.0005 * self.contrastive(restore, sharp, blur)
|
||||||
|
loss = char + edge + contrastive
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
def get_loss(model):
|
||||||
|
if model['content_loss'] == 'Stripformer_Loss':
|
||||||
|
content_loss = Stripformer_Loss()
|
||||||
|
else:
|
||||||
|
raise ValueError("ContentLoss [%s] not recognized." % model['content_loss'])
|
||||||
|
return content_loss
|
||||||
@ -0,0 +1,35 @@
|
|||||||
|
import numpy as np
|
||||||
|
import torch.nn as nn
|
||||||
|
from skimage.measure import compare_ssim as SSIM
|
||||||
|
|
||||||
|
from util.metrics import PSNR
|
||||||
|
|
||||||
|
|
||||||
|
class DeblurModel(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super(DeblurModel, self).__init__()
|
||||||
|
|
||||||
|
def get_input(self, data):
|
||||||
|
img = data['a']
|
||||||
|
inputs = img
|
||||||
|
targets = data['b']
|
||||||
|
inputs, targets = inputs.cuda(), targets.cuda()
|
||||||
|
return inputs, targets
|
||||||
|
|
||||||
|
def tensor2im(self, image_tensor, imtype=np.uint8):
|
||||||
|
image_numpy = image_tensor[0].cpu().float().numpy()
|
||||||
|
image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 0.5) * 255.0
|
||||||
|
return image_numpy
|
||||||
|
|
||||||
|
def get_images_and_metrics(self, inp, output, target) -> (float, float, np.ndarray):
|
||||||
|
inp = self.tensor2im(inp)
|
||||||
|
fake = self.tensor2im(output.data)
|
||||||
|
real = self.tensor2im(target.data)
|
||||||
|
psnr = PSNR(fake, real)
|
||||||
|
ssim = SSIM(fake.astype('uint8'), real.astype('uint8'), multichannel=True)
|
||||||
|
vis_img = np.hstack((inp, fake, real))
|
||||||
|
return psnr, ssim, vis_img
|
||||||
|
|
||||||
|
|
||||||
|
def get_model(model_config):
|
||||||
|
return DeblurModel()
|
||||||
@ -0,0 +1,13 @@
|
|||||||
|
import torch.nn as nn
|
||||||
|
from models.Stripformer import Stripformer
|
||||||
|
|
||||||
|
def get_generator(model_config):
|
||||||
|
generator_name = model_config['g_name']
|
||||||
|
if generator_name == 'Stripformer':
|
||||||
|
model_g = Stripformer()
|
||||||
|
else:
|
||||||
|
raise ValueError("Generator Network [%s] not recognized." % generator_name)
|
||||||
|
return nn.DataParallel(model_g)
|
||||||
|
|
||||||
|
def get_nets(model_config):
|
||||||
|
return get_generator(model_config)
|
||||||
@ -0,0 +1 @@
|
|||||||
|
testing results are created in this folder
|
||||||
@ -0,0 +1,68 @@
|
|||||||
|
from __future__ import print_function
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import cv2
|
||||||
|
import yaml
|
||||||
|
import os
|
||||||
|
from torch.autograd import Variable
|
||||||
|
from models.networks import get_generator
|
||||||
|
import torchvision
|
||||||
|
import time
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
parser = argparse.ArgumentParser('Test an image')
|
||||||
|
parser.add_argument('--weights_path', required=True, help='Weights path')
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
args = get_args()
|
||||||
|
with open('config/config_Stripformer_gopro.yaml') as cfg:
|
||||||
|
config = yaml.load(cfg)
|
||||||
|
blur_path = './datasets/GoPro/test/blur/'
|
||||||
|
out_path = './out/Stripformer_GoPro_results'
|
||||||
|
if not os.path.isdir(out_path):
|
||||||
|
os.mkdir(out_path)
|
||||||
|
model = get_generator(config['model'])
|
||||||
|
model.load_state_dict(torch.load(args.weights_path))
|
||||||
|
model = model.cuda()
|
||||||
|
|
||||||
|
test_time = 0
|
||||||
|
iteration = 0
|
||||||
|
total_image_number = 1111
|
||||||
|
|
||||||
|
# warm-up
|
||||||
|
warm_up = 0
|
||||||
|
print('Hardware warm-up')
|
||||||
|
for file in os.listdir(blur_path):
|
||||||
|
for img_name in os.listdir(blur_path + '/' + file):
|
||||||
|
warm_up += 1
|
||||||
|
img = cv2.imread(blur_path + '/' + file + '/' + img_name)
|
||||||
|
img_tensor = torch.from_numpy(np.transpose(img / 255, (2, 0, 1)).astype('float32')) - 0.5
|
||||||
|
with torch.no_grad():
|
||||||
|
img_tensor = Variable(img_tensor.unsqueeze(0)).cuda()
|
||||||
|
result_image = model(img_tensor)
|
||||||
|
if warm_up == 20:
|
||||||
|
break
|
||||||
|
break
|
||||||
|
|
||||||
|
for file in os.listdir(blur_path):
|
||||||
|
if not os.path.isdir(out_path + '/' + file):
|
||||||
|
os.mkdir(out_path + '/' + file)
|
||||||
|
for img_name in os.listdir(blur_path + '/' + file):
|
||||||
|
img = cv2.imread(blur_path + '/' + file + '/' + img_name)
|
||||||
|
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||||
|
img_tensor = torch.from_numpy(np.transpose(img / 255, (2, 0, 1)).astype('float32')) - 0.5
|
||||||
|
with torch.no_grad():
|
||||||
|
iteration += 1
|
||||||
|
img_tensor = Variable(img_tensor.unsqueeze(0)).cuda()
|
||||||
|
|
||||||
|
start = time.time()
|
||||||
|
result_image = model(img_tensor)
|
||||||
|
stop = time.time()
|
||||||
|
print('Image:{}/{}, CNN Runtime:{:.4f}'.format(iteration, total_image_number, (stop - start)))
|
||||||
|
test_time += stop - start
|
||||||
|
print('Average Runtime:{:.4f}'.format(test_time / float(iteration)))
|
||||||
|
result_image = result_image + 0.5
|
||||||
|
out_file_name = out_path + '/' + file + '/' + img_name
|
||||||
|
torchvision.utils.save_image(result_image, out_file_name)
|
||||||
@ -0,0 +1,64 @@
|
|||||||
|
from __future__ import print_function
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import cv2
|
||||||
|
import yaml
|
||||||
|
import os
|
||||||
|
from torch.autograd import Variable
|
||||||
|
from models.networks import get_generator
|
||||||
|
import torchvision
|
||||||
|
import time
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
parser = argparse.ArgumentParser('Test an image')
|
||||||
|
parser.add_argument('--weights_path', required=True, help='Weights path')
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
args = get_args()
|
||||||
|
with open('config/config_Stripformer_gopro.yaml') as cfg:
|
||||||
|
config = yaml.load(cfg)
|
||||||
|
blur_path = './datasets/HIDE/blur/'
|
||||||
|
out_path = './out/Stripformer_HIDE_results'
|
||||||
|
if not os.path.isdir(out_path):
|
||||||
|
os.mkdir(out_path)
|
||||||
|
model = get_generator(config['model'])
|
||||||
|
model.load_state_dict(torch.load(args.weights_path))
|
||||||
|
model = model.cuda()
|
||||||
|
test_time = 0
|
||||||
|
iteration = 0
|
||||||
|
total_image_number = 2025
|
||||||
|
|
||||||
|
# warm up
|
||||||
|
warm_up = 0
|
||||||
|
print('Hardware warm-up')
|
||||||
|
for img_name in os.listdir(blur_path):
|
||||||
|
warm_up += 1
|
||||||
|
img = cv2.imread(blur_path + '/' + img_name)
|
||||||
|
img_tensor = torch.from_numpy(np.transpose(img / 255, (2, 0, 1)).astype('float32')) - 0.5
|
||||||
|
with torch.no_grad():
|
||||||
|
img_tensor = Variable(img_tensor.unsqueeze(0)).cuda()
|
||||||
|
result_image = model(img_tensor)
|
||||||
|
if warm_up == 20:
|
||||||
|
break
|
||||||
|
break
|
||||||
|
|
||||||
|
for img_name in os.listdir(blur_path):
|
||||||
|
img = cv2.imread(blur_path + '/' + img_name)
|
||||||
|
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||||
|
img_tensor = torch.from_numpy(np.transpose(img / 255, (2, 0, 1)).astype('float32')) - 0.5
|
||||||
|
with torch.no_grad():
|
||||||
|
iteration += 1
|
||||||
|
img_tensor = Variable(img_tensor.unsqueeze(0)).cuda()
|
||||||
|
|
||||||
|
start = time.time()
|
||||||
|
result_image = model(img_tensor)
|
||||||
|
stop = time.time()
|
||||||
|
|
||||||
|
print('Image:{}/{}, CNN Runtime:{:.4f}'.format(iteration, total_image_number, (stop - start)))
|
||||||
|
test_time += stop - start
|
||||||
|
print('Average Runtime:{:.4f}'.format(test_time / float(iteration)))
|
||||||
|
result_image = result_image + 0.5
|
||||||
|
out_file_name = out_path + '/' + img_name
|
||||||
|
torchvision.utils.save_image(result_image, out_file_name)
|
||||||
@ -0,0 +1,90 @@
|
|||||||
|
from __future__ import print_function
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import cv2
|
||||||
|
import yaml
|
||||||
|
import os
|
||||||
|
from torch.autograd import Variable
|
||||||
|
from models.networks import get_generator
|
||||||
|
import torchvision
|
||||||
|
import time
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
parser = argparse.ArgumentParser('Test an image')
|
||||||
|
parser.add_argument('--weights_path', required=True, help='Weights path')
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
args = get_args()
|
||||||
|
with open('config/config_Stripformer_gopro.yaml') as cfg:
|
||||||
|
config = yaml.load(cfg)
|
||||||
|
blur_path = './datasets/Realblur_J/test/blur/'
|
||||||
|
out_path = './out/Stripformer_realblur_J_results'
|
||||||
|
if not os.path.isdir(out_path):
|
||||||
|
os.mkdir(out_path)
|
||||||
|
model = get_generator(config['model'])
|
||||||
|
model.load_state_dict(torch.load(args.weights_path))
|
||||||
|
model = model.cuda()
|
||||||
|
test_time = 0
|
||||||
|
iteration = 0
|
||||||
|
total_image_number = 980
|
||||||
|
|
||||||
|
# warm up
|
||||||
|
warm_up = 0
|
||||||
|
print('Hardware warm-up')
|
||||||
|
for file in os.listdir(blur_path):
|
||||||
|
if not os.path.isdir(out_path + '/' + file):
|
||||||
|
os.mkdir(out_path + '/' + file)
|
||||||
|
for img_name in os.listdir(blur_path + '/' + file):
|
||||||
|
warm_up += 1
|
||||||
|
img = cv2.imread(blur_path + '/' + file + '/' + img_name)
|
||||||
|
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||||
|
img_tensor = torch.from_numpy(np.transpose(img / 255, (2, 0, 1)).astype('float32')) - 0.5
|
||||||
|
with torch.no_grad():
|
||||||
|
img_tensor = Variable(img_tensor.unsqueeze(0)).cuda()
|
||||||
|
factor = 8
|
||||||
|
h, w = img_tensor.shape[2], img_tensor.shape[3]
|
||||||
|
H, W = ((h + factor) // factor) * factor, ((w + factor) // factor) * factor
|
||||||
|
padh = H - h if h % factor != 0 else 0
|
||||||
|
padw = W - w if w % factor != 0 else 0
|
||||||
|
img_tensor = F.pad(img_tensor, (0, padw, 0, padh), 'reflect')
|
||||||
|
result_image = model(img_tensor)
|
||||||
|
if warm_up == 20:
|
||||||
|
break
|
||||||
|
break
|
||||||
|
|
||||||
|
for file in os.listdir(blur_path):
|
||||||
|
if not os.path.isdir(out_path + '/' + file):
|
||||||
|
os.mkdir(out_path + '/' + file)
|
||||||
|
for img_name in os.listdir(blur_path + '/' + file):
|
||||||
|
img = cv2.imread(blur_path + '/' + file + '/' + img_name)
|
||||||
|
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||||
|
img_tensor = torch.from_numpy(np.transpose(img / 255, (2, 0, 1)).astype('float32')) - 0.5
|
||||||
|
with torch.no_grad():
|
||||||
|
iteration += 1
|
||||||
|
img_tensor = Variable(img_tensor.unsqueeze(0)).cuda()
|
||||||
|
|
||||||
|
factor = 8
|
||||||
|
h, w = img_tensor.shape[2], img_tensor.shape[3]
|
||||||
|
H, W = ((h + factor) // factor) * factor, ((w + factor) // factor) * factor
|
||||||
|
padh = H - h if h % factor != 0 else 0
|
||||||
|
padw = W - w if w % factor != 0 else 0
|
||||||
|
img_tensor = F.pad(img_tensor, (0, padw, 0, padh), 'reflect')
|
||||||
|
H, W = img_tensor.shape[2], img_tensor.shape[3]
|
||||||
|
|
||||||
|
start = time.time()
|
||||||
|
_output = model(img_tensor)
|
||||||
|
stop = time.time()
|
||||||
|
|
||||||
|
result_image = _output[:, :, :h, :w]
|
||||||
|
result_image = torch.clamp(result_image, -0.5, 0.5)
|
||||||
|
result_image = result_image + 0.5
|
||||||
|
|
||||||
|
test_time += stop - start
|
||||||
|
print('Image:{}/{}, CNN Runtime:{:.4f}'.format(iteration, total_image_number, (stop - start)))
|
||||||
|
print('Average Runtime:{:.4f}'.format(test_time / float(iteration)))
|
||||||
|
out_file_name = out_path + '/' + file + '/' + img_name
|
||||||
|
torchvision.utils.save_image(result_image, out_file_name)
|
||||||
|
|
||||||
@ -0,0 +1,90 @@
|
|||||||
|
from __future__ import print_function
|
||||||
|
import argparse
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import cv2
|
||||||
|
import yaml
|
||||||
|
import os
|
||||||
|
from torch.autograd import Variable
|
||||||
|
from models.networks import get_generator
|
||||||
|
import torchvision
|
||||||
|
import time
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
parser = argparse.ArgumentParser('Test an image')
|
||||||
|
parser.add_argument('--weights_path', required=True, help='Weights path')
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
args = get_args()
|
||||||
|
with open('config/config_Stripformer_gopro.yaml') as cfg:
|
||||||
|
config = yaml.load(cfg)
|
||||||
|
blur_path = './datasets/Realblur_R/test/blur/'
|
||||||
|
out_path = './out/Stripformer_realblur_R_results'
|
||||||
|
model = get_generator(config['model'])
|
||||||
|
model.load_state_dict(torch.load(args.weights_path))
|
||||||
|
model = model.cuda()
|
||||||
|
if not os.path.isdir(out_path):
|
||||||
|
os.mkdir(out_path)
|
||||||
|
test_time = 0
|
||||||
|
iteration = 0
|
||||||
|
total_image_number = 980
|
||||||
|
|
||||||
|
# warm up
|
||||||
|
warm_up = 0
|
||||||
|
print('Hardware warm-up')
|
||||||
|
for file in os.listdir(blur_path):
|
||||||
|
if not os.path.isdir(out_path + '/' + file):
|
||||||
|
os.mkdir(out_path + '/' + file)
|
||||||
|
for img_name in os.listdir(blur_path + '/' + file):
|
||||||
|
warm_up += 1
|
||||||
|
img = cv2.imread(blur_path + '/' + file + '/' + img_name)
|
||||||
|
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||||
|
img_tensor = torch.from_numpy(np.transpose(img / 255, (2, 0, 1)).astype('float32')) - 0.5
|
||||||
|
with torch.no_grad():
|
||||||
|
img_tensor = Variable(img_tensor.unsqueeze(0)).cuda()
|
||||||
|
factor = 8
|
||||||
|
h, w = img_tensor.shape[2], img_tensor.shape[3]
|
||||||
|
H, W = ((h + factor) // factor) * factor, ((w + factor) // factor) * factor
|
||||||
|
padh = H - h if h % factor != 0 else 0
|
||||||
|
padw = W - w if w % factor != 0 else 0
|
||||||
|
img_tensor = F.pad(img_tensor, (0, padw, 0, padh), 'reflect')
|
||||||
|
result_image = model(img_tensor)
|
||||||
|
if warm_up == 20:
|
||||||
|
break
|
||||||
|
break
|
||||||
|
|
||||||
|
for file in os.listdir(blur_path):
|
||||||
|
if not os.path.isdir(out_path + '/' + file):
|
||||||
|
os.mkdir(out_path + '/' + file)
|
||||||
|
for img_name in os.listdir(blur_path + '/' + file):
|
||||||
|
img = cv2.imread(blur_path + '/' + file + '/' + img_name)
|
||||||
|
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||||
|
img_tensor = torch.from_numpy(np.transpose(img / 255, (2, 0, 1)).astype('float32')) - 0.5
|
||||||
|
with torch.no_grad():
|
||||||
|
iteration += 1
|
||||||
|
img_tensor = Variable(img_tensor.unsqueeze(0)).cuda()
|
||||||
|
|
||||||
|
factor = 8
|
||||||
|
h, w = img_tensor.shape[2], img_tensor.shape[3]
|
||||||
|
H, W = ((h + factor) // factor) * factor, ((w + factor) // factor) * factor
|
||||||
|
padh = H - h if h % factor != 0 else 0
|
||||||
|
padw = W - w if w % factor != 0 else 0
|
||||||
|
img_tensor = F.pad(img_tensor, (0, padw, 0, padh), 'reflect')
|
||||||
|
H, W = img_tensor.shape[2], img_tensor.shape[3]
|
||||||
|
|
||||||
|
start = time.time()
|
||||||
|
_output = model(img_tensor)
|
||||||
|
stop = time.time()
|
||||||
|
|
||||||
|
result_image = _output[:, :, :h, :w]
|
||||||
|
result_image = torch.clamp(result_image, -0.5, 0.5)
|
||||||
|
result_image = result_image + 0.5
|
||||||
|
|
||||||
|
test_time += stop - start
|
||||||
|
print('Image:{}/{}, CNN Runtime:{:.4f}'.format(iteration, total_image_number, (stop - start)))
|
||||||
|
print('Average Runtime:{:.4f}'.format(test_time / float(iteration)))
|
||||||
|
out_file_name = out_path + '/' + file + '/' + img_name
|
||||||
|
torchvision.utils.save_image(result_image, out_file_name)
|
||||||
|
|
||||||
@ -0,0 +1,59 @@
|
|||||||
|
import math
|
||||||
|
|
||||||
|
from torch.optim import lr_scheduler
|
||||||
|
|
||||||
|
|
||||||
|
class WarmRestart(lr_scheduler.CosineAnnealingLR):
|
||||||
|
"""This class implements Stochastic Gradient Descent with Warm Restarts(SGDR): https://arxiv.org/abs/1608.03983.
|
||||||
|
|
||||||
|
Set the learning rate of each parameter group using a cosine annealing schedule, When last_epoch=-1, sets initial lr as lr.
|
||||||
|
This can't support scheduler.step(epoch). please keep epoch=None.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, optimizer, T_max=30, T_mult=1, eta_min=0, last_epoch=-1):
|
||||||
|
"""implements SGDR
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
----------
|
||||||
|
T_max : int
|
||||||
|
Maximum number of epochs.
|
||||||
|
T_mult : int
|
||||||
|
Multiplicative factor of T_max.
|
||||||
|
eta_min : int
|
||||||
|
Minimum learning rate. Default: 0.
|
||||||
|
last_epoch : int
|
||||||
|
The index of last epoch. Default: -1.
|
||||||
|
"""
|
||||||
|
self.T_mult = T_mult
|
||||||
|
super().__init__(optimizer, T_max, eta_min, last_epoch)
|
||||||
|
|
||||||
|
def get_lr(self):
|
||||||
|
if self.last_epoch == self.T_max:
|
||||||
|
self.last_epoch = 0
|
||||||
|
self.T_max *= self.T_mult
|
||||||
|
return [self.eta_min + (base_lr - self.eta_min) * (1 + math.cos(math.pi * self.last_epoch / self.T_max)) / 2 for
|
||||||
|
base_lr in self.base_lrs]
|
||||||
|
|
||||||
|
|
||||||
|
class LinearDecay(lr_scheduler._LRScheduler):
|
||||||
|
"""This class implements LinearDecay
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, optimizer, num_epochs, start_epoch=0, min_lr=0, last_epoch=-1):
|
||||||
|
"""implements LinearDecay
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
----------
|
||||||
|
|
||||||
|
"""
|
||||||
|
self.num_epochs = num_epochs
|
||||||
|
self.start_epoch = start_epoch
|
||||||
|
self.min_lr = min_lr
|
||||||
|
super().__init__(optimizer, last_epoch)
|
||||||
|
|
||||||
|
def get_lr(self):
|
||||||
|
if self.last_epoch < self.start_epoch:
|
||||||
|
return self.base_lrs
|
||||||
|
return [base_lr - ((base_lr - self.min_lr) / self.num_epochs) * (self.last_epoch - self.start_epoch) for
|
||||||
|
base_lr in self.base_lrs]
|
||||||
@ -0,0 +1,170 @@
|
|||||||
|
import logging
|
||||||
|
from functools import partial
|
||||||
|
import os
|
||||||
|
import cv2
|
||||||
|
import torch
|
||||||
|
import torch.optim as optim
|
||||||
|
import tqdm
|
||||||
|
import yaml
|
||||||
|
from joblib import cpu_count
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
import random
|
||||||
|
from dataset import PairedDataset
|
||||||
|
from metric_counter import MetricCounter
|
||||||
|
from models.losses import get_loss
|
||||||
|
from models.models import get_model
|
||||||
|
from models.networks import get_nets
|
||||||
|
import numpy as np
|
||||||
|
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||||
|
cv2.setNumThreads(0)
|
||||||
|
|
||||||
|
|
||||||
|
class Trainer:
|
||||||
|
def __init__(self, config, train: DataLoader, val: DataLoader):
|
||||||
|
self.config = config
|
||||||
|
self.train_dataset = train
|
||||||
|
self.val_dataset = val
|
||||||
|
self.metric_counter = MetricCounter(config['experiment_desc'])
|
||||||
|
|
||||||
|
def train(self):
|
||||||
|
self._init_params()
|
||||||
|
start_epoch = 0
|
||||||
|
if os.path.exists('last_Stripformer_gopro.pth'):
|
||||||
|
print('load_pretrained')
|
||||||
|
training_state = (torch.load('last_Stripformer_gopro.pth'))
|
||||||
|
start_epoch = training_state['epoch']
|
||||||
|
new_weight = self.netG.state_dict()
|
||||||
|
new_weight.update(training_state['model_state'])
|
||||||
|
self.netG.load_state_dict(new_weight)
|
||||||
|
new_optimizer = self.optimizer_G.state_dict()
|
||||||
|
new_optimizer.update(training_state['optimizer_state'])
|
||||||
|
self.optimizer_G.load_state_dict(new_optimizer)
|
||||||
|
new_scheduler = self.scheduler_G.state_dict()
|
||||||
|
new_scheduler.update(training_state['scheduler_state'])
|
||||||
|
self.scheduler_G.load_state_dict(new_scheduler)
|
||||||
|
else:
|
||||||
|
print('load_GoPro_pretrained')
|
||||||
|
training_state = (torch.load('final_Stripformer_pretrained.pth'))
|
||||||
|
new_weight = self.netG.state_dict()
|
||||||
|
new_weight.update(training_state)
|
||||||
|
self.netG.load_state_dict(new_weight)
|
||||||
|
|
||||||
|
|
||||||
|
for epoch in range(start_epoch, config['num_epochs']):
|
||||||
|
self._run_epoch(epoch)
|
||||||
|
if epoch % 30 == 0 or epoch == (config['num_epochs']-1):
|
||||||
|
self._validate(epoch)
|
||||||
|
self.scheduler_G.step()
|
||||||
|
|
||||||
|
scheduler_state = self.scheduler_G.state_dict()
|
||||||
|
training_state = {'epoch': epoch, 'model_state': self.netG.state_dict(),
|
||||||
|
'scheduler_state': scheduler_state, 'optimizer_state': self.optimizer_G.state_dict()}
|
||||||
|
if self.metric_counter.update_best_model():
|
||||||
|
torch.save(training_state['model_state'], 'best_{}.pth'.format(self.config['experiment_desc']))
|
||||||
|
|
||||||
|
if epoch % 200 == 0:
|
||||||
|
torch.save(training_state, 'last_{}_{}.pth'.format(self.config['experiment_desc'], epoch))
|
||||||
|
|
||||||
|
if epoch == (config['num_epochs']-1):
|
||||||
|
torch.save(training_state['model_state'], 'final_{}.pth'.format(self.config['experiment_desc']))
|
||||||
|
|
||||||
|
torch.save(training_state, 'last_{}.pth'.format(self.config['experiment_desc']))
|
||||||
|
logging.debug("Experiment Name: %s, Epoch: %d, Loss: %s" % (
|
||||||
|
self.config['experiment_desc'], epoch, self.metric_counter.loss_message()))
|
||||||
|
|
||||||
|
def _run_epoch(self, epoch):
|
||||||
|
self.metric_counter.clear()
|
||||||
|
for param_group in self.optimizer_G.param_groups:
|
||||||
|
lr = param_group['lr']
|
||||||
|
|
||||||
|
epoch_size = config.get('train_batches_per_epoch') or len(self.train_dataset)
|
||||||
|
tq = tqdm.tqdm(self.train_dataset, total=epoch_size)
|
||||||
|
tq.set_description('Epoch {}, lr {}'.format(epoch, lr))
|
||||||
|
i = 0
|
||||||
|
for data in tq:
|
||||||
|
inputs, targets = self.model.get_input(data)
|
||||||
|
outputs = self.netG(inputs)
|
||||||
|
self.optimizer_G.zero_grad()
|
||||||
|
loss_G = self.criterionG(outputs, targets, inputs)
|
||||||
|
loss_G.backward()
|
||||||
|
self.optimizer_G.step()
|
||||||
|
self.metric_counter.add_losses(loss_G.item())
|
||||||
|
curr_psnr, curr_ssim, img_for_vis = self.model.get_images_and_metrics(inputs, outputs, targets)
|
||||||
|
self.metric_counter.add_metrics(curr_psnr, curr_ssim)
|
||||||
|
tq.set_postfix(loss=self.metric_counter.loss_message())
|
||||||
|
if not i:
|
||||||
|
self.metric_counter.add_image(img_for_vis, tag='train')
|
||||||
|
i += 1
|
||||||
|
if i > epoch_size:
|
||||||
|
break
|
||||||
|
tq.close()
|
||||||
|
self.metric_counter.write_to_tensorboard(epoch)
|
||||||
|
|
||||||
|
def _validate(self, epoch):
|
||||||
|
self.metric_counter.clear()
|
||||||
|
epoch_size = config.get('val_batches_per_epoch') or len(self.val_dataset)
|
||||||
|
tq = tqdm.tqdm(self.val_dataset, total=epoch_size)
|
||||||
|
tq.set_description('Validation')
|
||||||
|
i = 0
|
||||||
|
for data in tq:
|
||||||
|
with torch.no_grad():
|
||||||
|
inputs, targets = self.model.get_input(data)
|
||||||
|
outputs = self.netG(inputs)
|
||||||
|
loss_G = self.criterionG(outputs, targets, inputs)
|
||||||
|
self.metric_counter.add_losses(loss_G.item())
|
||||||
|
curr_psnr, curr_ssim, img_for_vis = self.model.get_images_and_metrics(inputs, outputs, targets)
|
||||||
|
self.metric_counter.add_metrics(curr_psnr, curr_ssim)
|
||||||
|
if not i:
|
||||||
|
self.metric_counter.add_image(img_for_vis, tag='val')
|
||||||
|
i += 1
|
||||||
|
if i > epoch_size:
|
||||||
|
break
|
||||||
|
tq.close()
|
||||||
|
self.metric_counter.write_to_tensorboard(epoch, validation=True)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_optim(self, params):
|
||||||
|
if self.config['optimizer']['name'] == 'adam':
|
||||||
|
optimizer = optim.Adam(params, lr=self.config['optimizer']['lr'])
|
||||||
|
else:
|
||||||
|
raise ValueError("Optimizer [%s] not recognized." % self.config['optimizer']['name'])
|
||||||
|
return optimizer
|
||||||
|
|
||||||
|
def _get_scheduler(self, optimizer):
|
||||||
|
if self.config['scheduler']['name'] == 'cosine':
|
||||||
|
scheduler = CosineAnnealingLR(optimizer, T_max=self.config['num_epochs'], eta_min=self.config['scheduler']['min_lr'])
|
||||||
|
else:
|
||||||
|
raise ValueError("Scheduler [%s] not recognized." % self.config['scheduler']['name'])
|
||||||
|
return scheduler
|
||||||
|
|
||||||
|
def _init_params(self):
|
||||||
|
self.criterionG = get_loss(self.config['model'])
|
||||||
|
self.netG = get_nets(self.config['model'])
|
||||||
|
self.netG.cuda()
|
||||||
|
self.model = get_model(self.config['model'])
|
||||||
|
self.optimizer_G = self._get_optim(filter(lambda p: p.requires_grad, self.netG.parameters()))
|
||||||
|
self.scheduler_G = self._get_scheduler(self.optimizer_G)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
with open('config/config_Stripformer_gopro.yaml', 'r') as f:
|
||||||
|
config = yaml.load(f)
|
||||||
|
# setup
|
||||||
|
torch.backends.cudnn.enabled = True
|
||||||
|
torch.backends.cudnn.benchmark = True
|
||||||
|
|
||||||
|
# set random seed
|
||||||
|
seed = 666
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
torch.cuda.manual_seed(seed)
|
||||||
|
random.seed(seed)
|
||||||
|
np.random.seed(seed)
|
||||||
|
|
||||||
|
batch_size = config.pop('batch_size')
|
||||||
|
get_dataloader = partial(DataLoader, batch_size=batch_size, num_workers=cpu_count(), shuffle=True, drop_last=False)
|
||||||
|
|
||||||
|
datasets = map(config.pop, ('train', 'val'))
|
||||||
|
datasets = map(PairedDataset.from_config, datasets)
|
||||||
|
train, val = map(get_dataloader, datasets)
|
||||||
|
trainer = Trainer(config, train=train, val=val)
|
||||||
|
trainer.train()
|
||||||
@ -0,0 +1,164 @@
|
|||||||
|
import logging
|
||||||
|
from functools import partial
|
||||||
|
import os
|
||||||
|
import cv2
|
||||||
|
import torch
|
||||||
|
import torch.optim as optim
|
||||||
|
import tqdm
|
||||||
|
import yaml
|
||||||
|
from joblib import cpu_count
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
import random
|
||||||
|
from dataset import PairedDataset
|
||||||
|
from metric_counter import MetricCounter
|
||||||
|
from models.losses import get_loss
|
||||||
|
from models.models import get_model
|
||||||
|
from models.networks import get_nets
|
||||||
|
import numpy as np
|
||||||
|
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||||
|
cv2.setNumThreads(0)
|
||||||
|
|
||||||
|
|
||||||
|
class Trainer:
|
||||||
|
def __init__(self, config, train: DataLoader, val: DataLoader):
|
||||||
|
self.config = config
|
||||||
|
self.train_dataset = train
|
||||||
|
self.val_dataset = val
|
||||||
|
self.metric_counter = MetricCounter(config['experiment_desc'])
|
||||||
|
|
||||||
|
def train(self):
|
||||||
|
self._init_params()
|
||||||
|
start_epoch = 0
|
||||||
|
if os.path.exists('last_Stripformer_pretrained.pth'):
|
||||||
|
print('load_pretrained')
|
||||||
|
training_state = (torch.load('last_Stripformer_pretrained.pth'))
|
||||||
|
start_epoch = training_state['epoch']
|
||||||
|
new_weight = self.netG.state_dict()
|
||||||
|
new_weight.update(training_state['model_state'])
|
||||||
|
self.netG.load_state_dict(new_weight)
|
||||||
|
new_optimizer = self.optimizer_G.state_dict()
|
||||||
|
new_optimizer.update(training_state['optimizer_state'])
|
||||||
|
self.optimizer_G.load_state_dict(new_optimizer)
|
||||||
|
new_scheduler = self.scheduler_G.state_dict()
|
||||||
|
new_scheduler.update(training_state['scheduler_state'])
|
||||||
|
self.scheduler_G.load_state_dict(new_scheduler)
|
||||||
|
|
||||||
|
|
||||||
|
for epoch in range(start_epoch, config['num_epochs']):
|
||||||
|
self._run_epoch(epoch)
|
||||||
|
if epoch % 30 == 0 or epoch == (config['num_epochs']-1):
|
||||||
|
self._validate(epoch)
|
||||||
|
self.scheduler_G.step()
|
||||||
|
|
||||||
|
scheduler_state = self.scheduler_G.state_dict()
|
||||||
|
training_state = {'epoch': epoch, 'model_state': self.netG.state_dict(),
|
||||||
|
'scheduler_state': scheduler_state, 'optimizer_state': self.optimizer_G.state_dict()}
|
||||||
|
if self.metric_counter.update_best_model():
|
||||||
|
torch.save(training_state['model_state'], 'best_{}.pth'.format(self.config['experiment_desc']))
|
||||||
|
|
||||||
|
if epoch % 300 == 0:
|
||||||
|
torch.save(training_state, 'last_{}_{}.pth'.format(self.config['experiment_desc'], epoch))
|
||||||
|
|
||||||
|
if epoch == (config['num_epochs']-1):
|
||||||
|
torch.save(training_state['model_state'], 'final_{}.pth'.format(self.config['experiment_desc']))
|
||||||
|
|
||||||
|
torch.save(training_state, 'last_{}.pth'.format(self.config['experiment_desc']))
|
||||||
|
logging.debug("Experiment Name: %s, Epoch: %d, Loss: %s" % (
|
||||||
|
self.config['experiment_desc'], epoch, self.metric_counter.loss_message()))
|
||||||
|
|
||||||
|
def _run_epoch(self, epoch):
|
||||||
|
self.metric_counter.clear()
|
||||||
|
for param_group in self.optimizer_G.param_groups:
|
||||||
|
lr = param_group['lr']
|
||||||
|
|
||||||
|
epoch_size = config.get('train_batches_per_epoch') or len(self.train_dataset)
|
||||||
|
tq = tqdm.tqdm(self.train_dataset, total=epoch_size)
|
||||||
|
tq.set_description('Epoch {}, lr {}'.format(epoch, lr))
|
||||||
|
i = 0
|
||||||
|
for data in tq:
|
||||||
|
inputs, targets = self.model.get_input(data)
|
||||||
|
outputs = self.netG(inputs)
|
||||||
|
self.optimizer_G.zero_grad()
|
||||||
|
loss_G = self.criterionG(outputs, targets, inputs)
|
||||||
|
loss_G.backward()
|
||||||
|
self.optimizer_G.step()
|
||||||
|
self.metric_counter.add_losses(loss_G.item())
|
||||||
|
curr_psnr, curr_ssim, img_for_vis = self.model.get_images_and_metrics(inputs, outputs, targets)
|
||||||
|
self.metric_counter.add_metrics(curr_psnr, curr_ssim)
|
||||||
|
tq.set_postfix(loss=self.metric_counter.loss_message())
|
||||||
|
if not i:
|
||||||
|
self.metric_counter.add_image(img_for_vis, tag='train')
|
||||||
|
i += 1
|
||||||
|
if i > epoch_size:
|
||||||
|
break
|
||||||
|
tq.close()
|
||||||
|
self.metric_counter.write_to_tensorboard(epoch)
|
||||||
|
|
||||||
|
def _validate(self, epoch):
|
||||||
|
self.metric_counter.clear()
|
||||||
|
epoch_size = config.get('val_batches_per_epoch') or len(self.val_dataset)
|
||||||
|
tq = tqdm.tqdm(self.val_dataset, total=epoch_size)
|
||||||
|
tq.set_description('Validation')
|
||||||
|
i = 0
|
||||||
|
for data in tq:
|
||||||
|
with torch.no_grad():
|
||||||
|
inputs, targets = self.model.get_input(data)
|
||||||
|
outputs = self.netG(inputs)
|
||||||
|
loss_G = self.criterionG(outputs, targets, inputs)
|
||||||
|
self.metric_counter.add_losses(loss_G.item())
|
||||||
|
curr_psnr, curr_ssim, img_for_vis = self.model.get_images_and_metrics(inputs, outputs, targets)
|
||||||
|
self.metric_counter.add_metrics(curr_psnr, curr_ssim)
|
||||||
|
if not i:
|
||||||
|
self.metric_counter.add_image(img_for_vis, tag='val')
|
||||||
|
i += 1
|
||||||
|
if i > epoch_size:
|
||||||
|
break
|
||||||
|
tq.close()
|
||||||
|
self.metric_counter.write_to_tensorboard(epoch, validation=True)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_optim(self, params):
|
||||||
|
if self.config['optimizer']['name'] == 'adam':
|
||||||
|
optimizer = optim.Adam(params, lr=self.config['optimizer']['lr'])
|
||||||
|
else:
|
||||||
|
raise ValueError("Optimizer [%s] not recognized." % self.config['optimizer']['name'])
|
||||||
|
return optimizer
|
||||||
|
|
||||||
|
def _get_scheduler(self, optimizer):
|
||||||
|
if self.config['scheduler']['name'] == 'cosine':
|
||||||
|
scheduler = CosineAnnealingLR(optimizer, T_max=self.config['num_epochs'], eta_min=self.config['scheduler']['min_lr'])
|
||||||
|
else:
|
||||||
|
raise ValueError("Scheduler [%s] not recognized." % self.config['scheduler']['name'])
|
||||||
|
return scheduler
|
||||||
|
|
||||||
|
def _init_params(self):
|
||||||
|
self.criterionG = get_loss(self.config['model'])
|
||||||
|
self.netG = get_nets(self.config['model'])
|
||||||
|
self.netG.cuda()
|
||||||
|
self.model = get_model(self.config['model'])
|
||||||
|
self.optimizer_G = self._get_optim(filter(lambda p: p.requires_grad, self.netG.parameters()))
|
||||||
|
self.scheduler_G = self._get_scheduler(self.optimizer_G)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
with open('config/config_Stripformer_pretrained.yaml', 'r') as f:
|
||||||
|
config = yaml.load(f)
|
||||||
|
# setup
|
||||||
|
torch.backends.cudnn.enabled = True
|
||||||
|
torch.backends.cudnn.benchmark = True
|
||||||
|
|
||||||
|
# set random seed
|
||||||
|
seed = 666
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
torch.cuda.manual_seed(seed)
|
||||||
|
random.seed(seed)
|
||||||
|
np.random.seed(seed)
|
||||||
|
|
||||||
|
batch_size = config.pop('batch_size')
|
||||||
|
get_dataloader = partial(DataLoader, batch_size=batch_size, num_workers=cpu_count(), shuffle=True, drop_last=False)
|
||||||
|
|
||||||
|
datasets = map(config.pop, ('train', 'val'))
|
||||||
|
datasets = map(PairedDataset.from_config, datasets)
|
||||||
|
train, val = map(get_dataloader, datasets)
|
||||||
|
trainer = Trainer(config, train=train, val=val)
|
||||||
|
trainer.train()
|
||||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -0,0 +1,66 @@
|
|||||||
|
import dominate
|
||||||
|
from dominate.tags import *
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
class HTML:
|
||||||
|
def __init__(self, web_dir, title, image_subdir='', reflesh=0):
|
||||||
|
self.title = title
|
||||||
|
self.web_dir = web_dir
|
||||||
|
# self.img_dir = os.path.join(self.web_dir, )
|
||||||
|
self.img_subdir = image_subdir
|
||||||
|
self.img_dir = os.path.join(self.web_dir, image_subdir)
|
||||||
|
if not os.path.exists(self.web_dir):
|
||||||
|
os.makedirs(self.web_dir)
|
||||||
|
if not os.path.exists(self.img_dir):
|
||||||
|
os.makedirs(self.img_dir)
|
||||||
|
# print(self.img_dir)
|
||||||
|
|
||||||
|
self.doc = dominate.document(title=title)
|
||||||
|
if reflesh > 0:
|
||||||
|
with self.doc.head:
|
||||||
|
meta(http_equiv="reflesh", content=str(reflesh))
|
||||||
|
|
||||||
|
def get_image_dir(self):
|
||||||
|
return self.img_dir
|
||||||
|
|
||||||
|
def add_header(self, str):
|
||||||
|
with self.doc:
|
||||||
|
h3(str)
|
||||||
|
|
||||||
|
def add_table(self, border=1):
|
||||||
|
self.t = table(border=border, style="table-layout: fixed;")
|
||||||
|
self.doc.add(self.t)
|
||||||
|
|
||||||
|
def add_images(self, ims, txts, links, width=400):
|
||||||
|
self.add_table()
|
||||||
|
with self.t:
|
||||||
|
with tr():
|
||||||
|
for im, txt, link in zip(ims, txts, links):
|
||||||
|
with td(style="word-wrap: break-word;", halign="center", valign="top"):
|
||||||
|
with p():
|
||||||
|
with a(href=os.path.join(link)):
|
||||||
|
img(style="width:%dpx" % width, src=os.path.join(im))
|
||||||
|
br()
|
||||||
|
p(txt)
|
||||||
|
|
||||||
|
def save(self,file='index'):
|
||||||
|
html_file = '%s/%s.html' % (self.web_dir,file)
|
||||||
|
f = open(html_file, 'wt')
|
||||||
|
f.write(self.doc.render())
|
||||||
|
f.close()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
html = HTML('web/', 'test_html')
|
||||||
|
html.add_header('hello world')
|
||||||
|
|
||||||
|
ims = []
|
||||||
|
txts = []
|
||||||
|
links = []
|
||||||
|
for n in range(4):
|
||||||
|
ims.append('image_%d.png' % n)
|
||||||
|
txts.append('text_%d' % n)
|
||||||
|
links.append('image_%d.png' % n)
|
||||||
|
html.add_images(ims, txts, links)
|
||||||
|
html.save()
|
||||||
@ -0,0 +1,33 @@
|
|||||||
|
import random
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from torch.autograd import Variable
|
||||||
|
from collections import deque
|
||||||
|
|
||||||
|
|
||||||
|
class ImagePool():
|
||||||
|
def __init__(self, pool_size):
|
||||||
|
self.pool_size = pool_size
|
||||||
|
self.sample_size = pool_size
|
||||||
|
if self.pool_size > 0:
|
||||||
|
self.num_imgs = 0
|
||||||
|
self.images = deque()
|
||||||
|
|
||||||
|
def add(self, images):
|
||||||
|
if self.pool_size == 0:
|
||||||
|
return images
|
||||||
|
for image in images.data:
|
||||||
|
image = torch.unsqueeze(image, 0)
|
||||||
|
if self.num_imgs < self.pool_size:
|
||||||
|
self.num_imgs = self.num_imgs + 1
|
||||||
|
self.images.append(image)
|
||||||
|
else:
|
||||||
|
self.images.popleft()
|
||||||
|
self.images.append(image)
|
||||||
|
|
||||||
|
def query(self):
|
||||||
|
if len(self.images) > self.sample_size:
|
||||||
|
return_images = list(random.sample(self.images, self.sample_size))
|
||||||
|
else:
|
||||||
|
return_images = list(self.images)
|
||||||
|
return torch.cat(return_images, 0)
|
||||||
@ -0,0 +1,54 @@
|
|||||||
|
import math
|
||||||
|
from math import exp
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch.autograd import Variable
|
||||||
|
|
||||||
|
|
||||||
|
def gaussian(window_size, sigma):
|
||||||
|
gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)])
|
||||||
|
return gauss / gauss.sum()
|
||||||
|
|
||||||
|
|
||||||
|
def create_window(window_size, channel):
|
||||||
|
_1D_window = gaussian(window_size, 1.5).unsqueeze(1)
|
||||||
|
_2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
|
||||||
|
window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
|
||||||
|
return window
|
||||||
|
|
||||||
|
|
||||||
|
def SSIM(img1, img2):
|
||||||
|
(_, channel, _, _) = img1.size()
|
||||||
|
window_size = 11
|
||||||
|
window = create_window(window_size, channel)
|
||||||
|
|
||||||
|
if img1.is_cuda:
|
||||||
|
window = window.cuda(img1.get_device())
|
||||||
|
window = window.type_as(img1)
|
||||||
|
|
||||||
|
mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
|
||||||
|
mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
|
||||||
|
|
||||||
|
mu1_sq = mu1.pow(2)
|
||||||
|
mu2_sq = mu2.pow(2)
|
||||||
|
mu1_mu2 = mu1 * mu2
|
||||||
|
|
||||||
|
sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
|
||||||
|
sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
|
||||||
|
sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2
|
||||||
|
|
||||||
|
C1 = 0.01 ** 2
|
||||||
|
C2 = 0.03 ** 2
|
||||||
|
|
||||||
|
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
|
||||||
|
return ssim_map.mean()
|
||||||
|
|
||||||
|
|
||||||
|
def PSNR(img1, img2):
|
||||||
|
mse = np.mean((img1 / 255. - img2 / 255.) ** 2)
|
||||||
|
if mse == 0:
|
||||||
|
return 100
|
||||||
|
PIXEL_MAX = 1
|
||||||
|
return 20 * math.log10(PIXEL_MAX / math.sqrt(mse))
|
||||||
@ -0,0 +1,48 @@
|
|||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
import numpy as np
|
||||||
|
import os
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import torch
|
||||||
|
|
||||||
|
def load_image(path):
|
||||||
|
if(path[-3:] == 'dng'):
|
||||||
|
import rawpy
|
||||||
|
with rawpy.imread(path) as raw:
|
||||||
|
img = raw.postprocess()
|
||||||
|
elif(path[-3:]=='bmp' or path[-3:]=='jpg' or path[-3:]=='png'):
|
||||||
|
import cv2
|
||||||
|
return cv2.imread(path)[:,:,::-1]
|
||||||
|
else:
|
||||||
|
img = (255*plt.imread(path)[:,:,:3]).astype('uint8')
|
||||||
|
|
||||||
|
return img
|
||||||
|
|
||||||
|
def save_image(image_numpy, image_path, ):
|
||||||
|
image_pil = Image.fromarray(image_numpy)
|
||||||
|
image_pil.save(image_path)
|
||||||
|
|
||||||
|
def mkdirs(paths):
|
||||||
|
if isinstance(paths, list) and not isinstance(paths, str):
|
||||||
|
for path in paths:
|
||||||
|
mkdir(path)
|
||||||
|
else:
|
||||||
|
mkdir(paths)
|
||||||
|
|
||||||
|
def mkdir(path):
|
||||||
|
if not os.path.exists(path):
|
||||||
|
os.makedirs(path)
|
||||||
|
|
||||||
|
|
||||||
|
def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.):
|
||||||
|
# def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=1.):
|
||||||
|
image_numpy = image_tensor[0].cpu().float().numpy()
|
||||||
|
image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor
|
||||||
|
return image_numpy.astype(imtype)
|
||||||
|
|
||||||
|
def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.):
|
||||||
|
# def im2tensor(image, imtype=np.uint8, cent=1., factor=1.):
|
||||||
|
return torch.Tensor((image / factor - cent)
|
||||||
|
[:, :, :, np.newaxis].transpose((3, 2, 0, 1)))
|
||||||
@ -0,0 +1,216 @@
|
|||||||
|
import numpy as np
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
from . import util
|
||||||
|
from . import html
|
||||||
|
# from pdb import set_trace as st
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import math
|
||||||
|
# from IPython import embed
|
||||||
|
|
||||||
|
def zoom_to_res(img,res=256,order=0,axis=0):
|
||||||
|
# img 3xXxX
|
||||||
|
from scipy.ndimage import zoom
|
||||||
|
zoom_factor = res/img.shape[1]
|
||||||
|
if(axis==0):
|
||||||
|
return zoom(img,[1,zoom_factor,zoom_factor],order=order)
|
||||||
|
elif(axis==2):
|
||||||
|
return zoom(img,[zoom_factor,zoom_factor,1],order=order)
|
||||||
|
|
||||||
|
class Visualizer():
|
||||||
|
def __init__(self, opt):
|
||||||
|
# self.opt = opt
|
||||||
|
self.display_id = opt.display_id
|
||||||
|
# self.use_html = opt.is_train and not opt.no_html
|
||||||
|
self.win_size = opt.display_winsize
|
||||||
|
self.name = opt.name
|
||||||
|
self.display_cnt = 0 # display_current_results counter
|
||||||
|
self.display_cnt_high = 0
|
||||||
|
self.use_html = opt.use_html
|
||||||
|
|
||||||
|
if self.display_id > 0:
|
||||||
|
import visdom
|
||||||
|
self.vis = visdom.Visdom(port = opt.display_port)
|
||||||
|
|
||||||
|
self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web')
|
||||||
|
util.mkdirs([self.web_dir,])
|
||||||
|
if self.use_html:
|
||||||
|
self.img_dir = os.path.join(self.web_dir, 'images')
|
||||||
|
print('create web directory %s...' % self.web_dir)
|
||||||
|
util.mkdirs([self.img_dir,])
|
||||||
|
|
||||||
|
# |visuals|: dictionary of images to display or save
|
||||||
|
def display_current_results(self, visuals, epoch, nrows=None, res=256):
|
||||||
|
if self.display_id > 0: # show images in the browser
|
||||||
|
title = self.name
|
||||||
|
if(nrows is None):
|
||||||
|
nrows = int(math.ceil(len(visuals.items()) / 2.0))
|
||||||
|
images = []
|
||||||
|
idx = 0
|
||||||
|
for label, image_numpy in visuals.items():
|
||||||
|
title += " | " if idx % nrows == 0 else ", "
|
||||||
|
title += label
|
||||||
|
img = image_numpy.transpose([2, 0, 1])
|
||||||
|
img = zoom_to_res(img,res=res,order=0)
|
||||||
|
images.append(img)
|
||||||
|
idx += 1
|
||||||
|
if len(visuals.items()) % 2 != 0:
|
||||||
|
white_image = np.ones_like(image_numpy.transpose([2, 0, 1]))*255
|
||||||
|
white_image = zoom_to_res(white_image,res=res,order=0)
|
||||||
|
images.append(white_image)
|
||||||
|
self.vis.images(images, nrow=nrows, win=self.display_id + 1,
|
||||||
|
opts=dict(title=title))
|
||||||
|
|
||||||
|
if self.use_html: # save images to a html file
|
||||||
|
for label, image_numpy in visuals.items():
|
||||||
|
img_path = os.path.join(self.img_dir, 'epoch%.3d_cnt%.6d_%s.png' % (epoch, self.display_cnt, label))
|
||||||
|
util.save_image(zoom_to_res(image_numpy, res=res, axis=2), img_path)
|
||||||
|
|
||||||
|
self.display_cnt += 1
|
||||||
|
self.display_cnt_high = np.maximum(self.display_cnt_high, self.display_cnt)
|
||||||
|
|
||||||
|
# update website
|
||||||
|
webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, reflesh=1)
|
||||||
|
for n in range(epoch, 0, -1):
|
||||||
|
webpage.add_header('epoch [%d]' % n)
|
||||||
|
if(n==epoch):
|
||||||
|
high = self.display_cnt
|
||||||
|
else:
|
||||||
|
high = self.display_cnt_high
|
||||||
|
for c in range(high-1,-1,-1):
|
||||||
|
ims = []
|
||||||
|
txts = []
|
||||||
|
links = []
|
||||||
|
|
||||||
|
for label, image_numpy in visuals.items():
|
||||||
|
img_path = 'epoch%.3d_cnt%.6d_%s.png' % (n, c, label)
|
||||||
|
ims.append(os.path.join('images',img_path))
|
||||||
|
txts.append(label)
|
||||||
|
links.append(os.path.join('images',img_path))
|
||||||
|
webpage.add_images(ims, txts, links, width=self.win_size)
|
||||||
|
webpage.save()
|
||||||
|
|
||||||
|
# save errors into a directory
|
||||||
|
def plot_current_errors_save(self, epoch, counter_ratio, opt, errors,keys='+ALL',name='loss', to_plot=False):
|
||||||
|
if not hasattr(self, 'plot_data'):
|
||||||
|
self.plot_data = {'X':[],'Y':[], 'legend':list(errors.keys())}
|
||||||
|
self.plot_data['X'].append(epoch + counter_ratio)
|
||||||
|
self.plot_data['Y'].append([errors[k] for k in self.plot_data['legend']])
|
||||||
|
|
||||||
|
# embed()
|
||||||
|
if(keys=='+ALL'):
|
||||||
|
plot_keys = self.plot_data['legend']
|
||||||
|
else:
|
||||||
|
plot_keys = keys
|
||||||
|
|
||||||
|
if(to_plot):
|
||||||
|
(f,ax) = plt.subplots(1,1)
|
||||||
|
for (k,kname) in enumerate(plot_keys):
|
||||||
|
kk = np.where(np.array(self.plot_data['legend'])==kname)[0][0]
|
||||||
|
x = self.plot_data['X']
|
||||||
|
y = np.array(self.plot_data['Y'])[:,kk]
|
||||||
|
if(to_plot):
|
||||||
|
ax.plot(x, y, 'o-', label=kname)
|
||||||
|
np.save(os.path.join(self.web_dir,'%s_x')%kname,x)
|
||||||
|
np.save(os.path.join(self.web_dir,'%s_y')%kname,y)
|
||||||
|
|
||||||
|
if(to_plot):
|
||||||
|
plt.legend(loc=0,fontsize='small')
|
||||||
|
plt.xlabel('epoch')
|
||||||
|
plt.ylabel('Value')
|
||||||
|
f.savefig(os.path.join(self.web_dir,'%s.png'%name))
|
||||||
|
f.clf()
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
# errors: dictionary of error labels and values
|
||||||
|
def plot_current_errors(self, epoch, counter_ratio, opt, errors):
|
||||||
|
if not hasattr(self, 'plot_data'):
|
||||||
|
self.plot_data = {'X':[],'Y':[], 'legend':list(errors.keys())}
|
||||||
|
self.plot_data['X'].append(epoch + counter_ratio)
|
||||||
|
self.plot_data['Y'].append([errors[k] for k in self.plot_data['legend']])
|
||||||
|
self.vis.line(
|
||||||
|
X=np.stack([np.array(self.plot_data['X'])]*len(self.plot_data['legend']),1),
|
||||||
|
Y=np.array(self.plot_data['Y']),
|
||||||
|
opts={
|
||||||
|
'title': self.name + ' loss over time',
|
||||||
|
'legend': self.plot_data['legend'],
|
||||||
|
'xlabel': 'epoch',
|
||||||
|
'ylabel': 'loss'},
|
||||||
|
win=self.display_id)
|
||||||
|
|
||||||
|
# errors: same format as |errors| of plotCurrentErrors
|
||||||
|
def print_current_errors(self, epoch, i, errors, t, t2=-1, t2o=-1, fid=None):
|
||||||
|
message = '(ep: %d, it: %d, t: %.3f[s], ept: %.2f/%.2f[h]) ' % (epoch, i, t, t2o, t2)
|
||||||
|
message += (', ').join(['%s: %.3f' % (k, v) for k, v in errors.items()])
|
||||||
|
|
||||||
|
print(message)
|
||||||
|
if(fid is not None):
|
||||||
|
fid.write('%s\n'%message)
|
||||||
|
|
||||||
|
|
||||||
|
# save image to the disk
|
||||||
|
def save_images_simple(self, webpage, images, names, in_txts, prefix='', res=256):
|
||||||
|
image_dir = webpage.get_image_dir()
|
||||||
|
ims = []
|
||||||
|
txts = []
|
||||||
|
links = []
|
||||||
|
|
||||||
|
for name, image_numpy, txt in zip(names, images, in_txts):
|
||||||
|
image_name = '%s_%s.png' % (prefix, name)
|
||||||
|
save_path = os.path.join(image_dir, image_name)
|
||||||
|
if(res is not None):
|
||||||
|
util.save_image(zoom_to_res(image_numpy,res=res,axis=2), save_path)
|
||||||
|
else:
|
||||||
|
util.save_image(image_numpy, save_path)
|
||||||
|
|
||||||
|
ims.append(os.path.join(webpage.img_subdir,image_name))
|
||||||
|
# txts.append(name)
|
||||||
|
txts.append(txt)
|
||||||
|
links.append(os.path.join(webpage.img_subdir,image_name))
|
||||||
|
# embed()
|
||||||
|
webpage.add_images(ims, txts, links, width=self.win_size)
|
||||||
|
|
||||||
|
# save image to the disk
|
||||||
|
def save_images(self, webpage, images, names, image_path, title=''):
|
||||||
|
image_dir = webpage.get_image_dir()
|
||||||
|
# short_path = ntpath.basename(image_path)
|
||||||
|
# name = os.path.splitext(short_path)[0]
|
||||||
|
# name = short_path
|
||||||
|
# webpage.add_header('%s, %s' % (name, title))
|
||||||
|
ims = []
|
||||||
|
txts = []
|
||||||
|
links = []
|
||||||
|
|
||||||
|
for label, image_numpy in zip(names, images):
|
||||||
|
image_name = '%s.jpg' % (label,)
|
||||||
|
save_path = os.path.join(image_dir, image_name)
|
||||||
|
util.save_image(image_numpy, save_path)
|
||||||
|
|
||||||
|
ims.append(image_name)
|
||||||
|
txts.append(label)
|
||||||
|
links.append(image_name)
|
||||||
|
webpage.add_images(ims, txts, links, width=self.win_size)
|
||||||
|
|
||||||
|
# save image to the disk
|
||||||
|
# def save_images(self, webpage, visuals, image_path, short=False):
|
||||||
|
# image_dir = webpage.get_image_dir()
|
||||||
|
# if short:
|
||||||
|
# short_path = ntpath.basename(image_path)
|
||||||
|
# name = os.path.splitext(short_path)[0]
|
||||||
|
# else:
|
||||||
|
# name = image_path
|
||||||
|
|
||||||
|
# webpage.add_header(name)
|
||||||
|
# ims = []
|
||||||
|
# txts = []
|
||||||
|
# links = []
|
||||||
|
|
||||||
|
# for label, image_numpy in visuals.items():
|
||||||
|
# image_name = '%s_%s.png' % (name, label)
|
||||||
|
# save_path = os.path.join(image_dir, image_name)
|
||||||
|
# util.save_image(image_numpy, save_path)
|
||||||
|
|
||||||
|
# ims.append(image_name)
|
||||||
|
# txts.append(label)
|
||||||
|
# links.append(image_name)
|
||||||
|
# webpage.add_images(ims, txts, links, width=self.win_size)
|
||||||
Loading…
Reference in new issue