commit
df69507ef2
@ -0,0 +1,62 @@
|
||||
from typing import List
|
||||
|
||||
import albumentations as albu
|
||||
from torchvision import transforms
|
||||
|
||||
def get_transforms(size: int, scope: str = 'geometric', crop='random'):
|
||||
augs = {'strong': albu.Compose([albu.HorizontalFlip(),
|
||||
albu.ShiftScaleRotate(shift_limit=0.0, scale_limit=0.2, rotate_limit=20, p=.4),
|
||||
albu.ElasticTransform(),
|
||||
albu.OpticalDistortion(),
|
||||
albu.OneOf([
|
||||
albu.CLAHE(clip_limit=2),
|
||||
albu.IAASharpen(),
|
||||
albu.IAAEmboss(),
|
||||
albu.RandomBrightnessContrast(),
|
||||
albu.RandomGamma()
|
||||
], p=0.5),
|
||||
albu.OneOf([
|
||||
albu.RGBShift(),
|
||||
albu.HueSaturationValue(),
|
||||
], p=0.5),
|
||||
]),
|
||||
'weak': albu.Compose([albu.HorizontalFlip(),
|
||||
]),
|
||||
'geometric': albu.Compose([albu.HorizontalFlip(),
|
||||
albu.VerticalFlip(),
|
||||
albu.RandomRotate90(),
|
||||
]),
|
||||
'None': None
|
||||
}
|
||||
|
||||
aug_fn = augs[scope]
|
||||
crop_fn = {'random': albu.RandomCrop(size, size, always_apply=True),
|
||||
'center': albu.CenterCrop(size, size, always_apply=True)}[crop]
|
||||
|
||||
pipeline = albu.Compose([aug_fn, crop_fn], additional_targets={'target': 'image'})
|
||||
|
||||
|
||||
def process(a, b):
|
||||
r = pipeline(image=a, target=b)
|
||||
return r['image'], r['target']
|
||||
|
||||
return process
|
||||
|
||||
|
||||
def get_normalize():
|
||||
transform = transforms.Compose([
|
||||
transforms.ToTensor()
|
||||
])
|
||||
|
||||
def process(a, b):
|
||||
image = transform(a).permute(1, 2, 0) - 0.5
|
||||
target = transform(b).permute(1, 2, 0) - 0.5
|
||||
return image, target
|
||||
|
||||
return process
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@ -0,0 +1,40 @@
|
||||
---
|
||||
experiment_desc: Stripformer_gopro
|
||||
|
||||
train:
|
||||
files_a: ./datasets/GoPro/train/blur/**/*.png
|
||||
files_b: ./datasets/GoPro/train/sharp/**/*.png
|
||||
size: &SIZE 512
|
||||
crop: random
|
||||
preload: &PRELOAD false
|
||||
preload_size: &PRELOAD_SIZE 0
|
||||
bounds: [0, 1]
|
||||
scope: geometric
|
||||
|
||||
val:
|
||||
files_a: ./datasets/GoPro/test/blur/**/*.png
|
||||
files_b: ./datasets/GoPro/test/sharp/**/*.png
|
||||
size: *SIZE
|
||||
scope: None
|
||||
crop: random
|
||||
preload: *PRELOAD
|
||||
preload_size: *PRELOAD_SIZE
|
||||
bounds: [0, 1]
|
||||
|
||||
model:
|
||||
g_name: Stripformer
|
||||
content_loss: Stripformer_Loss
|
||||
|
||||
num_epochs: 1000
|
||||
train_batches_per_epoch: 2103
|
||||
val_batches_per_epoch: 1111
|
||||
batch_size: 8
|
||||
image_size: [512, 512]
|
||||
|
||||
optimizer:
|
||||
name: adam
|
||||
lr: 0.0001
|
||||
scheduler:
|
||||
name: cosine
|
||||
start_epoch: 50
|
||||
min_lr: 0.0000001
|
||||
@ -0,0 +1,40 @@
|
||||
---
|
||||
experiment_desc: Stripformer_pretrained
|
||||
|
||||
train:
|
||||
files_a: ./datasets/GoPro/train/blur/**/*.png
|
||||
files_b: ./datasets/GoPro/train/sharp/**/*.png
|
||||
size: &SIZE 256
|
||||
crop: random
|
||||
preload: &PRELOAD false
|
||||
preload_size: &PRELOAD_SIZE 0
|
||||
bounds: [0, 1]
|
||||
scope: geometric
|
||||
|
||||
val:
|
||||
files_a: ./datasets/GoPro/test/blur/**/*.png
|
||||
files_b: ./datasets/GoPro/test/sharp/**/*.png
|
||||
size: *SIZE
|
||||
scope: None
|
||||
crop: random
|
||||
preload: *PRELOAD
|
||||
preload_size: *PRELOAD_SIZE
|
||||
bounds: [0, 1]
|
||||
|
||||
model:
|
||||
g_name: Stripformer
|
||||
content_loss: Stripformer_Loss
|
||||
|
||||
num_epochs: 3000
|
||||
train_batches_per_epoch: 2103
|
||||
val_batches_per_epoch: 1111
|
||||
batch_size: 8
|
||||
image_size: [256, 256]
|
||||
|
||||
optimizer:
|
||||
name: adam
|
||||
lr: 0.0001
|
||||
scheduler:
|
||||
name: cosine
|
||||
start_epoch: 50
|
||||
min_lr: 0.0000001
|
||||
@ -0,0 +1,140 @@
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
from glob import glob
|
||||
from hashlib import sha1
|
||||
from typing import Callable, Iterable, Optional, Tuple
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from glog import logger
|
||||
from joblib import Parallel, cpu_count, delayed
|
||||
from skimage.io import imread
|
||||
from torch.utils.data import Dataset
|
||||
from tqdm import tqdm
|
||||
|
||||
import aug
|
||||
|
||||
|
||||
def subsample(data: Iterable, bounds: Tuple[float, float], hash_fn: Callable, n_buckets=100, salt='', verbose=True):
|
||||
data = list(data)
|
||||
buckets = split_into_buckets(data, n_buckets=n_buckets, salt=salt, hash_fn=hash_fn)
|
||||
|
||||
lower_bound, upper_bound = [x * n_buckets for x in bounds]
|
||||
msg = f'Subsampling buckets from {lower_bound} to {upper_bound}, total buckets number is {n_buckets}'
|
||||
if salt:
|
||||
msg += f'; salt is {salt}'
|
||||
if verbose:
|
||||
logger.info(msg)
|
||||
return np.array([sample for bucket, sample in zip(buckets, data) if lower_bound <= bucket < upper_bound])
|
||||
|
||||
|
||||
def hash_from_paths(x: Tuple[str, str], salt: str = '') -> str:
|
||||
path_a, path_b = x
|
||||
names = ''.join(map(os.path.basename, (path_a, path_b)))
|
||||
return sha1(f'{names}_{salt}'.encode()).hexdigest()
|
||||
|
||||
|
||||
def split_into_buckets(data: Iterable, n_buckets: int, hash_fn: Callable, salt=''):
|
||||
hashes = map(partial(hash_fn, salt=salt), data)
|
||||
return np.array([int(x, 16) % n_buckets for x in hashes])
|
||||
|
||||
|
||||
def _read_img(x: str):
|
||||
img = cv2.imread(x)
|
||||
if img is None:
|
||||
logger.warning(f'Can not read image {x} with OpenCV, switching to scikit-image')
|
||||
img = imread(x)
|
||||
return img
|
||||
|
||||
|
||||
class PairedDataset(Dataset):
|
||||
def __init__(self,
|
||||
files_a: Tuple[str],
|
||||
files_b: Tuple[str],
|
||||
transform_fn: Callable,
|
||||
normalize_fn: Callable,
|
||||
corrupt_fn: Optional[Callable] = None,
|
||||
preload: bool = True,
|
||||
preload_size: Optional[int] = 0,
|
||||
verbose=True):
|
||||
|
||||
assert len(files_a) == len(files_b)
|
||||
|
||||
self.preload = preload
|
||||
self.data_a = files_a
|
||||
self.data_b = files_b
|
||||
self.verbose = verbose
|
||||
self.corrupt_fn = corrupt_fn
|
||||
self.transform_fn = transform_fn
|
||||
self.normalize_fn = normalize_fn
|
||||
logger.info(f'Dataset has been created with {len(self.data_a)} samples')
|
||||
|
||||
if preload:
|
||||
preload_fn = partial(self._bulk_preload, preload_size=preload_size)
|
||||
if files_a == files_b:
|
||||
self.data_a = self.data_b = preload_fn(self.data_a)
|
||||
else:
|
||||
self.data_a, self.data_b = map(preload_fn, (self.data_a, self.data_b))
|
||||
self.preload = True
|
||||
|
||||
def _bulk_preload(self, data: Iterable[str], preload_size: int):
|
||||
jobs = [delayed(self._preload)(x, preload_size=preload_size) for x in data]
|
||||
jobs = tqdm(jobs, desc='preloading images', disable=not self.verbose)
|
||||
return Parallel(n_jobs=cpu_count(), backend='threading')(jobs)
|
||||
|
||||
@staticmethod
|
||||
def _preload(x: str, preload_size: int):
|
||||
img = _read_img(x)
|
||||
if preload_size:
|
||||
h, w, *_ = img.shape
|
||||
h_scale = preload_size / h
|
||||
w_scale = preload_size / w
|
||||
scale = max(h_scale, w_scale)
|
||||
img = cv2.resize(img, fx=scale, fy=scale, dsize=None)
|
||||
assert min(img.shape[:2]) >= preload_size, f'weird img shape: {img.shape}'
|
||||
return img
|
||||
|
||||
def _preprocess(self, img, res):
|
||||
def transpose(x):
|
||||
return np.transpose(x, (2, 0, 1))
|
||||
|
||||
return map(transpose, self.normalize_fn(img, res))
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data_a)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
a, b = self.data_a[idx], self.data_b[idx]
|
||||
if not self.preload:
|
||||
a, b = map(_read_img, (a, b))
|
||||
a, b = self.transform_fn(a, b)
|
||||
if self.corrupt_fn is not None:
|
||||
a = self.corrupt_fn(a)
|
||||
a, b = self._preprocess(a, b)
|
||||
return {'a': a, 'b': b}
|
||||
|
||||
@staticmethod
|
||||
def from_config(config):
|
||||
config = deepcopy(config)
|
||||
files_a, files_b = map(lambda x: sorted(glob(config[x], recursive=True)), ('files_a', 'files_b'))
|
||||
transform_fn = aug.get_transforms(size=config['size'], scope=config['scope'], crop=config['crop'])
|
||||
normalize_fn = aug.get_normalize()
|
||||
|
||||
hash_fn = hash_from_paths
|
||||
# ToDo: add more hash functions
|
||||
verbose = config.get('verbose', True)
|
||||
data = subsample(data=zip(files_a, files_b),
|
||||
bounds=config.get('bounds', (0, 1)),
|
||||
hash_fn=hash_fn,
|
||||
verbose=verbose)
|
||||
|
||||
files_a, files_b = map(list, zip(*data))
|
||||
|
||||
return PairedDataset(files_a=files_a,
|
||||
files_b=files_b,
|
||||
preload=config['preload'],
|
||||
preload_size=config['preload_size'],
|
||||
normalize_fn=normalize_fn,
|
||||
transform_fn=transform_fn,
|
||||
verbose=verbose)
|
||||
@ -0,0 +1 @@
|
||||
Download 'GoPro' datasets and put the datasets into folder './datasets'
|
||||
@ -0,0 +1,101 @@
|
||||
import os
|
||||
from skimage import io
|
||||
import cv2
|
||||
import numpy as np
|
||||
from skimage.metrics import structural_similarity
|
||||
import concurrent.futures
|
||||
|
||||
def image_align(deblurred, gt):
|
||||
# this function is based on kohler evaluation code
|
||||
z = deblurred
|
||||
c = np.ones_like(z)
|
||||
x = gt
|
||||
|
||||
zs = (np.sum(x * z) / np.sum(z * z)) * z # simple intensity matching
|
||||
|
||||
warp_mode = cv2.MOTION_HOMOGRAPHY
|
||||
warp_matrix = np.eye(3, 3, dtype=np.float32)
|
||||
|
||||
# Specify the number of iterations.
|
||||
number_of_iterations = 100
|
||||
|
||||
termination_eps = 0
|
||||
|
||||
criteria = (cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT,
|
||||
number_of_iterations, termination_eps)
|
||||
|
||||
# Run the ECC algorithm. The results are stored in warp_matrix.
|
||||
(cc, warp_matrix) = cv2.findTransformECC(cv2.cvtColor(x, cv2.COLOR_RGB2GRAY), cv2.cvtColor(zs, cv2.COLOR_RGB2GRAY),
|
||||
warp_matrix, warp_mode, criteria, inputMask=None, gaussFiltSize=5)
|
||||
|
||||
target_shape = x.shape
|
||||
shift = warp_matrix
|
||||
|
||||
zr = cv2.warpPerspective(
|
||||
zs,
|
||||
warp_matrix,
|
||||
(target_shape[1], target_shape[0]),
|
||||
flags=cv2.INTER_CUBIC + cv2.WARP_INVERSE_MAP,
|
||||
borderMode=cv2.BORDER_REFLECT)
|
||||
|
||||
cr = cv2.warpPerspective(
|
||||
np.ones_like(zs, dtype='float32'),
|
||||
warp_matrix,
|
||||
(target_shape[1], target_shape[0]),
|
||||
flags=cv2.INTER_NEAREST + cv2.WARP_INVERSE_MAP,
|
||||
borderMode=cv2.BORDER_CONSTANT,
|
||||
borderValue=0)
|
||||
|
||||
zr = zr * cr
|
||||
xr = x * cr
|
||||
|
||||
return zr, xr, cr, shift
|
||||
|
||||
def compute_psnr(image_true, image_test, image_mask, data_range=None):
|
||||
# this function is based on skimage.metrics.peak_signal_noise_ratio
|
||||
err = np.sum((image_true - image_test) ** 2, dtype=np.float64) / np.sum(image_mask)
|
||||
return 10 * np.log10((data_range ** 2) / err)
|
||||
|
||||
|
||||
def compute_ssim(tar_img, prd_img, cr1):
|
||||
ssim_pre, ssim_map = structural_similarity(tar_img, prd_img, multichannel=True, gaussian_weights=True,
|
||||
use_sample_covariance=False, data_range=1.0, full=True)
|
||||
ssim_map = ssim_map * cr1
|
||||
r = int(3.5 * 1.5 + 0.5) # radius as in ndimage
|
||||
win_size = 2 * r + 1
|
||||
pad = (win_size - 1) // 2
|
||||
ssim = ssim_map[pad:-pad, pad:-pad, :]
|
||||
crop_cr1 = cr1[pad:-pad, pad:-pad, :]
|
||||
ssim = ssim.sum(axis=0).sum(axis=0) / crop_cr1.sum(axis=0).sum(axis=0)
|
||||
ssim = np.mean(ssim)
|
||||
return ssim
|
||||
|
||||
total_psnr = 0.
|
||||
total_ssim = 0.
|
||||
count = 0
|
||||
img_path = './out/Stripformer_realblur_J_results'
|
||||
gt_path = './datasets/Realblur_J/test/sharp'
|
||||
print(img_path)
|
||||
for file in os.listdir(img_path):
|
||||
for img_name in os.listdir(img_path + '/' + file):
|
||||
count += 1
|
||||
number = img_name.split('_')[1]
|
||||
gt_name = 'gt_' + number
|
||||
img_dir = img_path + '/' + file + '/' + img_name
|
||||
gt_dir = gt_path + '/' + file + '/' + gt_name
|
||||
with concurrent.futures.ProcessPoolExecutor(max_workers=10) as executor:
|
||||
tar_img = io.imread(gt_dir)
|
||||
prd_img = io.imread(img_dir)
|
||||
tar_img = tar_img.astype(np.float32) / 255.0
|
||||
prd_img = prd_img.astype(np.float32) / 255.0
|
||||
prd_img, tar_img, cr1, shift = image_align(prd_img, tar_img)
|
||||
PSNR = compute_psnr(tar_img, prd_img, cr1, data_range=1)
|
||||
SSIM = compute_ssim(tar_img, prd_img, cr1)
|
||||
total_psnr += PSNR
|
||||
total_ssim += SSIM
|
||||
print(count, PSNR)
|
||||
|
||||
print('PSNR:', total_psnr / count)
|
||||
print('SSIM:', total_ssim / count)
|
||||
print(img_path)
|
||||
|
||||
@ -0,0 +1,101 @@
|
||||
import os
|
||||
from skimage import io
|
||||
import cv2
|
||||
import numpy as np
|
||||
from skimage.metrics import structural_similarity
|
||||
import concurrent.futures
|
||||
|
||||
def image_align(deblurred, gt):
|
||||
# this function is based on kohler evaluation code
|
||||
z = deblurred
|
||||
c = np.ones_like(z)
|
||||
x = gt
|
||||
|
||||
zs = (np.sum(x * z) / np.sum(z * z)) * z # simple intensity matching
|
||||
|
||||
warp_mode = cv2.MOTION_HOMOGRAPHY
|
||||
warp_matrix = np.eye(3, 3, dtype=np.float32)
|
||||
|
||||
# Specify the number of iterations.
|
||||
number_of_iterations = 100
|
||||
|
||||
termination_eps = 0
|
||||
|
||||
criteria = (cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT,
|
||||
number_of_iterations, termination_eps)
|
||||
|
||||
# Run the ECC algorithm. The results are stored in warp_matrix.
|
||||
(cc, warp_matrix) = cv2.findTransformECC(cv2.cvtColor(x, cv2.COLOR_RGB2GRAY), cv2.cvtColor(zs, cv2.COLOR_RGB2GRAY),
|
||||
warp_matrix, warp_mode, criteria, inputMask=None, gaussFiltSize=5)
|
||||
|
||||
target_shape = x.shape
|
||||
shift = warp_matrix
|
||||
|
||||
zr = cv2.warpPerspective(
|
||||
zs,
|
||||
warp_matrix,
|
||||
(target_shape[1], target_shape[0]),
|
||||
flags=cv2.INTER_CUBIC + cv2.WARP_INVERSE_MAP,
|
||||
borderMode=cv2.BORDER_REFLECT)
|
||||
|
||||
cr = cv2.warpPerspective(
|
||||
np.ones_like(zs, dtype='float32'),
|
||||
warp_matrix,
|
||||
(target_shape[1], target_shape[0]),
|
||||
flags=cv2.INTER_NEAREST + cv2.WARP_INVERSE_MAP,
|
||||
borderMode=cv2.BORDER_CONSTANT,
|
||||
borderValue=0)
|
||||
|
||||
zr = zr * cr
|
||||
xr = x * cr
|
||||
|
||||
return zr, xr, cr, shift
|
||||
|
||||
def compute_psnr(image_true, image_test, image_mask, data_range=None):
|
||||
# this function is based on skimage.metrics.peak_signal_noise_ratio
|
||||
err = np.sum((image_true - image_test) ** 2, dtype=np.float64) / np.sum(image_mask)
|
||||
return 10 * np.log10((data_range ** 2) / err)
|
||||
|
||||
|
||||
def compute_ssim(tar_img, prd_img, cr1):
|
||||
ssim_pre, ssim_map = structural_similarity(tar_img, prd_img, multichannel=True, gaussian_weights=True,
|
||||
use_sample_covariance=False, data_range=1.0, full=True)
|
||||
ssim_map = ssim_map * cr1
|
||||
r = int(3.5 * 1.5 + 0.5) # radius as in ndimage
|
||||
win_size = 2 * r + 1
|
||||
pad = (win_size - 1) // 2
|
||||
ssim = ssim_map[pad:-pad, pad:-pad, :]
|
||||
crop_cr1 = cr1[pad:-pad, pad:-pad, :]
|
||||
ssim = ssim.sum(axis=0).sum(axis=0) / crop_cr1.sum(axis=0).sum(axis=0)
|
||||
ssim = np.mean(ssim)
|
||||
return ssim
|
||||
|
||||
total_psnr = 0.
|
||||
total_ssim = 0.
|
||||
count = 0
|
||||
img_path = './out/Stripformer_realblur_R_results'
|
||||
gt_path = './datasets/Realblur_R/test/sharp'
|
||||
print(img_path)
|
||||
for file in os.listdir(img_path):
|
||||
for img_name in os.listdir(img_path + '/' + file):
|
||||
count += 1
|
||||
number = img_name.split('_')[1]
|
||||
gt_name = 'gt_' + number
|
||||
img_dir = img_path + '/' + file + '/' + img_name
|
||||
gt_dir = gt_path + '/' + file + '/' + gt_name
|
||||
with concurrent.futures.ProcessPoolExecutor(max_workers=10) as executor:
|
||||
tar_img = io.imread(gt_dir)
|
||||
prd_img = io.imread(img_dir)
|
||||
tar_img = tar_img.astype(np.float32) / 255.0
|
||||
prd_img = prd_img.astype(np.float32) / 255.0
|
||||
prd_img, tar_img, cr1, shift = image_align(prd_img, tar_img)
|
||||
PSNR = compute_psnr(tar_img, prd_img, cr1, data_range=1)
|
||||
SSIM = compute_ssim(tar_img, prd_img, cr1)
|
||||
total_psnr += PSNR
|
||||
total_ssim += SSIM
|
||||
print(count, PSNR)
|
||||
|
||||
print('PSNR:', total_psnr / count)
|
||||
print('SSIM:', total_ssim / count)
|
||||
print(img_path)
|
||||
|
||||
@ -0,0 +1,60 @@
|
||||
p = genpath('.\out\stripformer_10_patches');% GoPro Deblur Results
|
||||
gt = genpath('.\datasets\GoPro\test\sharp');% GoPro GT Results
|
||||
|
||||
length_p = size(p,2);
|
||||
path = {};
|
||||
temp = [];
|
||||
for i = 1:length_p
|
||||
if p(i) ~= ';'
|
||||
temp = [temp p(i)];
|
||||
else
|
||||
temp = [temp '\'];
|
||||
path = [path ; temp];
|
||||
temp = [];
|
||||
end
|
||||
end
|
||||
clear p length_p temp;
|
||||
length_gt = size(gt,2);
|
||||
path_gt = {};
|
||||
temp_gt = [];
|
||||
for i = 1:length_gt
|
||||
if gt(i) ~= ';'
|
||||
temp_gt = [temp_gt gt(i)];
|
||||
else
|
||||
temp_gt = [temp_gt '\'];
|
||||
path_gt = [path_gt ; temp_gt];
|
||||
temp_gt = [];
|
||||
end
|
||||
end
|
||||
clear gt length_gt temp_gt;
|
||||
|
||||
file_num = size(path,1);
|
||||
total_psnr = 0;
|
||||
n = 0;
|
||||
total_ssim = 0;
|
||||
for i = 1:file_num
|
||||
file_path = path{i};
|
||||
gt_file_path = path_gt{i};
|
||||
img_path_list = dir(strcat(file_path,'*.png'));
|
||||
gt_path_list = dir(strcat(gt_file_path,'*.png'));
|
||||
img_num = length(img_path_list);
|
||||
if img_num > 0
|
||||
for j = 1:img_num
|
||||
image_name = img_path_list(j).name;
|
||||
gt_name = gt_path_list(j).name;
|
||||
image = imread(strcat(file_path,image_name));
|
||||
gt = imread(strcat(gt_file_path,gt_name));
|
||||
size(image);
|
||||
size(gt);
|
||||
peaksnr = psnr(image,gt);
|
||||
ssimval = ssim(image,gt);
|
||||
total_psnr = total_psnr + peaksnr;
|
||||
total_ssim = total_ssim + ssimval;
|
||||
n = n + 1
|
||||
end
|
||||
end
|
||||
end
|
||||
psnr = total_psnr / n
|
||||
ssim = total_ssim / n
|
||||
close all;clear all;
|
||||
|
||||
@ -0,0 +1,60 @@
|
||||
p = genpath('.\out\Stripformer_HIDE_result');% HIDE Deblur Results
|
||||
gt = genpath('.\datasets\HIDE\sharp');% HIDE GT Results
|
||||
|
||||
length_p = size(p,2);
|
||||
path = {};
|
||||
temp = [];
|
||||
for i = 1:length_p
|
||||
if p(i) ~= ';'
|
||||
temp = [temp p(i)];
|
||||
else
|
||||
temp = [temp '\'];
|
||||
path = [path ; temp];
|
||||
temp = [];
|
||||
end
|
||||
end
|
||||
clear p length_p temp;
|
||||
length_gt = size(gt,2);
|
||||
path_gt = {};
|
||||
temp_gt = [];
|
||||
for i = 1:length_gt
|
||||
if gt(i) ~= ';'
|
||||
temp_gt = [temp_gt gt(i)];
|
||||
else
|
||||
temp_gt = [temp_gt '\'];
|
||||
path_gt = [path_gt ; temp_gt];
|
||||
temp_gt = [];
|
||||
end
|
||||
end
|
||||
clear gt length_gt temp_gt;
|
||||
|
||||
file_num = size(path,1);
|
||||
total_psnr = 0;
|
||||
n = 0;
|
||||
total_ssim = 0;
|
||||
for i = 1:file_num
|
||||
file_path = path{i};
|
||||
gt_file_path = path_gt{i};
|
||||
img_path_list = dir(strcat(file_path,'*.png'));
|
||||
gt_path_list = dir(strcat(gt_file_path,'*.png'));
|
||||
img_num = length(img_path_list);
|
||||
if img_num > 0
|
||||
for j = 1:img_num
|
||||
image_name = img_path_list(j).name;
|
||||
gt_name = gt_path_list(j).name;
|
||||
image = imread(strcat(file_path,image_name));
|
||||
gt = imread(strcat(gt_file_path,gt_name));
|
||||
size(image);
|
||||
size(gt);
|
||||
peaksnr = psnr(image,gt);
|
||||
ssimval = ssim(image,gt);
|
||||
total_psnr = total_psnr + peaksnr;
|
||||
total_ssim = total_ssim + ssimval;
|
||||
n = n + 1
|
||||
end
|
||||
end
|
||||
end
|
||||
psnr = total_psnr / n
|
||||
ssim = total_ssim / n
|
||||
close all;clear all;
|
||||
|
||||
@ -0,0 +1,55 @@
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
|
||||
import numpy as np
|
||||
from tensorboardX import SummaryWriter
|
||||
|
||||
WINDOW_SIZE = 100
|
||||
|
||||
|
||||
class MetricCounter:
|
||||
def __init__(self, exp_name):
|
||||
self.writer = SummaryWriter(exp_name)
|
||||
logging.basicConfig(filename='{}.log'.format(exp_name), level=logging.DEBUG)
|
||||
self.metrics = defaultdict(list)
|
||||
self.images = defaultdict(list)
|
||||
self.best_metric = 0
|
||||
|
||||
def add_image(self, x: np.ndarray, tag: str):
|
||||
self.images[tag].append(x)
|
||||
|
||||
def clear(self):
|
||||
self.metrics = defaultdict(list)
|
||||
self.images = defaultdict(list)
|
||||
|
||||
def add_losses(self, l_G):
|
||||
for name, value in zip(('G_loss', None), (l_G, None)):
|
||||
self.metrics[name].append(value)
|
||||
|
||||
def add_metrics(self, psnr, ssim):
|
||||
for name, value in zip(('PSNR', 'SSIM'),
|
||||
(psnr, ssim)):
|
||||
self.metrics[name].append(value)
|
||||
|
||||
def loss_message(self):
|
||||
metrics = ((k, np.mean(self.metrics[k][-WINDOW_SIZE:])) for k in ('G_loss', 'PSNR', 'SSIM'))
|
||||
return '; '.join(map(lambda x: f'{x[0]}={x[1]:.4f}', metrics))
|
||||
|
||||
def write_to_tensorboard(self, epoch_num, validation=False):
|
||||
scalar_prefix = 'Validation' if validation else 'Train'
|
||||
for tag in ('G_loss', 'SSIM', 'PSNR'):
|
||||
self.writer.add_scalar(f'{scalar_prefix}_{tag}', np.mean(self.metrics[tag]), global_step=epoch_num)
|
||||
for tag in self.images:
|
||||
imgs = self.images[tag]
|
||||
if imgs:
|
||||
imgs = np.array(imgs)
|
||||
self.writer.add_images(tag, imgs[:, :, :, ::-1].astype('float32') / 255, dataformats='NHWC',
|
||||
global_step=epoch_num)
|
||||
self.images[tag] = []
|
||||
|
||||
def update_best_model(self):
|
||||
cur_metric = np.mean(self.metrics['PSNR'])
|
||||
if self.best_metric < cur_metric:
|
||||
self.best_metric = cur_metric
|
||||
return True
|
||||
return False
|
||||
@ -0,0 +1,374 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import math
|
||||
|
||||
class Embeddings(nn.Module):
|
||||
def __init__(self):
|
||||
super(Embeddings, self).__init__()
|
||||
|
||||
self.activation = nn.LeakyReLU(0.2, True)
|
||||
|
||||
self.en_layer1_1 = nn.Sequential(
|
||||
nn.Conv2d(3, 64, kernel_size=3, padding=1),
|
||||
self.activation,
|
||||
)
|
||||
self.en_layer1_2 = nn.Sequential(
|
||||
nn.Conv2d(64, 64, kernel_size=3, padding=1),
|
||||
self.activation,
|
||||
nn.Conv2d(64, 64, kernel_size=3, padding=1))
|
||||
self.en_layer1_3 = nn.Sequential(
|
||||
nn.Conv2d(64, 64, kernel_size=3, padding=1),
|
||||
self.activation,
|
||||
nn.Conv2d(64, 64, kernel_size=3, padding=1))
|
||||
self.en_layer1_4 = nn.Sequential(
|
||||
nn.Conv2d(64, 64, kernel_size=3, padding=1),
|
||||
self.activation,
|
||||
nn.Conv2d(64, 64, kernel_size=3, padding=1))
|
||||
|
||||
self.en_layer2_1 = nn.Sequential(
|
||||
nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
|
||||
self.activation,
|
||||
)
|
||||
self.en_layer2_2 = nn.Sequential(
|
||||
nn.Conv2d(128, 128, kernel_size=3, padding=1),
|
||||
self.activation,
|
||||
nn.Conv2d(128, 128, kernel_size=3, padding=1))
|
||||
self.en_layer2_3 = nn.Sequential(
|
||||
nn.Conv2d(128, 128, kernel_size=3, padding=1),
|
||||
self.activation,
|
||||
nn.Conv2d(128, 128, kernel_size=3, padding=1))
|
||||
self.en_layer2_4 = nn.Sequential(
|
||||
nn.Conv2d(128, 128, kernel_size=3, padding=1),
|
||||
self.activation,
|
||||
nn.Conv2d(128, 128, kernel_size=3, padding=1))
|
||||
|
||||
|
||||
self.en_layer3_1 = nn.Sequential(
|
||||
nn.Conv2d(128, 320, kernel_size=3, stride=2, padding=1),
|
||||
self.activation,
|
||||
)
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
hx = self.en_layer1_1(x)
|
||||
hx = self.activation(self.en_layer1_2(hx) + hx)
|
||||
hx = self.activation(self.en_layer1_3(hx) + hx)
|
||||
hx = self.activation(self.en_layer1_4(hx) + hx)
|
||||
residual_1 = hx
|
||||
hx = self.en_layer2_1(hx)
|
||||
hx = self.activation(self.en_layer2_2(hx) + hx)
|
||||
hx = self.activation(self.en_layer2_3(hx) + hx)
|
||||
hx = self.activation(self.en_layer2_4(hx) + hx)
|
||||
residual_2 = hx
|
||||
hx = self.en_layer3_1(hx)
|
||||
|
||||
return hx, residual_1, residual_2
|
||||
|
||||
|
||||
class Embeddings_output(nn.Module):
|
||||
def __init__(self):
|
||||
super(Embeddings_output, self).__init__()
|
||||
|
||||
self.activation = nn.LeakyReLU(0.2, True)
|
||||
|
||||
self.de_layer3_1 = nn.Sequential(
|
||||
nn.ConvTranspose2d(320, 192, kernel_size=4, stride=2, padding=1),
|
||||
self.activation,
|
||||
)
|
||||
head_num = 3
|
||||
dim = 192
|
||||
|
||||
self.de_layer2_2 = nn.Sequential(
|
||||
nn.Conv2d(192+128, 192, kernel_size=1, padding=0),
|
||||
self.activation,
|
||||
)
|
||||
|
||||
self.de_block_1 = Intra_SA(dim, head_num)
|
||||
self.de_block_2 = Inter_SA(dim, head_num)
|
||||
self.de_block_3 = Intra_SA(dim, head_num)
|
||||
self.de_block_4 = Inter_SA(dim, head_num)
|
||||
self.de_block_5 = Intra_SA(dim, head_num)
|
||||
self.de_block_6 = Inter_SA(dim, head_num)
|
||||
|
||||
|
||||
self.de_layer2_1 = nn.Sequential(
|
||||
nn.ConvTranspose2d(192, 64, kernel_size=4, stride=2, padding=1),
|
||||
self.activation,
|
||||
)
|
||||
|
||||
self.de_layer1_3 = nn.Sequential(
|
||||
nn.Conv2d(128, 64, kernel_size=1, padding=0),
|
||||
self.activation,
|
||||
nn.Conv2d(64, 64, kernel_size=3, padding=1))
|
||||
self.de_layer1_2 = nn.Sequential(
|
||||
nn.Conv2d(64, 64, kernel_size=3, padding=1),
|
||||
self.activation,
|
||||
nn.Conv2d(64, 64, kernel_size=3, padding=1))
|
||||
self.de_layer1_1 = nn.Sequential(
|
||||
nn.Conv2d(64, 3, kernel_size=3, padding=1),
|
||||
self.activation
|
||||
)
|
||||
|
||||
def forward(self, x, residual_1, residual_2):
|
||||
|
||||
|
||||
hx = self.de_layer3_1(x)
|
||||
|
||||
hx = self.de_layer2_2(torch.cat((hx, residual_2), dim = 1))
|
||||
hx = self.de_block_1(hx)
|
||||
hx = self.de_block_2(hx)
|
||||
hx = self.de_block_3(hx)
|
||||
hx = self.de_block_4(hx)
|
||||
hx = self.de_block_5(hx)
|
||||
hx = self.de_block_6(hx)
|
||||
hx = self.de_layer2_1(hx)
|
||||
|
||||
hx = self.activation(self.de_layer1_3(torch.cat((hx, residual_1), dim = 1)) + hx)
|
||||
hx = self.activation(self.de_layer1_2(hx) + hx)
|
||||
hx = self.de_layer1_1(hx)
|
||||
|
||||
return hx
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, head_num):
|
||||
super(Attention, self).__init__()
|
||||
self.num_attention_heads = head_num
|
||||
self.softmax = nn.Softmax(dim=-1)
|
||||
|
||||
def transpose_for_scores(self, x):
|
||||
B, N, C = x.size()
|
||||
attention_head_size = int(C / self.num_attention_heads)
|
||||
new_x_shape = x.size()[:-1] + (self.num_attention_heads, attention_head_size)
|
||||
x = x.view(*new_x_shape)
|
||||
return x.permute(0, 2, 1, 3).contiguous()
|
||||
|
||||
def forward(self, query_layer, key_layer, value_layer):
|
||||
B, N, C = query_layer.size()
|
||||
query_layer = self.transpose_for_scores(query_layer)
|
||||
key_layer = self.transpose_for_scores(key_layer)
|
||||
value_layer = self.transpose_for_scores(value_layer)
|
||||
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
||||
_, _, _, d = query_layer.size()
|
||||
attention_scores = attention_scores / math.sqrt(d)
|
||||
attention_probs = self.softmax(attention_scores)
|
||||
context_layer = torch.matmul(attention_probs, value_layer)
|
||||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||||
new_context_layer_shape = context_layer.size()[:-2] + (C,)
|
||||
attention_out = context_layer.view(*new_context_layer_shape)
|
||||
|
||||
return attention_out
|
||||
|
||||
|
||||
class Mlp(nn.Module):
|
||||
def __init__(self, hidden_size):
|
||||
super(Mlp, self).__init__()
|
||||
self.fc1 = nn.Linear(hidden_size, 4*hidden_size)
|
||||
self.fc2 = nn.Linear(4*hidden_size, hidden_size)
|
||||
self.act_fn = torch.nn.functional.gelu
|
||||
self._init_weights()
|
||||
|
||||
def _init_weights(self):
|
||||
nn.init.xavier_uniform_(self.fc1.weight)
|
||||
nn.init.xavier_uniform_(self.fc2.weight)
|
||||
nn.init.normal_(self.fc1.bias, std=1e-6)
|
||||
nn.init.normal_(self.fc2.bias, std=1e-6)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
x = self.act_fn(x)
|
||||
x = self.fc2(x)
|
||||
return x
|
||||
|
||||
|
||||
# CPE (Conditional Positional Embedding)
|
||||
class PEG(nn.Module):
|
||||
def __init__(self, hidden_size):
|
||||
super(PEG, self).__init__()
|
||||
self.PEG = nn.Conv2d(hidden_size, hidden_size, kernel_size=3, padding=1, groups=hidden_size)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.PEG(x) + x
|
||||
return x
|
||||
|
||||
|
||||
class Intra_SA(nn.Module):
|
||||
def __init__(self, dim, head_num):
|
||||
super(Intra_SA, self).__init__()
|
||||
self.hidden_size = dim // 2
|
||||
self.head_num = head_num
|
||||
self.attention_norm = nn.LayerNorm(dim)
|
||||
self.conv_input = nn.Conv2d(dim, dim, kernel_size=1, padding=0)
|
||||
self.qkv_local_h = nn.Linear(self.hidden_size, self.hidden_size * 3) # qkv_h
|
||||
self.qkv_local_v = nn.Linear(self.hidden_size, self.hidden_size * 3) # qkv_v
|
||||
self.fuse_out = nn.Conv2d(dim, dim, kernel_size=1, padding=0)
|
||||
self.ffn_norm = nn.LayerNorm(dim)
|
||||
self.ffn = Mlp(dim)
|
||||
self.attn = Attention(head_num=self.head_num)
|
||||
self.PEG = PEG(dim)
|
||||
def forward(self, x):
|
||||
h = x
|
||||
B, C, H, W = x.size()
|
||||
|
||||
x = x.view(B, C, H*W).permute(0, 2, 1).contiguous()
|
||||
x = self.attention_norm(x).permute(0, 2, 1).contiguous()
|
||||
x = x.view(B, C, H, W)
|
||||
|
||||
x_input = torch.chunk(self.conv_input(x), 2, dim=1)
|
||||
feature_h = (x_input[0]).permute(0, 2, 3, 1).contiguous()
|
||||
feature_h = feature_h.view(B * H, W, C//2)
|
||||
feature_v = (x_input[1]).permute(0, 3, 2, 1).contiguous()
|
||||
feature_v = feature_v.view(B * W, H, C//2)
|
||||
qkv_h = torch.chunk(self.qkv_local_h(feature_h), 3, dim=2)
|
||||
qkv_v = torch.chunk(self.qkv_local_v(feature_v), 3, dim=2)
|
||||
q_h, k_h, v_h = qkv_h[0], qkv_h[1], qkv_h[2]
|
||||
q_v, k_v, v_v = qkv_v[0], qkv_v[1], qkv_v[2]
|
||||
|
||||
if H == W:
|
||||
query = torch.cat((q_h, q_v), dim=0)
|
||||
key = torch.cat((k_h, k_v), dim=0)
|
||||
value = torch.cat((v_h, v_v), dim=0)
|
||||
attention_output = self.attn(query, key, value)
|
||||
attention_output = torch.chunk(attention_output, 2, dim=0)
|
||||
attention_output_h = attention_output[0]
|
||||
attention_output_v = attention_output[1]
|
||||
attention_output_h = attention_output_h.view(B, H, W, C//2).permute(0, 3, 1, 2).contiguous()
|
||||
attention_output_v = attention_output_v.view(B, W, H, C//2).permute(0, 3, 2, 1).contiguous()
|
||||
attn_out = self.fuse_out(torch.cat((attention_output_h, attention_output_v), dim=1))
|
||||
else:
|
||||
attention_output_h = self.attn(q_h, k_h, v_h)
|
||||
attention_output_v = self.attn(q_v, k_v, v_v)
|
||||
attention_output_h = attention_output_h.view(B, H, W, C//2).permute(0, 3, 1, 2).contiguous()
|
||||
attention_output_v = attention_output_v.view(B, W, H, C//2).permute(0, 3, 2, 1).contiguous()
|
||||
attn_out = self.fuse_out(torch.cat((attention_output_h, attention_output_v), dim=1))
|
||||
|
||||
x = attn_out + h
|
||||
x = x.view(B, C, H*W).permute(0, 2, 1).contiguous()
|
||||
h = x
|
||||
x = self.ffn_norm(x)
|
||||
x = self.ffn(x)
|
||||
x = x + h
|
||||
x = x.permute(0, 2, 1).contiguous()
|
||||
x = x.view(B, C, H, W)
|
||||
|
||||
x = self.PEG(x)
|
||||
|
||||
return x
|
||||
|
||||
class Inter_SA(nn.Module):
|
||||
def __init__(self,dim, head_num):
|
||||
super(Inter_SA, self).__init__()
|
||||
self.hidden_size = dim
|
||||
self.head_num = head_num
|
||||
self.attention_norm = nn.LayerNorm(self.hidden_size)
|
||||
self.conv_input = nn.Conv2d(self.hidden_size, self.hidden_size, kernel_size=1, padding=0)
|
||||
self.conv_h = nn.Conv2d(self.hidden_size//2, 3 * (self.hidden_size//2), kernel_size=1, padding=0) # qkv_h
|
||||
self.conv_v = nn.Conv2d(self.hidden_size//2, 3 * (self.hidden_size//2), kernel_size=1, padding=0) # qkv_v
|
||||
self.ffn_norm = nn.LayerNorm(self.hidden_size)
|
||||
self.ffn = Mlp(self.hidden_size)
|
||||
self.fuse_out = nn.Conv2d(self.hidden_size, self.hidden_size, kernel_size=1, padding=0)
|
||||
self.attn = Attention(head_num=self.head_num)
|
||||
self.PEG = PEG(dim)
|
||||
|
||||
def forward(self, x):
|
||||
h = x
|
||||
B, C, H, W = x.size()
|
||||
|
||||
x = x.view(B, C, H*W).permute(0, 2, 1).contiguous()
|
||||
x = self.attention_norm(x).permute(0, 2, 1).contiguous()
|
||||
x = x.view(B, C, H, W)
|
||||
|
||||
x_input = torch.chunk(self.conv_input(x), 2, dim=1)
|
||||
feature_h = torch.chunk(self.conv_h(x_input[0]), 3, dim=1)
|
||||
feature_v = torch.chunk(self.conv_v(x_input[1]), 3, dim=1)
|
||||
query_h, key_h, value_h = feature_h[0], feature_h[1], feature_h[2]
|
||||
query_v, key_v, value_v = feature_v[0], feature_v[1], feature_v[2]
|
||||
|
||||
horizontal_groups = torch.cat((query_h, key_h, value_h), dim=0)
|
||||
horizontal_groups = horizontal_groups.permute(0, 2, 1, 3).contiguous()
|
||||
horizontal_groups = horizontal_groups.view(3*B, H, -1)
|
||||
horizontal_groups = torch.chunk(horizontal_groups, 3, dim=0)
|
||||
query_h, key_h, value_h = horizontal_groups[0], horizontal_groups[1], horizontal_groups[2]
|
||||
|
||||
vertical_groups = torch.cat((query_v, key_v, value_v), dim=0)
|
||||
vertical_groups = vertical_groups.permute(0, 3, 1, 2).contiguous()
|
||||
vertical_groups = vertical_groups.view(3*B, W, -1)
|
||||
vertical_groups = torch.chunk(vertical_groups, 3, dim=0)
|
||||
query_v, key_v, value_v = vertical_groups[0], vertical_groups[1], vertical_groups[2]
|
||||
|
||||
|
||||
if H == W:
|
||||
query = torch.cat((query_h, query_v), dim=0)
|
||||
key = torch.cat((key_h, key_v), dim=0)
|
||||
value = torch.cat((value_h, value_v), dim=0)
|
||||
attention_output = self.attn(query, key, value)
|
||||
attention_output = torch.chunk(attention_output, 2, dim=0)
|
||||
attention_output_h = attention_output[0]
|
||||
attention_output_v = attention_output[1]
|
||||
attention_output_h = attention_output_h.view(B, H, C//2, W).permute(0, 2, 1, 3).contiguous()
|
||||
attention_output_v = attention_output_v.view(B, W, C//2, H).permute(0, 2, 3, 1).contiguous()
|
||||
attn_out = self.fuse_out(torch.cat((attention_output_h, attention_output_v), dim=1))
|
||||
else:
|
||||
attention_output_h = self.attn(query_h, key_h, value_h)
|
||||
attention_output_v = self.attn(query_v, key_v, value_v)
|
||||
attention_output_h = attention_output_h.view(B, H, C//2, W).permute(0, 2, 1, 3).contiguous()
|
||||
attention_output_v = attention_output_v.view(B, W, C//2, H).permute(0, 2, 3, 1).contiguous()
|
||||
attn_out = self.fuse_out(torch.cat((attention_output_h, attention_output_v), dim=1))
|
||||
|
||||
x = attn_out + h
|
||||
x = x.view(B, C, H*W).permute(0, 2, 1).contiguous()
|
||||
h = x
|
||||
x = self.ffn_norm(x)
|
||||
x = self.ffn(x)
|
||||
x = x + h
|
||||
x = x.permute(0, 2, 1).contiguous()
|
||||
x = x.view(B, C, H, W)
|
||||
|
||||
x = self.PEG(x)
|
||||
|
||||
return x
|
||||
|
||||
class Stripformer(nn.Module):
|
||||
def __init__(self):
|
||||
super(Stripformer, self).__init__()
|
||||
|
||||
self.encoder = Embeddings()
|
||||
head_num = 5
|
||||
dim = 320
|
||||
self.Trans_block_1 = Intra_SA(dim, head_num)
|
||||
self.Trans_block_2 = Inter_SA(dim, head_num)
|
||||
self.Trans_block_3 = Intra_SA(dim, head_num)
|
||||
self.Trans_block_4 = Inter_SA(dim, head_num)
|
||||
self.Trans_block_5 = Intra_SA(dim, head_num)
|
||||
self.Trans_block_6 = Inter_SA(dim, head_num)
|
||||
self.Trans_block_7 = Intra_SA(dim, head_num)
|
||||
self.Trans_block_8 = Inter_SA(dim, head_num)
|
||||
self.Trans_block_9 = Intra_SA(dim, head_num)
|
||||
self.Trans_block_10 = Inter_SA(dim, head_num)
|
||||
self.Trans_block_11 = Intra_SA(dim, head_num)
|
||||
self.Trans_block_12 = Inter_SA(dim, head_num)
|
||||
self.decoder = Embeddings_output()
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
hx, residual_1, residual_2 = self.encoder(x)
|
||||
hx = self.Trans_block_1(hx)
|
||||
hx = self.Trans_block_2(hx)
|
||||
hx = self.Trans_block_3(hx)
|
||||
hx = self.Trans_block_4(hx)
|
||||
hx = self.Trans_block_5(hx)
|
||||
hx = self.Trans_block_6(hx)
|
||||
hx = self.Trans_block_7(hx)
|
||||
hx = self.Trans_block_8(hx)
|
||||
hx = self.Trans_block_9(hx)
|
||||
hx = self.Trans_block_10(hx)
|
||||
hx = self.Trans_block_11(hx)
|
||||
hx = self.Trans_block_12(hx)
|
||||
hx = self.decoder(hx, residual_1, residual_2)
|
||||
|
||||
return hx + x
|
||||
|
||||
|
||||
|
||||
|
||||
@ -0,0 +1,131 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision.models as models
|
||||
import torch.nn.functional as F
|
||||
|
||||
class Vgg19(torch.nn.Module):
|
||||
def __init__(self, requires_grad=False):
|
||||
super(Vgg19, self).__init__()
|
||||
vgg_pretrained_features = models.vgg19(pretrained=True).features
|
||||
self.slice1 = torch.nn.Sequential()
|
||||
|
||||
for x in range(12):
|
||||
self.slice1.add_module(str(x), vgg_pretrained_features[x])
|
||||
|
||||
if not requires_grad:
|
||||
for param in self.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def forward(self, X):
|
||||
h_relu1 = self.slice1(X)
|
||||
return h_relu1
|
||||
|
||||
class ContrastLoss(nn.Module):
|
||||
def __init__(self, ablation=False):
|
||||
|
||||
super(ContrastLoss, self).__init__()
|
||||
self.vgg = Vgg19().cuda()
|
||||
self.l1 = nn.L1Loss()
|
||||
self.ab = ablation
|
||||
self.down_sample_4 = nn.Upsample(scale_factor=1 / 4, mode='bilinear')
|
||||
def forward(self, restore, sharp, blur):
|
||||
B, C, H, W = restore.size()
|
||||
restore_vgg, sharp_vgg, blur_vgg = self.vgg(restore), self.vgg(sharp), self.vgg(blur)
|
||||
|
||||
# filter out sharp regions
|
||||
threshold = 0.01
|
||||
mask = torch.mean(torch.abs(sharp-blur), dim=1).view(B, 1, H, W)
|
||||
mask[mask <= threshold] = 0
|
||||
mask[mask > threshold] = 1
|
||||
mask = self.down_sample_4(mask)
|
||||
d_ap = torch.mean(torch.abs((restore_vgg - sharp_vgg.detach())), dim=1).view(B, 1, H//4, W//4)
|
||||
d_an = torch.mean(torch.abs((restore_vgg - blur_vgg.detach())), dim=1).view(B, 1, H//4, W//4)
|
||||
mask_size = torch.sum(mask)
|
||||
contrastive = torch.sum((d_ap / (d_an + 1e-7)) * mask) / mask_size
|
||||
|
||||
return contrastive
|
||||
|
||||
|
||||
class ContrastLoss_Ori(nn.Module):
|
||||
def __init__(self, ablation=False):
|
||||
super(ContrastLoss_Ori, self).__init__()
|
||||
self.vgg = Vgg19().cuda()
|
||||
self.l1 = nn.L1Loss()
|
||||
self.ab = ablation
|
||||
|
||||
def forward(self, restore, sharp, blur):
|
||||
|
||||
restore_vgg, sharp_vgg, blur_vgg = self.vgg(restore), self.vgg(sharp), self.vgg(blur)
|
||||
d_ap = self.l1(restore_vgg, sharp_vgg.detach())
|
||||
d_an = self.l1(restore_vgg, blur_vgg.detach())
|
||||
contrastive_loss = d_ap / (d_an + 1e-7)
|
||||
|
||||
return contrastive_loss
|
||||
|
||||
class CharbonnierLoss(nn.Module):
|
||||
"""Charbonnier Loss (L1)"""
|
||||
|
||||
def __init__(self, eps=1e-3):
|
||||
super(CharbonnierLoss, self).__init__()
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, x, y):
|
||||
diff = x - y
|
||||
# loss = torch.sum(torch.sqrt(diff * diff + self.eps))
|
||||
loss = torch.mean(torch.sqrt((diff * diff) + (self.eps * self.eps)))
|
||||
return loss
|
||||
|
||||
|
||||
class EdgeLoss(nn.Module):
|
||||
def __init__(self):
|
||||
super(EdgeLoss, self).__init__()
|
||||
k = torch.Tensor([[.05, .25, .4, .25, .05]])
|
||||
self.kernel = torch.matmul(k.t(), k).unsqueeze(0).repeat(3, 1, 1, 1)
|
||||
if torch.cuda.is_available():
|
||||
self.kernel = self.kernel.cuda()
|
||||
self.loss = CharbonnierLoss()
|
||||
|
||||
def conv_gauss(self, img):
|
||||
n_channels, _, kw, kh = self.kernel.shape
|
||||
img = F.pad(img, (kw // 2, kh // 2, kw // 2, kh // 2), mode='replicate')
|
||||
return F.conv2d(img, self.kernel, groups=n_channels)
|
||||
|
||||
def laplacian_kernel(self, current):
|
||||
filtered = self.conv_gauss(current) # filter
|
||||
down = filtered[:, :, ::2, ::2] # downsample
|
||||
new_filter = torch.zeros_like(filtered)
|
||||
new_filter[:, :, ::2, ::2] = down * 4 # upsample
|
||||
filtered = self.conv_gauss(new_filter) # filter
|
||||
diff = current - filtered
|
||||
return diff
|
||||
|
||||
def forward(self, x, y):
|
||||
# x = torch.clamp(x + 0.5, min = 0,max = 1)
|
||||
# y = torch.clamp(y + 0.5, min = 0,max = 1)
|
||||
loss = self.loss(self.laplacian_kernel(x), self.laplacian_kernel(y))
|
||||
return loss
|
||||
|
||||
|
||||
class Stripformer_Loss(nn.Module):
|
||||
|
||||
def __init__(self, ):
|
||||
super(Stripformer_Loss, self).__init__()
|
||||
|
||||
self.char = CharbonnierLoss()
|
||||
self.edge = EdgeLoss()
|
||||
self.contrastive = ContrastLoss()
|
||||
|
||||
def forward(self, restore, sharp, blur):
|
||||
char = self.char(restore, sharp)
|
||||
edge = 0.05 * self.edge(restore, sharp)
|
||||
contrastive = 0.0005 * self.contrastive(restore, sharp, blur)
|
||||
loss = char + edge + contrastive
|
||||
return loss
|
||||
|
||||
|
||||
def get_loss(model):
|
||||
if model['content_loss'] == 'Stripformer_Loss':
|
||||
content_loss = Stripformer_Loss()
|
||||
else:
|
||||
raise ValueError("ContentLoss [%s] not recognized." % model['content_loss'])
|
||||
return content_loss
|
||||
@ -0,0 +1,35 @@
|
||||
import numpy as np
|
||||
import torch.nn as nn
|
||||
from skimage.measure import compare_ssim as SSIM
|
||||
|
||||
from util.metrics import PSNR
|
||||
|
||||
|
||||
class DeblurModel(nn.Module):
|
||||
def __init__(self):
|
||||
super(DeblurModel, self).__init__()
|
||||
|
||||
def get_input(self, data):
|
||||
img = data['a']
|
||||
inputs = img
|
||||
targets = data['b']
|
||||
inputs, targets = inputs.cuda(), targets.cuda()
|
||||
return inputs, targets
|
||||
|
||||
def tensor2im(self, image_tensor, imtype=np.uint8):
|
||||
image_numpy = image_tensor[0].cpu().float().numpy()
|
||||
image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 0.5) * 255.0
|
||||
return image_numpy
|
||||
|
||||
def get_images_and_metrics(self, inp, output, target) -> (float, float, np.ndarray):
|
||||
inp = self.tensor2im(inp)
|
||||
fake = self.tensor2im(output.data)
|
||||
real = self.tensor2im(target.data)
|
||||
psnr = PSNR(fake, real)
|
||||
ssim = SSIM(fake.astype('uint8'), real.astype('uint8'), multichannel=True)
|
||||
vis_img = np.hstack((inp, fake, real))
|
||||
return psnr, ssim, vis_img
|
||||
|
||||
|
||||
def get_model(model_config):
|
||||
return DeblurModel()
|
||||
@ -0,0 +1,13 @@
|
||||
import torch.nn as nn
|
||||
from models.Stripformer import Stripformer
|
||||
|
||||
def get_generator(model_config):
|
||||
generator_name = model_config['g_name']
|
||||
if generator_name == 'Stripformer':
|
||||
model_g = Stripformer()
|
||||
else:
|
||||
raise ValueError("Generator Network [%s] not recognized." % generator_name)
|
||||
return nn.DataParallel(model_g)
|
||||
|
||||
def get_nets(model_config):
|
||||
return get_generator(model_config)
|
||||
@ -0,0 +1 @@
|
||||
testing results are created in this folder
|
||||
@ -0,0 +1,68 @@
|
||||
from __future__ import print_function
|
||||
import numpy as np
|
||||
import torch
|
||||
import cv2
|
||||
import yaml
|
||||
import os
|
||||
from torch.autograd import Variable
|
||||
from models.networks import get_generator
|
||||
import torchvision
|
||||
import time
|
||||
import argparse
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser('Test an image')
|
||||
parser.add_argument('--weights_path', required=True, help='Weights path')
|
||||
return parser.parse_args()
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = get_args()
|
||||
with open('config/config_Stripformer_gopro.yaml') as cfg:
|
||||
config = yaml.load(cfg)
|
||||
blur_path = './datasets/GoPro/test/blur/'
|
||||
out_path = './out/Stripformer_GoPro_results'
|
||||
if not os.path.isdir(out_path):
|
||||
os.mkdir(out_path)
|
||||
model = get_generator(config['model'])
|
||||
model.load_state_dict(torch.load(args.weights_path))
|
||||
model = model.cuda()
|
||||
|
||||
test_time = 0
|
||||
iteration = 0
|
||||
total_image_number = 1111
|
||||
|
||||
# warm-up
|
||||
warm_up = 0
|
||||
print('Hardware warm-up')
|
||||
for file in os.listdir(blur_path):
|
||||
for img_name in os.listdir(blur_path + '/' + file):
|
||||
warm_up += 1
|
||||
img = cv2.imread(blur_path + '/' + file + '/' + img_name)
|
||||
img_tensor = torch.from_numpy(np.transpose(img / 255, (2, 0, 1)).astype('float32')) - 0.5
|
||||
with torch.no_grad():
|
||||
img_tensor = Variable(img_tensor.unsqueeze(0)).cuda()
|
||||
result_image = model(img_tensor)
|
||||
if warm_up == 20:
|
||||
break
|
||||
break
|
||||
|
||||
for file in os.listdir(blur_path):
|
||||
if not os.path.isdir(out_path + '/' + file):
|
||||
os.mkdir(out_path + '/' + file)
|
||||
for img_name in os.listdir(blur_path + '/' + file):
|
||||
img = cv2.imread(blur_path + '/' + file + '/' + img_name)
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
img_tensor = torch.from_numpy(np.transpose(img / 255, (2, 0, 1)).astype('float32')) - 0.5
|
||||
with torch.no_grad():
|
||||
iteration += 1
|
||||
img_tensor = Variable(img_tensor.unsqueeze(0)).cuda()
|
||||
|
||||
start = time.time()
|
||||
result_image = model(img_tensor)
|
||||
stop = time.time()
|
||||
print('Image:{}/{}, CNN Runtime:{:.4f}'.format(iteration, total_image_number, (stop - start)))
|
||||
test_time += stop - start
|
||||
print('Average Runtime:{:.4f}'.format(test_time / float(iteration)))
|
||||
result_image = result_image + 0.5
|
||||
out_file_name = out_path + '/' + file + '/' + img_name
|
||||
torchvision.utils.save_image(result_image, out_file_name)
|
||||
@ -0,0 +1,64 @@
|
||||
from __future__ import print_function
|
||||
import numpy as np
|
||||
import torch
|
||||
import cv2
|
||||
import yaml
|
||||
import os
|
||||
from torch.autograd import Variable
|
||||
from models.networks import get_generator
|
||||
import torchvision
|
||||
import time
|
||||
import argparse
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser('Test an image')
|
||||
parser.add_argument('--weights_path', required=True, help='Weights path')
|
||||
return parser.parse_args()
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = get_args()
|
||||
with open('config/config_Stripformer_gopro.yaml') as cfg:
|
||||
config = yaml.load(cfg)
|
||||
blur_path = './datasets/HIDE/blur/'
|
||||
out_path = './out/Stripformer_HIDE_results'
|
||||
if not os.path.isdir(out_path):
|
||||
os.mkdir(out_path)
|
||||
model = get_generator(config['model'])
|
||||
model.load_state_dict(torch.load(args.weights_path))
|
||||
model = model.cuda()
|
||||
test_time = 0
|
||||
iteration = 0
|
||||
total_image_number = 2025
|
||||
|
||||
# warm up
|
||||
warm_up = 0
|
||||
print('Hardware warm-up')
|
||||
for img_name in os.listdir(blur_path):
|
||||
warm_up += 1
|
||||
img = cv2.imread(blur_path + '/' + img_name)
|
||||
img_tensor = torch.from_numpy(np.transpose(img / 255, (2, 0, 1)).astype('float32')) - 0.5
|
||||
with torch.no_grad():
|
||||
img_tensor = Variable(img_tensor.unsqueeze(0)).cuda()
|
||||
result_image = model(img_tensor)
|
||||
if warm_up == 20:
|
||||
break
|
||||
break
|
||||
|
||||
for img_name in os.listdir(blur_path):
|
||||
img = cv2.imread(blur_path + '/' + img_name)
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
img_tensor = torch.from_numpy(np.transpose(img / 255, (2, 0, 1)).astype('float32')) - 0.5
|
||||
with torch.no_grad():
|
||||
iteration += 1
|
||||
img_tensor = Variable(img_tensor.unsqueeze(0)).cuda()
|
||||
|
||||
start = time.time()
|
||||
result_image = model(img_tensor)
|
||||
stop = time.time()
|
||||
|
||||
print('Image:{}/{}, CNN Runtime:{:.4f}'.format(iteration, total_image_number, (stop - start)))
|
||||
test_time += stop - start
|
||||
print('Average Runtime:{:.4f}'.format(test_time / float(iteration)))
|
||||
result_image = result_image + 0.5
|
||||
out_file_name = out_path + '/' + img_name
|
||||
torchvision.utils.save_image(result_image, out_file_name)
|
||||
@ -0,0 +1,90 @@
|
||||
from __future__ import print_function
|
||||
import numpy as np
|
||||
import torch
|
||||
import cv2
|
||||
import yaml
|
||||
import os
|
||||
from torch.autograd import Variable
|
||||
from models.networks import get_generator
|
||||
import torchvision
|
||||
import time
|
||||
import torch.nn.functional as F
|
||||
import argparse
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser('Test an image')
|
||||
parser.add_argument('--weights_path', required=True, help='Weights path')
|
||||
return parser.parse_args()
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = get_args()
|
||||
with open('config/config_Stripformer_gopro.yaml') as cfg:
|
||||
config = yaml.load(cfg)
|
||||
blur_path = './datasets/Realblur_J/test/blur/'
|
||||
out_path = './out/Stripformer_realblur_J_results'
|
||||
if not os.path.isdir(out_path):
|
||||
os.mkdir(out_path)
|
||||
model = get_generator(config['model'])
|
||||
model.load_state_dict(torch.load(args.weights_path))
|
||||
model = model.cuda()
|
||||
test_time = 0
|
||||
iteration = 0
|
||||
total_image_number = 980
|
||||
|
||||
# warm up
|
||||
warm_up = 0
|
||||
print('Hardware warm-up')
|
||||
for file in os.listdir(blur_path):
|
||||
if not os.path.isdir(out_path + '/' + file):
|
||||
os.mkdir(out_path + '/' + file)
|
||||
for img_name in os.listdir(blur_path + '/' + file):
|
||||
warm_up += 1
|
||||
img = cv2.imread(blur_path + '/' + file + '/' + img_name)
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
img_tensor = torch.from_numpy(np.transpose(img / 255, (2, 0, 1)).astype('float32')) - 0.5
|
||||
with torch.no_grad():
|
||||
img_tensor = Variable(img_tensor.unsqueeze(0)).cuda()
|
||||
factor = 8
|
||||
h, w = img_tensor.shape[2], img_tensor.shape[3]
|
||||
H, W = ((h + factor) // factor) * factor, ((w + factor) // factor) * factor
|
||||
padh = H - h if h % factor != 0 else 0
|
||||
padw = W - w if w % factor != 0 else 0
|
||||
img_tensor = F.pad(img_tensor, (0, padw, 0, padh), 'reflect')
|
||||
result_image = model(img_tensor)
|
||||
if warm_up == 20:
|
||||
break
|
||||
break
|
||||
|
||||
for file in os.listdir(blur_path):
|
||||
if not os.path.isdir(out_path + '/' + file):
|
||||
os.mkdir(out_path + '/' + file)
|
||||
for img_name in os.listdir(blur_path + '/' + file):
|
||||
img = cv2.imread(blur_path + '/' + file + '/' + img_name)
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
img_tensor = torch.from_numpy(np.transpose(img / 255, (2, 0, 1)).astype('float32')) - 0.5
|
||||
with torch.no_grad():
|
||||
iteration += 1
|
||||
img_tensor = Variable(img_tensor.unsqueeze(0)).cuda()
|
||||
|
||||
factor = 8
|
||||
h, w = img_tensor.shape[2], img_tensor.shape[3]
|
||||
H, W = ((h + factor) // factor) * factor, ((w + factor) // factor) * factor
|
||||
padh = H - h if h % factor != 0 else 0
|
||||
padw = W - w if w % factor != 0 else 0
|
||||
img_tensor = F.pad(img_tensor, (0, padw, 0, padh), 'reflect')
|
||||
H, W = img_tensor.shape[2], img_tensor.shape[3]
|
||||
|
||||
start = time.time()
|
||||
_output = model(img_tensor)
|
||||
stop = time.time()
|
||||
|
||||
result_image = _output[:, :, :h, :w]
|
||||
result_image = torch.clamp(result_image, -0.5, 0.5)
|
||||
result_image = result_image + 0.5
|
||||
|
||||
test_time += stop - start
|
||||
print('Image:{}/{}, CNN Runtime:{:.4f}'.format(iteration, total_image_number, (stop - start)))
|
||||
print('Average Runtime:{:.4f}'.format(test_time / float(iteration)))
|
||||
out_file_name = out_path + '/' + file + '/' + img_name
|
||||
torchvision.utils.save_image(result_image, out_file_name)
|
||||
|
||||
@ -0,0 +1,90 @@
|
||||
from __future__ import print_function
|
||||
import argparse
|
||||
import numpy as np
|
||||
import torch
|
||||
import cv2
|
||||
import yaml
|
||||
import os
|
||||
from torch.autograd import Variable
|
||||
from models.networks import get_generator
|
||||
import torchvision
|
||||
import time
|
||||
import torch.nn.functional as F
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser('Test an image')
|
||||
parser.add_argument('--weights_path', required=True, help='Weights path')
|
||||
return parser.parse_args()
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = get_args()
|
||||
with open('config/config_Stripformer_gopro.yaml') as cfg:
|
||||
config = yaml.load(cfg)
|
||||
blur_path = './datasets/Realblur_R/test/blur/'
|
||||
out_path = './out/Stripformer_realblur_R_results'
|
||||
model = get_generator(config['model'])
|
||||
model.load_state_dict(torch.load(args.weights_path))
|
||||
model = model.cuda()
|
||||
if not os.path.isdir(out_path):
|
||||
os.mkdir(out_path)
|
||||
test_time = 0
|
||||
iteration = 0
|
||||
total_image_number = 980
|
||||
|
||||
# warm up
|
||||
warm_up = 0
|
||||
print('Hardware warm-up')
|
||||
for file in os.listdir(blur_path):
|
||||
if not os.path.isdir(out_path + '/' + file):
|
||||
os.mkdir(out_path + '/' + file)
|
||||
for img_name in os.listdir(blur_path + '/' + file):
|
||||
warm_up += 1
|
||||
img = cv2.imread(blur_path + '/' + file + '/' + img_name)
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
img_tensor = torch.from_numpy(np.transpose(img / 255, (2, 0, 1)).astype('float32')) - 0.5
|
||||
with torch.no_grad():
|
||||
img_tensor = Variable(img_tensor.unsqueeze(0)).cuda()
|
||||
factor = 8
|
||||
h, w = img_tensor.shape[2], img_tensor.shape[3]
|
||||
H, W = ((h + factor) // factor) * factor, ((w + factor) // factor) * factor
|
||||
padh = H - h if h % factor != 0 else 0
|
||||
padw = W - w if w % factor != 0 else 0
|
||||
img_tensor = F.pad(img_tensor, (0, padw, 0, padh), 'reflect')
|
||||
result_image = model(img_tensor)
|
||||
if warm_up == 20:
|
||||
break
|
||||
break
|
||||
|
||||
for file in os.listdir(blur_path):
|
||||
if not os.path.isdir(out_path + '/' + file):
|
||||
os.mkdir(out_path + '/' + file)
|
||||
for img_name in os.listdir(blur_path + '/' + file):
|
||||
img = cv2.imread(blur_path + '/' + file + '/' + img_name)
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
img_tensor = torch.from_numpy(np.transpose(img / 255, (2, 0, 1)).astype('float32')) - 0.5
|
||||
with torch.no_grad():
|
||||
iteration += 1
|
||||
img_tensor = Variable(img_tensor.unsqueeze(0)).cuda()
|
||||
|
||||
factor = 8
|
||||
h, w = img_tensor.shape[2], img_tensor.shape[3]
|
||||
H, W = ((h + factor) // factor) * factor, ((w + factor) // factor) * factor
|
||||
padh = H - h if h % factor != 0 else 0
|
||||
padw = W - w if w % factor != 0 else 0
|
||||
img_tensor = F.pad(img_tensor, (0, padw, 0, padh), 'reflect')
|
||||
H, W = img_tensor.shape[2], img_tensor.shape[3]
|
||||
|
||||
start = time.time()
|
||||
_output = model(img_tensor)
|
||||
stop = time.time()
|
||||
|
||||
result_image = _output[:, :, :h, :w]
|
||||
result_image = torch.clamp(result_image, -0.5, 0.5)
|
||||
result_image = result_image + 0.5
|
||||
|
||||
test_time += stop - start
|
||||
print('Image:{}/{}, CNN Runtime:{:.4f}'.format(iteration, total_image_number, (stop - start)))
|
||||
print('Average Runtime:{:.4f}'.format(test_time / float(iteration)))
|
||||
out_file_name = out_path + '/' + file + '/' + img_name
|
||||
torchvision.utils.save_image(result_image, out_file_name)
|
||||
|
||||
@ -0,0 +1,59 @@
|
||||
import math
|
||||
|
||||
from torch.optim import lr_scheduler
|
||||
|
||||
|
||||
class WarmRestart(lr_scheduler.CosineAnnealingLR):
|
||||
"""This class implements Stochastic Gradient Descent with Warm Restarts(SGDR): https://arxiv.org/abs/1608.03983.
|
||||
|
||||
Set the learning rate of each parameter group using a cosine annealing schedule, When last_epoch=-1, sets initial lr as lr.
|
||||
This can't support scheduler.step(epoch). please keep epoch=None.
|
||||
"""
|
||||
|
||||
def __init__(self, optimizer, T_max=30, T_mult=1, eta_min=0, last_epoch=-1):
|
||||
"""implements SGDR
|
||||
|
||||
Parameters:
|
||||
----------
|
||||
T_max : int
|
||||
Maximum number of epochs.
|
||||
T_mult : int
|
||||
Multiplicative factor of T_max.
|
||||
eta_min : int
|
||||
Minimum learning rate. Default: 0.
|
||||
last_epoch : int
|
||||
The index of last epoch. Default: -1.
|
||||
"""
|
||||
self.T_mult = T_mult
|
||||
super().__init__(optimizer, T_max, eta_min, last_epoch)
|
||||
|
||||
def get_lr(self):
|
||||
if self.last_epoch == self.T_max:
|
||||
self.last_epoch = 0
|
||||
self.T_max *= self.T_mult
|
||||
return [self.eta_min + (base_lr - self.eta_min) * (1 + math.cos(math.pi * self.last_epoch / self.T_max)) / 2 for
|
||||
base_lr in self.base_lrs]
|
||||
|
||||
|
||||
class LinearDecay(lr_scheduler._LRScheduler):
|
||||
"""This class implements LinearDecay
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, optimizer, num_epochs, start_epoch=0, min_lr=0, last_epoch=-1):
|
||||
"""implements LinearDecay
|
||||
|
||||
Parameters:
|
||||
----------
|
||||
|
||||
"""
|
||||
self.num_epochs = num_epochs
|
||||
self.start_epoch = start_epoch
|
||||
self.min_lr = min_lr
|
||||
super().__init__(optimizer, last_epoch)
|
||||
|
||||
def get_lr(self):
|
||||
if self.last_epoch < self.start_epoch:
|
||||
return self.base_lrs
|
||||
return [base_lr - ((base_lr - self.min_lr) / self.num_epochs) * (self.last_epoch - self.start_epoch) for
|
||||
base_lr in self.base_lrs]
|
||||
@ -0,0 +1,170 @@
|
||||
import logging
|
||||
from functools import partial
|
||||
import os
|
||||
import cv2
|
||||
import torch
|
||||
import torch.optim as optim
|
||||
import tqdm
|
||||
import yaml
|
||||
from joblib import cpu_count
|
||||
from torch.utils.data import DataLoader
|
||||
import random
|
||||
from dataset import PairedDataset
|
||||
from metric_counter import MetricCounter
|
||||
from models.losses import get_loss
|
||||
from models.models import get_model
|
||||
from models.networks import get_nets
|
||||
import numpy as np
|
||||
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||
cv2.setNumThreads(0)
|
||||
|
||||
|
||||
class Trainer:
|
||||
def __init__(self, config, train: DataLoader, val: DataLoader):
|
||||
self.config = config
|
||||
self.train_dataset = train
|
||||
self.val_dataset = val
|
||||
self.metric_counter = MetricCounter(config['experiment_desc'])
|
||||
|
||||
def train(self):
|
||||
self._init_params()
|
||||
start_epoch = 0
|
||||
if os.path.exists('last_Stripformer_gopro.pth'):
|
||||
print('load_pretrained')
|
||||
training_state = (torch.load('last_Stripformer_gopro.pth'))
|
||||
start_epoch = training_state['epoch']
|
||||
new_weight = self.netG.state_dict()
|
||||
new_weight.update(training_state['model_state'])
|
||||
self.netG.load_state_dict(new_weight)
|
||||
new_optimizer = self.optimizer_G.state_dict()
|
||||
new_optimizer.update(training_state['optimizer_state'])
|
||||
self.optimizer_G.load_state_dict(new_optimizer)
|
||||
new_scheduler = self.scheduler_G.state_dict()
|
||||
new_scheduler.update(training_state['scheduler_state'])
|
||||
self.scheduler_G.load_state_dict(new_scheduler)
|
||||
else:
|
||||
print('load_GoPro_pretrained')
|
||||
training_state = (torch.load('final_Stripformer_pretrained.pth'))
|
||||
new_weight = self.netG.state_dict()
|
||||
new_weight.update(training_state)
|
||||
self.netG.load_state_dict(new_weight)
|
||||
|
||||
|
||||
for epoch in range(start_epoch, config['num_epochs']):
|
||||
self._run_epoch(epoch)
|
||||
if epoch % 30 == 0 or epoch == (config['num_epochs']-1):
|
||||
self._validate(epoch)
|
||||
self.scheduler_G.step()
|
||||
|
||||
scheduler_state = self.scheduler_G.state_dict()
|
||||
training_state = {'epoch': epoch, 'model_state': self.netG.state_dict(),
|
||||
'scheduler_state': scheduler_state, 'optimizer_state': self.optimizer_G.state_dict()}
|
||||
if self.metric_counter.update_best_model():
|
||||
torch.save(training_state['model_state'], 'best_{}.pth'.format(self.config['experiment_desc']))
|
||||
|
||||
if epoch % 200 == 0:
|
||||
torch.save(training_state, 'last_{}_{}.pth'.format(self.config['experiment_desc'], epoch))
|
||||
|
||||
if epoch == (config['num_epochs']-1):
|
||||
torch.save(training_state['model_state'], 'final_{}.pth'.format(self.config['experiment_desc']))
|
||||
|
||||
torch.save(training_state, 'last_{}.pth'.format(self.config['experiment_desc']))
|
||||
logging.debug("Experiment Name: %s, Epoch: %d, Loss: %s" % (
|
||||
self.config['experiment_desc'], epoch, self.metric_counter.loss_message()))
|
||||
|
||||
def _run_epoch(self, epoch):
|
||||
self.metric_counter.clear()
|
||||
for param_group in self.optimizer_G.param_groups:
|
||||
lr = param_group['lr']
|
||||
|
||||
epoch_size = config.get('train_batches_per_epoch') or len(self.train_dataset)
|
||||
tq = tqdm.tqdm(self.train_dataset, total=epoch_size)
|
||||
tq.set_description('Epoch {}, lr {}'.format(epoch, lr))
|
||||
i = 0
|
||||
for data in tq:
|
||||
inputs, targets = self.model.get_input(data)
|
||||
outputs = self.netG(inputs)
|
||||
self.optimizer_G.zero_grad()
|
||||
loss_G = self.criterionG(outputs, targets, inputs)
|
||||
loss_G.backward()
|
||||
self.optimizer_G.step()
|
||||
self.metric_counter.add_losses(loss_G.item())
|
||||
curr_psnr, curr_ssim, img_for_vis = self.model.get_images_and_metrics(inputs, outputs, targets)
|
||||
self.metric_counter.add_metrics(curr_psnr, curr_ssim)
|
||||
tq.set_postfix(loss=self.metric_counter.loss_message())
|
||||
if not i:
|
||||
self.metric_counter.add_image(img_for_vis, tag='train')
|
||||
i += 1
|
||||
if i > epoch_size:
|
||||
break
|
||||
tq.close()
|
||||
self.metric_counter.write_to_tensorboard(epoch)
|
||||
|
||||
def _validate(self, epoch):
|
||||
self.metric_counter.clear()
|
||||
epoch_size = config.get('val_batches_per_epoch') or len(self.val_dataset)
|
||||
tq = tqdm.tqdm(self.val_dataset, total=epoch_size)
|
||||
tq.set_description('Validation')
|
||||
i = 0
|
||||
for data in tq:
|
||||
with torch.no_grad():
|
||||
inputs, targets = self.model.get_input(data)
|
||||
outputs = self.netG(inputs)
|
||||
loss_G = self.criterionG(outputs, targets, inputs)
|
||||
self.metric_counter.add_losses(loss_G.item())
|
||||
curr_psnr, curr_ssim, img_for_vis = self.model.get_images_and_metrics(inputs, outputs, targets)
|
||||
self.metric_counter.add_metrics(curr_psnr, curr_ssim)
|
||||
if not i:
|
||||
self.metric_counter.add_image(img_for_vis, tag='val')
|
||||
i += 1
|
||||
if i > epoch_size:
|
||||
break
|
||||
tq.close()
|
||||
self.metric_counter.write_to_tensorboard(epoch, validation=True)
|
||||
|
||||
|
||||
def _get_optim(self, params):
|
||||
if self.config['optimizer']['name'] == 'adam':
|
||||
optimizer = optim.Adam(params, lr=self.config['optimizer']['lr'])
|
||||
else:
|
||||
raise ValueError("Optimizer [%s] not recognized." % self.config['optimizer']['name'])
|
||||
return optimizer
|
||||
|
||||
def _get_scheduler(self, optimizer):
|
||||
if self.config['scheduler']['name'] == 'cosine':
|
||||
scheduler = CosineAnnealingLR(optimizer, T_max=self.config['num_epochs'], eta_min=self.config['scheduler']['min_lr'])
|
||||
else:
|
||||
raise ValueError("Scheduler [%s] not recognized." % self.config['scheduler']['name'])
|
||||
return scheduler
|
||||
|
||||
def _init_params(self):
|
||||
self.criterionG = get_loss(self.config['model'])
|
||||
self.netG = get_nets(self.config['model'])
|
||||
self.netG.cuda()
|
||||
self.model = get_model(self.config['model'])
|
||||
self.optimizer_G = self._get_optim(filter(lambda p: p.requires_grad, self.netG.parameters()))
|
||||
self.scheduler_G = self._get_scheduler(self.optimizer_G)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
with open('config/config_Stripformer_gopro.yaml', 'r') as f:
|
||||
config = yaml.load(f)
|
||||
# setup
|
||||
torch.backends.cudnn.enabled = True
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
# set random seed
|
||||
seed = 666
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
|
||||
batch_size = config.pop('batch_size')
|
||||
get_dataloader = partial(DataLoader, batch_size=batch_size, num_workers=cpu_count(), shuffle=True, drop_last=False)
|
||||
|
||||
datasets = map(config.pop, ('train', 'val'))
|
||||
datasets = map(PairedDataset.from_config, datasets)
|
||||
train, val = map(get_dataloader, datasets)
|
||||
trainer = Trainer(config, train=train, val=val)
|
||||
trainer.train()
|
||||
@ -0,0 +1,164 @@
|
||||
import logging
|
||||
from functools import partial
|
||||
import os
|
||||
import cv2
|
||||
import torch
|
||||
import torch.optim as optim
|
||||
import tqdm
|
||||
import yaml
|
||||
from joblib import cpu_count
|
||||
from torch.utils.data import DataLoader
|
||||
import random
|
||||
from dataset import PairedDataset
|
||||
from metric_counter import MetricCounter
|
||||
from models.losses import get_loss
|
||||
from models.models import get_model
|
||||
from models.networks import get_nets
|
||||
import numpy as np
|
||||
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||
cv2.setNumThreads(0)
|
||||
|
||||
|
||||
class Trainer:
|
||||
def __init__(self, config, train: DataLoader, val: DataLoader):
|
||||
self.config = config
|
||||
self.train_dataset = train
|
||||
self.val_dataset = val
|
||||
self.metric_counter = MetricCounter(config['experiment_desc'])
|
||||
|
||||
def train(self):
|
||||
self._init_params()
|
||||
start_epoch = 0
|
||||
if os.path.exists('last_Stripformer_pretrained.pth'):
|
||||
print('load_pretrained')
|
||||
training_state = (torch.load('last_Stripformer_pretrained.pth'))
|
||||
start_epoch = training_state['epoch']
|
||||
new_weight = self.netG.state_dict()
|
||||
new_weight.update(training_state['model_state'])
|
||||
self.netG.load_state_dict(new_weight)
|
||||
new_optimizer = self.optimizer_G.state_dict()
|
||||
new_optimizer.update(training_state['optimizer_state'])
|
||||
self.optimizer_G.load_state_dict(new_optimizer)
|
||||
new_scheduler = self.scheduler_G.state_dict()
|
||||
new_scheduler.update(training_state['scheduler_state'])
|
||||
self.scheduler_G.load_state_dict(new_scheduler)
|
||||
|
||||
|
||||
for epoch in range(start_epoch, config['num_epochs']):
|
||||
self._run_epoch(epoch)
|
||||
if epoch % 30 == 0 or epoch == (config['num_epochs']-1):
|
||||
self._validate(epoch)
|
||||
self.scheduler_G.step()
|
||||
|
||||
scheduler_state = self.scheduler_G.state_dict()
|
||||
training_state = {'epoch': epoch, 'model_state': self.netG.state_dict(),
|
||||
'scheduler_state': scheduler_state, 'optimizer_state': self.optimizer_G.state_dict()}
|
||||
if self.metric_counter.update_best_model():
|
||||
torch.save(training_state['model_state'], 'best_{}.pth'.format(self.config['experiment_desc']))
|
||||
|
||||
if epoch % 300 == 0:
|
||||
torch.save(training_state, 'last_{}_{}.pth'.format(self.config['experiment_desc'], epoch))
|
||||
|
||||
if epoch == (config['num_epochs']-1):
|
||||
torch.save(training_state['model_state'], 'final_{}.pth'.format(self.config['experiment_desc']))
|
||||
|
||||
torch.save(training_state, 'last_{}.pth'.format(self.config['experiment_desc']))
|
||||
logging.debug("Experiment Name: %s, Epoch: %d, Loss: %s" % (
|
||||
self.config['experiment_desc'], epoch, self.metric_counter.loss_message()))
|
||||
|
||||
def _run_epoch(self, epoch):
|
||||
self.metric_counter.clear()
|
||||
for param_group in self.optimizer_G.param_groups:
|
||||
lr = param_group['lr']
|
||||
|
||||
epoch_size = config.get('train_batches_per_epoch') or len(self.train_dataset)
|
||||
tq = tqdm.tqdm(self.train_dataset, total=epoch_size)
|
||||
tq.set_description('Epoch {}, lr {}'.format(epoch, lr))
|
||||
i = 0
|
||||
for data in tq:
|
||||
inputs, targets = self.model.get_input(data)
|
||||
outputs = self.netG(inputs)
|
||||
self.optimizer_G.zero_grad()
|
||||
loss_G = self.criterionG(outputs, targets, inputs)
|
||||
loss_G.backward()
|
||||
self.optimizer_G.step()
|
||||
self.metric_counter.add_losses(loss_G.item())
|
||||
curr_psnr, curr_ssim, img_for_vis = self.model.get_images_and_metrics(inputs, outputs, targets)
|
||||
self.metric_counter.add_metrics(curr_psnr, curr_ssim)
|
||||
tq.set_postfix(loss=self.metric_counter.loss_message())
|
||||
if not i:
|
||||
self.metric_counter.add_image(img_for_vis, tag='train')
|
||||
i += 1
|
||||
if i > epoch_size:
|
||||
break
|
||||
tq.close()
|
||||
self.metric_counter.write_to_tensorboard(epoch)
|
||||
|
||||
def _validate(self, epoch):
|
||||
self.metric_counter.clear()
|
||||
epoch_size = config.get('val_batches_per_epoch') or len(self.val_dataset)
|
||||
tq = tqdm.tqdm(self.val_dataset, total=epoch_size)
|
||||
tq.set_description('Validation')
|
||||
i = 0
|
||||
for data in tq:
|
||||
with torch.no_grad():
|
||||
inputs, targets = self.model.get_input(data)
|
||||
outputs = self.netG(inputs)
|
||||
loss_G = self.criterionG(outputs, targets, inputs)
|
||||
self.metric_counter.add_losses(loss_G.item())
|
||||
curr_psnr, curr_ssim, img_for_vis = self.model.get_images_and_metrics(inputs, outputs, targets)
|
||||
self.metric_counter.add_metrics(curr_psnr, curr_ssim)
|
||||
if not i:
|
||||
self.metric_counter.add_image(img_for_vis, tag='val')
|
||||
i += 1
|
||||
if i > epoch_size:
|
||||
break
|
||||
tq.close()
|
||||
self.metric_counter.write_to_tensorboard(epoch, validation=True)
|
||||
|
||||
|
||||
def _get_optim(self, params):
|
||||
if self.config['optimizer']['name'] == 'adam':
|
||||
optimizer = optim.Adam(params, lr=self.config['optimizer']['lr'])
|
||||
else:
|
||||
raise ValueError("Optimizer [%s] not recognized." % self.config['optimizer']['name'])
|
||||
return optimizer
|
||||
|
||||
def _get_scheduler(self, optimizer):
|
||||
if self.config['scheduler']['name'] == 'cosine':
|
||||
scheduler = CosineAnnealingLR(optimizer, T_max=self.config['num_epochs'], eta_min=self.config['scheduler']['min_lr'])
|
||||
else:
|
||||
raise ValueError("Scheduler [%s] not recognized." % self.config['scheduler']['name'])
|
||||
return scheduler
|
||||
|
||||
def _init_params(self):
|
||||
self.criterionG = get_loss(self.config['model'])
|
||||
self.netG = get_nets(self.config['model'])
|
||||
self.netG.cuda()
|
||||
self.model = get_model(self.config['model'])
|
||||
self.optimizer_G = self._get_optim(filter(lambda p: p.requires_grad, self.netG.parameters()))
|
||||
self.scheduler_G = self._get_scheduler(self.optimizer_G)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
with open('config/config_Stripformer_pretrained.yaml', 'r') as f:
|
||||
config = yaml.load(f)
|
||||
# setup
|
||||
torch.backends.cudnn.enabled = True
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
# set random seed
|
||||
seed = 666
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
|
||||
batch_size = config.pop('batch_size')
|
||||
get_dataloader = partial(DataLoader, batch_size=batch_size, num_workers=cpu_count(), shuffle=True, drop_last=False)
|
||||
|
||||
datasets = map(config.pop, ('train', 'val'))
|
||||
datasets = map(PairedDataset.from_config, datasets)
|
||||
train, val = map(get_dataloader, datasets)
|
||||
trainer = Trainer(config, train=train, val=val)
|
||||
trainer.train()
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -0,0 +1,66 @@
|
||||
import dominate
|
||||
from dominate.tags import *
|
||||
import os
|
||||
|
||||
|
||||
class HTML:
|
||||
def __init__(self, web_dir, title, image_subdir='', reflesh=0):
|
||||
self.title = title
|
||||
self.web_dir = web_dir
|
||||
# self.img_dir = os.path.join(self.web_dir, )
|
||||
self.img_subdir = image_subdir
|
||||
self.img_dir = os.path.join(self.web_dir, image_subdir)
|
||||
if not os.path.exists(self.web_dir):
|
||||
os.makedirs(self.web_dir)
|
||||
if not os.path.exists(self.img_dir):
|
||||
os.makedirs(self.img_dir)
|
||||
# print(self.img_dir)
|
||||
|
||||
self.doc = dominate.document(title=title)
|
||||
if reflesh > 0:
|
||||
with self.doc.head:
|
||||
meta(http_equiv="reflesh", content=str(reflesh))
|
||||
|
||||
def get_image_dir(self):
|
||||
return self.img_dir
|
||||
|
||||
def add_header(self, str):
|
||||
with self.doc:
|
||||
h3(str)
|
||||
|
||||
def add_table(self, border=1):
|
||||
self.t = table(border=border, style="table-layout: fixed;")
|
||||
self.doc.add(self.t)
|
||||
|
||||
def add_images(self, ims, txts, links, width=400):
|
||||
self.add_table()
|
||||
with self.t:
|
||||
with tr():
|
||||
for im, txt, link in zip(ims, txts, links):
|
||||
with td(style="word-wrap: break-word;", halign="center", valign="top"):
|
||||
with p():
|
||||
with a(href=os.path.join(link)):
|
||||
img(style="width:%dpx" % width, src=os.path.join(im))
|
||||
br()
|
||||
p(txt)
|
||||
|
||||
def save(self,file='index'):
|
||||
html_file = '%s/%s.html' % (self.web_dir,file)
|
||||
f = open(html_file, 'wt')
|
||||
f.write(self.doc.render())
|
||||
f.close()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
html = HTML('web/', 'test_html')
|
||||
html.add_header('hello world')
|
||||
|
||||
ims = []
|
||||
txts = []
|
||||
links = []
|
||||
for n in range(4):
|
||||
ims.append('image_%d.png' % n)
|
||||
txts.append('text_%d' % n)
|
||||
links.append('image_%d.png' % n)
|
||||
html.add_images(ims, txts, links)
|
||||
html.save()
|
||||
@ -0,0 +1,33 @@
|
||||
import random
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.autograd import Variable
|
||||
from collections import deque
|
||||
|
||||
|
||||
class ImagePool():
|
||||
def __init__(self, pool_size):
|
||||
self.pool_size = pool_size
|
||||
self.sample_size = pool_size
|
||||
if self.pool_size > 0:
|
||||
self.num_imgs = 0
|
||||
self.images = deque()
|
||||
|
||||
def add(self, images):
|
||||
if self.pool_size == 0:
|
||||
return images
|
||||
for image in images.data:
|
||||
image = torch.unsqueeze(image, 0)
|
||||
if self.num_imgs < self.pool_size:
|
||||
self.num_imgs = self.num_imgs + 1
|
||||
self.images.append(image)
|
||||
else:
|
||||
self.images.popleft()
|
||||
self.images.append(image)
|
||||
|
||||
def query(self):
|
||||
if len(self.images) > self.sample_size:
|
||||
return_images = list(random.sample(self.images, self.sample_size))
|
||||
else:
|
||||
return_images = list(self.images)
|
||||
return torch.cat(return_images, 0)
|
||||
@ -0,0 +1,54 @@
|
||||
import math
|
||||
from math import exp
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.autograd import Variable
|
||||
|
||||
|
||||
def gaussian(window_size, sigma):
|
||||
gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)])
|
||||
return gauss / gauss.sum()
|
||||
|
||||
|
||||
def create_window(window_size, channel):
|
||||
_1D_window = gaussian(window_size, 1.5).unsqueeze(1)
|
||||
_2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
|
||||
window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
|
||||
return window
|
||||
|
||||
|
||||
def SSIM(img1, img2):
|
||||
(_, channel, _, _) = img1.size()
|
||||
window_size = 11
|
||||
window = create_window(window_size, channel)
|
||||
|
||||
if img1.is_cuda:
|
||||
window = window.cuda(img1.get_device())
|
||||
window = window.type_as(img1)
|
||||
|
||||
mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
|
||||
mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
|
||||
|
||||
mu1_sq = mu1.pow(2)
|
||||
mu2_sq = mu2.pow(2)
|
||||
mu1_mu2 = mu1 * mu2
|
||||
|
||||
sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
|
||||
sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
|
||||
sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2
|
||||
|
||||
C1 = 0.01 ** 2
|
||||
C2 = 0.03 ** 2
|
||||
|
||||
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
|
||||
return ssim_map.mean()
|
||||
|
||||
|
||||
def PSNR(img1, img2):
|
||||
mse = np.mean((img1 / 255. - img2 / 255.) ** 2)
|
||||
if mse == 0:
|
||||
return 100
|
||||
PIXEL_MAX = 1
|
||||
return 20 * math.log10(PIXEL_MAX / math.sqrt(mse))
|
||||
@ -0,0 +1,48 @@
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
import os
|
||||
import matplotlib.pyplot as plt
|
||||
import torch
|
||||
|
||||
def load_image(path):
|
||||
if(path[-3:] == 'dng'):
|
||||
import rawpy
|
||||
with rawpy.imread(path) as raw:
|
||||
img = raw.postprocess()
|
||||
elif(path[-3:]=='bmp' or path[-3:]=='jpg' or path[-3:]=='png'):
|
||||
import cv2
|
||||
return cv2.imread(path)[:,:,::-1]
|
||||
else:
|
||||
img = (255*plt.imread(path)[:,:,:3]).astype('uint8')
|
||||
|
||||
return img
|
||||
|
||||
def save_image(image_numpy, image_path, ):
|
||||
image_pil = Image.fromarray(image_numpy)
|
||||
image_pil.save(image_path)
|
||||
|
||||
def mkdirs(paths):
|
||||
if isinstance(paths, list) and not isinstance(paths, str):
|
||||
for path in paths:
|
||||
mkdir(path)
|
||||
else:
|
||||
mkdir(paths)
|
||||
|
||||
def mkdir(path):
|
||||
if not os.path.exists(path):
|
||||
os.makedirs(path)
|
||||
|
||||
|
||||
def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.):
|
||||
# def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=1.):
|
||||
image_numpy = image_tensor[0].cpu().float().numpy()
|
||||
image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor
|
||||
return image_numpy.astype(imtype)
|
||||
|
||||
def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.):
|
||||
# def im2tensor(image, imtype=np.uint8, cent=1., factor=1.):
|
||||
return torch.Tensor((image / factor - cent)
|
||||
[:, :, :, np.newaxis].transpose((3, 2, 0, 1)))
|
||||
@ -0,0 +1,216 @@
|
||||
import numpy as np
|
||||
import os
|
||||
import time
|
||||
from . import util
|
||||
from . import html
|
||||
# from pdb import set_trace as st
|
||||
import matplotlib.pyplot as plt
|
||||
import math
|
||||
# from IPython import embed
|
||||
|
||||
def zoom_to_res(img,res=256,order=0,axis=0):
|
||||
# img 3xXxX
|
||||
from scipy.ndimage import zoom
|
||||
zoom_factor = res/img.shape[1]
|
||||
if(axis==0):
|
||||
return zoom(img,[1,zoom_factor,zoom_factor],order=order)
|
||||
elif(axis==2):
|
||||
return zoom(img,[zoom_factor,zoom_factor,1],order=order)
|
||||
|
||||
class Visualizer():
|
||||
def __init__(self, opt):
|
||||
# self.opt = opt
|
||||
self.display_id = opt.display_id
|
||||
# self.use_html = opt.is_train and not opt.no_html
|
||||
self.win_size = opt.display_winsize
|
||||
self.name = opt.name
|
||||
self.display_cnt = 0 # display_current_results counter
|
||||
self.display_cnt_high = 0
|
||||
self.use_html = opt.use_html
|
||||
|
||||
if self.display_id > 0:
|
||||
import visdom
|
||||
self.vis = visdom.Visdom(port = opt.display_port)
|
||||
|
||||
self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web')
|
||||
util.mkdirs([self.web_dir,])
|
||||
if self.use_html:
|
||||
self.img_dir = os.path.join(self.web_dir, 'images')
|
||||
print('create web directory %s...' % self.web_dir)
|
||||
util.mkdirs([self.img_dir,])
|
||||
|
||||
# |visuals|: dictionary of images to display or save
|
||||
def display_current_results(self, visuals, epoch, nrows=None, res=256):
|
||||
if self.display_id > 0: # show images in the browser
|
||||
title = self.name
|
||||
if(nrows is None):
|
||||
nrows = int(math.ceil(len(visuals.items()) / 2.0))
|
||||
images = []
|
||||
idx = 0
|
||||
for label, image_numpy in visuals.items():
|
||||
title += " | " if idx % nrows == 0 else ", "
|
||||
title += label
|
||||
img = image_numpy.transpose([2, 0, 1])
|
||||
img = zoom_to_res(img,res=res,order=0)
|
||||
images.append(img)
|
||||
idx += 1
|
||||
if len(visuals.items()) % 2 != 0:
|
||||
white_image = np.ones_like(image_numpy.transpose([2, 0, 1]))*255
|
||||
white_image = zoom_to_res(white_image,res=res,order=0)
|
||||
images.append(white_image)
|
||||
self.vis.images(images, nrow=nrows, win=self.display_id + 1,
|
||||
opts=dict(title=title))
|
||||
|
||||
if self.use_html: # save images to a html file
|
||||
for label, image_numpy in visuals.items():
|
||||
img_path = os.path.join(self.img_dir, 'epoch%.3d_cnt%.6d_%s.png' % (epoch, self.display_cnt, label))
|
||||
util.save_image(zoom_to_res(image_numpy, res=res, axis=2), img_path)
|
||||
|
||||
self.display_cnt += 1
|
||||
self.display_cnt_high = np.maximum(self.display_cnt_high, self.display_cnt)
|
||||
|
||||
# update website
|
||||
webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, reflesh=1)
|
||||
for n in range(epoch, 0, -1):
|
||||
webpage.add_header('epoch [%d]' % n)
|
||||
if(n==epoch):
|
||||
high = self.display_cnt
|
||||
else:
|
||||
high = self.display_cnt_high
|
||||
for c in range(high-1,-1,-1):
|
||||
ims = []
|
||||
txts = []
|
||||
links = []
|
||||
|
||||
for label, image_numpy in visuals.items():
|
||||
img_path = 'epoch%.3d_cnt%.6d_%s.png' % (n, c, label)
|
||||
ims.append(os.path.join('images',img_path))
|
||||
txts.append(label)
|
||||
links.append(os.path.join('images',img_path))
|
||||
webpage.add_images(ims, txts, links, width=self.win_size)
|
||||
webpage.save()
|
||||
|
||||
# save errors into a directory
|
||||
def plot_current_errors_save(self, epoch, counter_ratio, opt, errors,keys='+ALL',name='loss', to_plot=False):
|
||||
if not hasattr(self, 'plot_data'):
|
||||
self.plot_data = {'X':[],'Y':[], 'legend':list(errors.keys())}
|
||||
self.plot_data['X'].append(epoch + counter_ratio)
|
||||
self.plot_data['Y'].append([errors[k] for k in self.plot_data['legend']])
|
||||
|
||||
# embed()
|
||||
if(keys=='+ALL'):
|
||||
plot_keys = self.plot_data['legend']
|
||||
else:
|
||||
plot_keys = keys
|
||||
|
||||
if(to_plot):
|
||||
(f,ax) = plt.subplots(1,1)
|
||||
for (k,kname) in enumerate(plot_keys):
|
||||
kk = np.where(np.array(self.plot_data['legend'])==kname)[0][0]
|
||||
x = self.plot_data['X']
|
||||
y = np.array(self.plot_data['Y'])[:,kk]
|
||||
if(to_plot):
|
||||
ax.plot(x, y, 'o-', label=kname)
|
||||
np.save(os.path.join(self.web_dir,'%s_x')%kname,x)
|
||||
np.save(os.path.join(self.web_dir,'%s_y')%kname,y)
|
||||
|
||||
if(to_plot):
|
||||
plt.legend(loc=0,fontsize='small')
|
||||
plt.xlabel('epoch')
|
||||
plt.ylabel('Value')
|
||||
f.savefig(os.path.join(self.web_dir,'%s.png'%name))
|
||||
f.clf()
|
||||
plt.close()
|
||||
|
||||
# errors: dictionary of error labels and values
|
||||
def plot_current_errors(self, epoch, counter_ratio, opt, errors):
|
||||
if not hasattr(self, 'plot_data'):
|
||||
self.plot_data = {'X':[],'Y':[], 'legend':list(errors.keys())}
|
||||
self.plot_data['X'].append(epoch + counter_ratio)
|
||||
self.plot_data['Y'].append([errors[k] for k in self.plot_data['legend']])
|
||||
self.vis.line(
|
||||
X=np.stack([np.array(self.plot_data['X'])]*len(self.plot_data['legend']),1),
|
||||
Y=np.array(self.plot_data['Y']),
|
||||
opts={
|
||||
'title': self.name + ' loss over time',
|
||||
'legend': self.plot_data['legend'],
|
||||
'xlabel': 'epoch',
|
||||
'ylabel': 'loss'},
|
||||
win=self.display_id)
|
||||
|
||||
# errors: same format as |errors| of plotCurrentErrors
|
||||
def print_current_errors(self, epoch, i, errors, t, t2=-1, t2o=-1, fid=None):
|
||||
message = '(ep: %d, it: %d, t: %.3f[s], ept: %.2f/%.2f[h]) ' % (epoch, i, t, t2o, t2)
|
||||
message += (', ').join(['%s: %.3f' % (k, v) for k, v in errors.items()])
|
||||
|
||||
print(message)
|
||||
if(fid is not None):
|
||||
fid.write('%s\n'%message)
|
||||
|
||||
|
||||
# save image to the disk
|
||||
def save_images_simple(self, webpage, images, names, in_txts, prefix='', res=256):
|
||||
image_dir = webpage.get_image_dir()
|
||||
ims = []
|
||||
txts = []
|
||||
links = []
|
||||
|
||||
for name, image_numpy, txt in zip(names, images, in_txts):
|
||||
image_name = '%s_%s.png' % (prefix, name)
|
||||
save_path = os.path.join(image_dir, image_name)
|
||||
if(res is not None):
|
||||
util.save_image(zoom_to_res(image_numpy,res=res,axis=2), save_path)
|
||||
else:
|
||||
util.save_image(image_numpy, save_path)
|
||||
|
||||
ims.append(os.path.join(webpage.img_subdir,image_name))
|
||||
# txts.append(name)
|
||||
txts.append(txt)
|
||||
links.append(os.path.join(webpage.img_subdir,image_name))
|
||||
# embed()
|
||||
webpage.add_images(ims, txts, links, width=self.win_size)
|
||||
|
||||
# save image to the disk
|
||||
def save_images(self, webpage, images, names, image_path, title=''):
|
||||
image_dir = webpage.get_image_dir()
|
||||
# short_path = ntpath.basename(image_path)
|
||||
# name = os.path.splitext(short_path)[0]
|
||||
# name = short_path
|
||||
# webpage.add_header('%s, %s' % (name, title))
|
||||
ims = []
|
||||
txts = []
|
||||
links = []
|
||||
|
||||
for label, image_numpy in zip(names, images):
|
||||
image_name = '%s.jpg' % (label,)
|
||||
save_path = os.path.join(image_dir, image_name)
|
||||
util.save_image(image_numpy, save_path)
|
||||
|
||||
ims.append(image_name)
|
||||
txts.append(label)
|
||||
links.append(image_name)
|
||||
webpage.add_images(ims, txts, links, width=self.win_size)
|
||||
|
||||
# save image to the disk
|
||||
# def save_images(self, webpage, visuals, image_path, short=False):
|
||||
# image_dir = webpage.get_image_dir()
|
||||
# if short:
|
||||
# short_path = ntpath.basename(image_path)
|
||||
# name = os.path.splitext(short_path)[0]
|
||||
# else:
|
||||
# name = image_path
|
||||
|
||||
# webpage.add_header(name)
|
||||
# ims = []
|
||||
# txts = []
|
||||
# links = []
|
||||
|
||||
# for label, image_numpy in visuals.items():
|
||||
# image_name = '%s_%s.png' % (name, label)
|
||||
# save_path = os.path.join(image_dir, image_name)
|
||||
# util.save_image(image_numpy, save_path)
|
||||
|
||||
# ims.append(image_name)
|
||||
# txts.append(label)
|
||||
# links.append(image_name)
|
||||
# webpage.add_images(ims, txts, links, width=self.win_size)
|
||||
Loading…
Reference in new issue