first version

main
pp00704831 3 years ago
commit df69507ef2

@ -0,0 +1,62 @@
from typing import List
import albumentations as albu
from torchvision import transforms
def get_transforms(size: int, scope: str = 'geometric', crop='random'):
augs = {'strong': albu.Compose([albu.HorizontalFlip(),
albu.ShiftScaleRotate(shift_limit=0.0, scale_limit=0.2, rotate_limit=20, p=.4),
albu.ElasticTransform(),
albu.OpticalDistortion(),
albu.OneOf([
albu.CLAHE(clip_limit=2),
albu.IAASharpen(),
albu.IAAEmboss(),
albu.RandomBrightnessContrast(),
albu.RandomGamma()
], p=0.5),
albu.OneOf([
albu.RGBShift(),
albu.HueSaturationValue(),
], p=0.5),
]),
'weak': albu.Compose([albu.HorizontalFlip(),
]),
'geometric': albu.Compose([albu.HorizontalFlip(),
albu.VerticalFlip(),
albu.RandomRotate90(),
]),
'None': None
}
aug_fn = augs[scope]
crop_fn = {'random': albu.RandomCrop(size, size, always_apply=True),
'center': albu.CenterCrop(size, size, always_apply=True)}[crop]
pipeline = albu.Compose([aug_fn, crop_fn], additional_targets={'target': 'image'})
def process(a, b):
r = pipeline(image=a, target=b)
return r['image'], r['target']
return process
def get_normalize():
transform = transforms.Compose([
transforms.ToTensor()
])
def process(a, b):
image = transform(a).permute(1, 2, 0) - 0.5
target = transform(b).permute(1, 2, 0) - 0.5
return image, target
return process

@ -0,0 +1,40 @@
---
experiment_desc: Stripformer_gopro
train:
files_a: ./datasets/GoPro/train/blur/**/*.png
files_b: ./datasets/GoPro/train/sharp/**/*.png
size: &SIZE 512
crop: random
preload: &PRELOAD false
preload_size: &PRELOAD_SIZE 0
bounds: [0, 1]
scope: geometric
val:
files_a: ./datasets/GoPro/test/blur/**/*.png
files_b: ./datasets/GoPro/test/sharp/**/*.png
size: *SIZE
scope: None
crop: random
preload: *PRELOAD
preload_size: *PRELOAD_SIZE
bounds: [0, 1]
model:
g_name: Stripformer
content_loss: Stripformer_Loss
num_epochs: 1000
train_batches_per_epoch: 2103
val_batches_per_epoch: 1111
batch_size: 8
image_size: [512, 512]
optimizer:
name: adam
lr: 0.0001
scheduler:
name: cosine
start_epoch: 50
min_lr: 0.0000001

@ -0,0 +1,40 @@
---
experiment_desc: Stripformer_pretrained
train:
files_a: ./datasets/GoPro/train/blur/**/*.png
files_b: ./datasets/GoPro/train/sharp/**/*.png
size: &SIZE 256
crop: random
preload: &PRELOAD false
preload_size: &PRELOAD_SIZE 0
bounds: [0, 1]
scope: geometric
val:
files_a: ./datasets/GoPro/test/blur/**/*.png
files_b: ./datasets/GoPro/test/sharp/**/*.png
size: *SIZE
scope: None
crop: random
preload: *PRELOAD
preload_size: *PRELOAD_SIZE
bounds: [0, 1]
model:
g_name: Stripformer
content_loss: Stripformer_Loss
num_epochs: 3000
train_batches_per_epoch: 2103
val_batches_per_epoch: 1111
batch_size: 8
image_size: [256, 256]
optimizer:
name: adam
lr: 0.0001
scheduler:
name: cosine
start_epoch: 50
min_lr: 0.0000001

@ -0,0 +1,140 @@
import os
from copy import deepcopy
from functools import partial
from glob import glob
from hashlib import sha1
from typing import Callable, Iterable, Optional, Tuple
import cv2
import numpy as np
from glog import logger
from joblib import Parallel, cpu_count, delayed
from skimage.io import imread
from torch.utils.data import Dataset
from tqdm import tqdm
import aug
def subsample(data: Iterable, bounds: Tuple[float, float], hash_fn: Callable, n_buckets=100, salt='', verbose=True):
data = list(data)
buckets = split_into_buckets(data, n_buckets=n_buckets, salt=salt, hash_fn=hash_fn)
lower_bound, upper_bound = [x * n_buckets for x in bounds]
msg = f'Subsampling buckets from {lower_bound} to {upper_bound}, total buckets number is {n_buckets}'
if salt:
msg += f'; salt is {salt}'
if verbose:
logger.info(msg)
return np.array([sample for bucket, sample in zip(buckets, data) if lower_bound <= bucket < upper_bound])
def hash_from_paths(x: Tuple[str, str], salt: str = '') -> str:
path_a, path_b = x
names = ''.join(map(os.path.basename, (path_a, path_b)))
return sha1(f'{names}_{salt}'.encode()).hexdigest()
def split_into_buckets(data: Iterable, n_buckets: int, hash_fn: Callable, salt=''):
hashes = map(partial(hash_fn, salt=salt), data)
return np.array([int(x, 16) % n_buckets for x in hashes])
def _read_img(x: str):
img = cv2.imread(x)
if img is None:
logger.warning(f'Can not read image {x} with OpenCV, switching to scikit-image')
img = imread(x)
return img
class PairedDataset(Dataset):
def __init__(self,
files_a: Tuple[str],
files_b: Tuple[str],
transform_fn: Callable,
normalize_fn: Callable,
corrupt_fn: Optional[Callable] = None,
preload: bool = True,
preload_size: Optional[int] = 0,
verbose=True):
assert len(files_a) == len(files_b)
self.preload = preload
self.data_a = files_a
self.data_b = files_b
self.verbose = verbose
self.corrupt_fn = corrupt_fn
self.transform_fn = transform_fn
self.normalize_fn = normalize_fn
logger.info(f'Dataset has been created with {len(self.data_a)} samples')
if preload:
preload_fn = partial(self._bulk_preload, preload_size=preload_size)
if files_a == files_b:
self.data_a = self.data_b = preload_fn(self.data_a)
else:
self.data_a, self.data_b = map(preload_fn, (self.data_a, self.data_b))
self.preload = True
def _bulk_preload(self, data: Iterable[str], preload_size: int):
jobs = [delayed(self._preload)(x, preload_size=preload_size) for x in data]
jobs = tqdm(jobs, desc='preloading images', disable=not self.verbose)
return Parallel(n_jobs=cpu_count(), backend='threading')(jobs)
@staticmethod
def _preload(x: str, preload_size: int):
img = _read_img(x)
if preload_size:
h, w, *_ = img.shape
h_scale = preload_size / h
w_scale = preload_size / w
scale = max(h_scale, w_scale)
img = cv2.resize(img, fx=scale, fy=scale, dsize=None)
assert min(img.shape[:2]) >= preload_size, f'weird img shape: {img.shape}'
return img
def _preprocess(self, img, res):
def transpose(x):
return np.transpose(x, (2, 0, 1))
return map(transpose, self.normalize_fn(img, res))
def __len__(self):
return len(self.data_a)
def __getitem__(self, idx):
a, b = self.data_a[idx], self.data_b[idx]
if not self.preload:
a, b = map(_read_img, (a, b))
a, b = self.transform_fn(a, b)
if self.corrupt_fn is not None:
a = self.corrupt_fn(a)
a, b = self._preprocess(a, b)
return {'a': a, 'b': b}
@staticmethod
def from_config(config):
config = deepcopy(config)
files_a, files_b = map(lambda x: sorted(glob(config[x], recursive=True)), ('files_a', 'files_b'))
transform_fn = aug.get_transforms(size=config['size'], scope=config['scope'], crop=config['crop'])
normalize_fn = aug.get_normalize()
hash_fn = hash_from_paths
# ToDo: add more hash functions
verbose = config.get('verbose', True)
data = subsample(data=zip(files_a, files_b),
bounds=config.get('bounds', (0, 1)),
hash_fn=hash_fn,
verbose=verbose)
files_a, files_b = map(list, zip(*data))
return PairedDataset(files_a=files_a,
files_b=files_b,
preload=config['preload'],
preload_size=config['preload_size'],
normalize_fn=normalize_fn,
transform_fn=transform_fn,
verbose=verbose)

@ -0,0 +1 @@
Download 'GoPro' datasets and put the datasets into folder './datasets'

@ -0,0 +1,101 @@
import os
from skimage import io
import cv2
import numpy as np
from skimage.metrics import structural_similarity
import concurrent.futures
def image_align(deblurred, gt):
# this function is based on kohler evaluation code
z = deblurred
c = np.ones_like(z)
x = gt
zs = (np.sum(x * z) / np.sum(z * z)) * z # simple intensity matching
warp_mode = cv2.MOTION_HOMOGRAPHY
warp_matrix = np.eye(3, 3, dtype=np.float32)
# Specify the number of iterations.
number_of_iterations = 100
termination_eps = 0
criteria = (cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT,
number_of_iterations, termination_eps)
# Run the ECC algorithm. The results are stored in warp_matrix.
(cc, warp_matrix) = cv2.findTransformECC(cv2.cvtColor(x, cv2.COLOR_RGB2GRAY), cv2.cvtColor(zs, cv2.COLOR_RGB2GRAY),
warp_matrix, warp_mode, criteria, inputMask=None, gaussFiltSize=5)
target_shape = x.shape
shift = warp_matrix
zr = cv2.warpPerspective(
zs,
warp_matrix,
(target_shape[1], target_shape[0]),
flags=cv2.INTER_CUBIC + cv2.WARP_INVERSE_MAP,
borderMode=cv2.BORDER_REFLECT)
cr = cv2.warpPerspective(
np.ones_like(zs, dtype='float32'),
warp_matrix,
(target_shape[1], target_shape[0]),
flags=cv2.INTER_NEAREST + cv2.WARP_INVERSE_MAP,
borderMode=cv2.BORDER_CONSTANT,
borderValue=0)
zr = zr * cr
xr = x * cr
return zr, xr, cr, shift
def compute_psnr(image_true, image_test, image_mask, data_range=None):
# this function is based on skimage.metrics.peak_signal_noise_ratio
err = np.sum((image_true - image_test) ** 2, dtype=np.float64) / np.sum(image_mask)
return 10 * np.log10((data_range ** 2) / err)
def compute_ssim(tar_img, prd_img, cr1):
ssim_pre, ssim_map = structural_similarity(tar_img, prd_img, multichannel=True, gaussian_weights=True,
use_sample_covariance=False, data_range=1.0, full=True)
ssim_map = ssim_map * cr1
r = int(3.5 * 1.5 + 0.5) # radius as in ndimage
win_size = 2 * r + 1
pad = (win_size - 1) // 2
ssim = ssim_map[pad:-pad, pad:-pad, :]
crop_cr1 = cr1[pad:-pad, pad:-pad, :]
ssim = ssim.sum(axis=0).sum(axis=0) / crop_cr1.sum(axis=0).sum(axis=0)
ssim = np.mean(ssim)
return ssim
total_psnr = 0.
total_ssim = 0.
count = 0
img_path = './out/Stripformer_realblur_J_results'
gt_path = './datasets/Realblur_J/test/sharp'
print(img_path)
for file in os.listdir(img_path):
for img_name in os.listdir(img_path + '/' + file):
count += 1
number = img_name.split('_')[1]
gt_name = 'gt_' + number
img_dir = img_path + '/' + file + '/' + img_name
gt_dir = gt_path + '/' + file + '/' + gt_name
with concurrent.futures.ProcessPoolExecutor(max_workers=10) as executor:
tar_img = io.imread(gt_dir)
prd_img = io.imread(img_dir)
tar_img = tar_img.astype(np.float32) / 255.0
prd_img = prd_img.astype(np.float32) / 255.0
prd_img, tar_img, cr1, shift = image_align(prd_img, tar_img)
PSNR = compute_psnr(tar_img, prd_img, cr1, data_range=1)
SSIM = compute_ssim(tar_img, prd_img, cr1)
total_psnr += PSNR
total_ssim += SSIM
print(count, PSNR)
print('PSNR:', total_psnr / count)
print('SSIM:', total_ssim / count)
print(img_path)

@ -0,0 +1,101 @@
import os
from skimage import io
import cv2
import numpy as np
from skimage.metrics import structural_similarity
import concurrent.futures
def image_align(deblurred, gt):
# this function is based on kohler evaluation code
z = deblurred
c = np.ones_like(z)
x = gt
zs = (np.sum(x * z) / np.sum(z * z)) * z # simple intensity matching
warp_mode = cv2.MOTION_HOMOGRAPHY
warp_matrix = np.eye(3, 3, dtype=np.float32)
# Specify the number of iterations.
number_of_iterations = 100
termination_eps = 0
criteria = (cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT,
number_of_iterations, termination_eps)
# Run the ECC algorithm. The results are stored in warp_matrix.
(cc, warp_matrix) = cv2.findTransformECC(cv2.cvtColor(x, cv2.COLOR_RGB2GRAY), cv2.cvtColor(zs, cv2.COLOR_RGB2GRAY),
warp_matrix, warp_mode, criteria, inputMask=None, gaussFiltSize=5)
target_shape = x.shape
shift = warp_matrix
zr = cv2.warpPerspective(
zs,
warp_matrix,
(target_shape[1], target_shape[0]),
flags=cv2.INTER_CUBIC + cv2.WARP_INVERSE_MAP,
borderMode=cv2.BORDER_REFLECT)
cr = cv2.warpPerspective(
np.ones_like(zs, dtype='float32'),
warp_matrix,
(target_shape[1], target_shape[0]),
flags=cv2.INTER_NEAREST + cv2.WARP_INVERSE_MAP,
borderMode=cv2.BORDER_CONSTANT,
borderValue=0)
zr = zr * cr
xr = x * cr
return zr, xr, cr, shift
def compute_psnr(image_true, image_test, image_mask, data_range=None):
# this function is based on skimage.metrics.peak_signal_noise_ratio
err = np.sum((image_true - image_test) ** 2, dtype=np.float64) / np.sum(image_mask)
return 10 * np.log10((data_range ** 2) / err)
def compute_ssim(tar_img, prd_img, cr1):
ssim_pre, ssim_map = structural_similarity(tar_img, prd_img, multichannel=True, gaussian_weights=True,
use_sample_covariance=False, data_range=1.0, full=True)
ssim_map = ssim_map * cr1
r = int(3.5 * 1.5 + 0.5) # radius as in ndimage
win_size = 2 * r + 1
pad = (win_size - 1) // 2
ssim = ssim_map[pad:-pad, pad:-pad, :]
crop_cr1 = cr1[pad:-pad, pad:-pad, :]
ssim = ssim.sum(axis=0).sum(axis=0) / crop_cr1.sum(axis=0).sum(axis=0)
ssim = np.mean(ssim)
return ssim
total_psnr = 0.
total_ssim = 0.
count = 0
img_path = './out/Stripformer_realblur_R_results'
gt_path = './datasets/Realblur_R/test/sharp'
print(img_path)
for file in os.listdir(img_path):
for img_name in os.listdir(img_path + '/' + file):
count += 1
number = img_name.split('_')[1]
gt_name = 'gt_' + number
img_dir = img_path + '/' + file + '/' + img_name
gt_dir = gt_path + '/' + file + '/' + gt_name
with concurrent.futures.ProcessPoolExecutor(max_workers=10) as executor:
tar_img = io.imread(gt_dir)
prd_img = io.imread(img_dir)
tar_img = tar_img.astype(np.float32) / 255.0
prd_img = prd_img.astype(np.float32) / 255.0
prd_img, tar_img, cr1, shift = image_align(prd_img, tar_img)
PSNR = compute_psnr(tar_img, prd_img, cr1, data_range=1)
SSIM = compute_ssim(tar_img, prd_img, cr1)
total_psnr += PSNR
total_ssim += SSIM
print(count, PSNR)
print('PSNR:', total_psnr / count)
print('SSIM:', total_ssim / count)
print(img_path)

@ -0,0 +1,60 @@
p = genpath('.\out\stripformer_10_patches');% GoPro Deblur Results
gt = genpath('.\datasets\GoPro\test\sharp');% GoPro GT Results
length_p = size(p,2);
path = {};
temp = [];
for i = 1:length_p
if p(i) ~= ';'
temp = [temp p(i)];
else
temp = [temp '\'];
path = [path ; temp];
temp = [];
end
end
clear p length_p temp;
length_gt = size(gt,2);
path_gt = {};
temp_gt = [];
for i = 1:length_gt
if gt(i) ~= ';'
temp_gt = [temp_gt gt(i)];
else
temp_gt = [temp_gt '\'];
path_gt = [path_gt ; temp_gt];
temp_gt = [];
end
end
clear gt length_gt temp_gt;
file_num = size(path,1);
total_psnr = 0;
n = 0;
total_ssim = 0;
for i = 1:file_num
file_path = path{i};
gt_file_path = path_gt{i};
img_path_list = dir(strcat(file_path,'*.png'));
gt_path_list = dir(strcat(gt_file_path,'*.png'));
img_num = length(img_path_list);
if img_num > 0
for j = 1:img_num
image_name = img_path_list(j).name;
gt_name = gt_path_list(j).name;
image = imread(strcat(file_path,image_name));
gt = imread(strcat(gt_file_path,gt_name));
size(image);
size(gt);
peaksnr = psnr(image,gt);
ssimval = ssim(image,gt);
total_psnr = total_psnr + peaksnr;
total_ssim = total_ssim + ssimval;
n = n + 1
end
end
end
psnr = total_psnr / n
ssim = total_ssim / n
close all;clear all;

@ -0,0 +1,60 @@
p = genpath('.\out\Stripformer_HIDE_result');% HIDE Deblur Results
gt = genpath('.\datasets\HIDE\sharp');% HIDE GT Results
length_p = size(p,2);
path = {};
temp = [];
for i = 1:length_p
if p(i) ~= ';'
temp = [temp p(i)];
else
temp = [temp '\'];
path = [path ; temp];
temp = [];
end
end
clear p length_p temp;
length_gt = size(gt,2);
path_gt = {};
temp_gt = [];
for i = 1:length_gt
if gt(i) ~= ';'
temp_gt = [temp_gt gt(i)];
else
temp_gt = [temp_gt '\'];
path_gt = [path_gt ; temp_gt];
temp_gt = [];
end
end
clear gt length_gt temp_gt;
file_num = size(path,1);
total_psnr = 0;
n = 0;
total_ssim = 0;
for i = 1:file_num
file_path = path{i};
gt_file_path = path_gt{i};
img_path_list = dir(strcat(file_path,'*.png'));
gt_path_list = dir(strcat(gt_file_path,'*.png'));
img_num = length(img_path_list);
if img_num > 0
for j = 1:img_num
image_name = img_path_list(j).name;
gt_name = gt_path_list(j).name;
image = imread(strcat(file_path,image_name));
gt = imread(strcat(gt_file_path,gt_name));
size(image);
size(gt);
peaksnr = psnr(image,gt);
ssimval = ssim(image,gt);
total_psnr = total_psnr + peaksnr;
total_ssim = total_ssim + ssimval;
n = n + 1
end
end
end
psnr = total_psnr / n
ssim = total_ssim / n
close all;clear all;

@ -0,0 +1,55 @@
import logging
from collections import defaultdict
import numpy as np
from tensorboardX import SummaryWriter
WINDOW_SIZE = 100
class MetricCounter:
def __init__(self, exp_name):
self.writer = SummaryWriter(exp_name)
logging.basicConfig(filename='{}.log'.format(exp_name), level=logging.DEBUG)
self.metrics = defaultdict(list)
self.images = defaultdict(list)
self.best_metric = 0
def add_image(self, x: np.ndarray, tag: str):
self.images[tag].append(x)
def clear(self):
self.metrics = defaultdict(list)
self.images = defaultdict(list)
def add_losses(self, l_G):
for name, value in zip(('G_loss', None), (l_G, None)):
self.metrics[name].append(value)
def add_metrics(self, psnr, ssim):
for name, value in zip(('PSNR', 'SSIM'),
(psnr, ssim)):
self.metrics[name].append(value)
def loss_message(self):
metrics = ((k, np.mean(self.metrics[k][-WINDOW_SIZE:])) for k in ('G_loss', 'PSNR', 'SSIM'))
return '; '.join(map(lambda x: f'{x[0]}={x[1]:.4f}', metrics))
def write_to_tensorboard(self, epoch_num, validation=False):
scalar_prefix = 'Validation' if validation else 'Train'
for tag in ('G_loss', 'SSIM', 'PSNR'):
self.writer.add_scalar(f'{scalar_prefix}_{tag}', np.mean(self.metrics[tag]), global_step=epoch_num)
for tag in self.images:
imgs = self.images[tag]
if imgs:
imgs = np.array(imgs)
self.writer.add_images(tag, imgs[:, :, :, ::-1].astype('float32') / 255, dataformats='NHWC',
global_step=epoch_num)
self.images[tag] = []
def update_best_model(self):
cur_metric = np.mean(self.metrics['PSNR'])
if self.best_metric < cur_metric:
self.best_metric = cur_metric
return True
return False

@ -0,0 +1,374 @@
import torch
import torch.nn as nn
import math
class Embeddings(nn.Module):
def __init__(self):
super(Embeddings, self).__init__()
self.activation = nn.LeakyReLU(0.2, True)
self.en_layer1_1 = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=3, padding=1),
self.activation,
)
self.en_layer1_2 = nn.Sequential(
nn.Conv2d(64, 64, kernel_size=3, padding=1),
self.activation,
nn.Conv2d(64, 64, kernel_size=3, padding=1))
self.en_layer1_3 = nn.Sequential(
nn.Conv2d(64, 64, kernel_size=3, padding=1),
self.activation,
nn.Conv2d(64, 64, kernel_size=3, padding=1))
self.en_layer1_4 = nn.Sequential(
nn.Conv2d(64, 64, kernel_size=3, padding=1),
self.activation,
nn.Conv2d(64, 64, kernel_size=3, padding=1))
self.en_layer2_1 = nn.Sequential(
nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
self.activation,
)
self.en_layer2_2 = nn.Sequential(
nn.Conv2d(128, 128, kernel_size=3, padding=1),
self.activation,
nn.Conv2d(128, 128, kernel_size=3, padding=1))
self.en_layer2_3 = nn.Sequential(
nn.Conv2d(128, 128, kernel_size=3, padding=1),
self.activation,
nn.Conv2d(128, 128, kernel_size=3, padding=1))
self.en_layer2_4 = nn.Sequential(
nn.Conv2d(128, 128, kernel_size=3, padding=1),
self.activation,
nn.Conv2d(128, 128, kernel_size=3, padding=1))
self.en_layer3_1 = nn.Sequential(
nn.Conv2d(128, 320, kernel_size=3, stride=2, padding=1),
self.activation,
)
def forward(self, x):
hx = self.en_layer1_1(x)
hx = self.activation(self.en_layer1_2(hx) + hx)
hx = self.activation(self.en_layer1_3(hx) + hx)
hx = self.activation(self.en_layer1_4(hx) + hx)
residual_1 = hx
hx = self.en_layer2_1(hx)
hx = self.activation(self.en_layer2_2(hx) + hx)
hx = self.activation(self.en_layer2_3(hx) + hx)
hx = self.activation(self.en_layer2_4(hx) + hx)
residual_2 = hx
hx = self.en_layer3_1(hx)
return hx, residual_1, residual_2
class Embeddings_output(nn.Module):
def __init__(self):
super(Embeddings_output, self).__init__()
self.activation = nn.LeakyReLU(0.2, True)
self.de_layer3_1 = nn.Sequential(
nn.ConvTranspose2d(320, 192, kernel_size=4, stride=2, padding=1),
self.activation,
)
head_num = 3
dim = 192
self.de_layer2_2 = nn.Sequential(
nn.Conv2d(192+128, 192, kernel_size=1, padding=0),
self.activation,
)
self.de_block_1 = Intra_SA(dim, head_num)
self.de_block_2 = Inter_SA(dim, head_num)
self.de_block_3 = Intra_SA(dim, head_num)
self.de_block_4 = Inter_SA(dim, head_num)
self.de_block_5 = Intra_SA(dim, head_num)
self.de_block_6 = Inter_SA(dim, head_num)
self.de_layer2_1 = nn.Sequential(
nn.ConvTranspose2d(192, 64, kernel_size=4, stride=2, padding=1),
self.activation,
)
self.de_layer1_3 = nn.Sequential(
nn.Conv2d(128, 64, kernel_size=1, padding=0),
self.activation,
nn.Conv2d(64, 64, kernel_size=3, padding=1))
self.de_layer1_2 = nn.Sequential(
nn.Conv2d(64, 64, kernel_size=3, padding=1),
self.activation,
nn.Conv2d(64, 64, kernel_size=3, padding=1))
self.de_layer1_1 = nn.Sequential(
nn.Conv2d(64, 3, kernel_size=3, padding=1),
self.activation
)
def forward(self, x, residual_1, residual_2):
hx = self.de_layer3_1(x)
hx = self.de_layer2_2(torch.cat((hx, residual_2), dim = 1))
hx = self.de_block_1(hx)
hx = self.de_block_2(hx)
hx = self.de_block_3(hx)
hx = self.de_block_4(hx)
hx = self.de_block_5(hx)
hx = self.de_block_6(hx)
hx = self.de_layer2_1(hx)
hx = self.activation(self.de_layer1_3(torch.cat((hx, residual_1), dim = 1)) + hx)
hx = self.activation(self.de_layer1_2(hx) + hx)
hx = self.de_layer1_1(hx)
return hx
class Attention(nn.Module):
def __init__(self, head_num):
super(Attention, self).__init__()
self.num_attention_heads = head_num
self.softmax = nn.Softmax(dim=-1)
def transpose_for_scores(self, x):
B, N, C = x.size()
attention_head_size = int(C / self.num_attention_heads)
new_x_shape = x.size()[:-1] + (self.num_attention_heads, attention_head_size)
x = x.view(*new_x_shape)
return x.permute(0, 2, 1, 3).contiguous()
def forward(self, query_layer, key_layer, value_layer):
B, N, C = query_layer.size()
query_layer = self.transpose_for_scores(query_layer)
key_layer = self.transpose_for_scores(key_layer)
value_layer = self.transpose_for_scores(value_layer)
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
_, _, _, d = query_layer.size()
attention_scores = attention_scores / math.sqrt(d)
attention_probs = self.softmax(attention_scores)
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (C,)
attention_out = context_layer.view(*new_context_layer_shape)
return attention_out
class Mlp(nn.Module):
def __init__(self, hidden_size):
super(Mlp, self).__init__()
self.fc1 = nn.Linear(hidden_size, 4*hidden_size)
self.fc2 = nn.Linear(4*hidden_size, hidden_size)
self.act_fn = torch.nn.functional.gelu
self._init_weights()
def _init_weights(self):
nn.init.xavier_uniform_(self.fc1.weight)
nn.init.xavier_uniform_(self.fc2.weight)
nn.init.normal_(self.fc1.bias, std=1e-6)
nn.init.normal_(self.fc2.bias, std=1e-6)
def forward(self, x):
x = self.fc1(x)
x = self.act_fn(x)
x = self.fc2(x)
return x
# CPE (Conditional Positional Embedding)
class PEG(nn.Module):
def __init__(self, hidden_size):
super(PEG, self).__init__()
self.PEG = nn.Conv2d(hidden_size, hidden_size, kernel_size=3, padding=1, groups=hidden_size)
def forward(self, x):
x = self.PEG(x) + x
return x
class Intra_SA(nn.Module):
def __init__(self, dim, head_num):
super(Intra_SA, self).__init__()
self.hidden_size = dim // 2
self.head_num = head_num
self.attention_norm = nn.LayerNorm(dim)
self.conv_input = nn.Conv2d(dim, dim, kernel_size=1, padding=0)
self.qkv_local_h = nn.Linear(self.hidden_size, self.hidden_size * 3) # qkv_h
self.qkv_local_v = nn.Linear(self.hidden_size, self.hidden_size * 3) # qkv_v
self.fuse_out = nn.Conv2d(dim, dim, kernel_size=1, padding=0)
self.ffn_norm = nn.LayerNorm(dim)
self.ffn = Mlp(dim)
self.attn = Attention(head_num=self.head_num)
self.PEG = PEG(dim)
def forward(self, x):
h = x
B, C, H, W = x.size()
x = x.view(B, C, H*W).permute(0, 2, 1).contiguous()
x = self.attention_norm(x).permute(0, 2, 1).contiguous()
x = x.view(B, C, H, W)
x_input = torch.chunk(self.conv_input(x), 2, dim=1)
feature_h = (x_input[0]).permute(0, 2, 3, 1).contiguous()
feature_h = feature_h.view(B * H, W, C//2)
feature_v = (x_input[1]).permute(0, 3, 2, 1).contiguous()
feature_v = feature_v.view(B * W, H, C//2)
qkv_h = torch.chunk(self.qkv_local_h(feature_h), 3, dim=2)
qkv_v = torch.chunk(self.qkv_local_v(feature_v), 3, dim=2)
q_h, k_h, v_h = qkv_h[0], qkv_h[1], qkv_h[2]
q_v, k_v, v_v = qkv_v[0], qkv_v[1], qkv_v[2]
if H == W:
query = torch.cat((q_h, q_v), dim=0)
key = torch.cat((k_h, k_v), dim=0)
value = torch.cat((v_h, v_v), dim=0)
attention_output = self.attn(query, key, value)
attention_output = torch.chunk(attention_output, 2, dim=0)
attention_output_h = attention_output[0]
attention_output_v = attention_output[1]
attention_output_h = attention_output_h.view(B, H, W, C//2).permute(0, 3, 1, 2).contiguous()
attention_output_v = attention_output_v.view(B, W, H, C//2).permute(0, 3, 2, 1).contiguous()
attn_out = self.fuse_out(torch.cat((attention_output_h, attention_output_v), dim=1))
else:
attention_output_h = self.attn(q_h, k_h, v_h)
attention_output_v = self.attn(q_v, k_v, v_v)
attention_output_h = attention_output_h.view(B, H, W, C//2).permute(0, 3, 1, 2).contiguous()
attention_output_v = attention_output_v.view(B, W, H, C//2).permute(0, 3, 2, 1).contiguous()
attn_out = self.fuse_out(torch.cat((attention_output_h, attention_output_v), dim=1))
x = attn_out + h
x = x.view(B, C, H*W).permute(0, 2, 1).contiguous()
h = x
x = self.ffn_norm(x)
x = self.ffn(x)
x = x + h
x = x.permute(0, 2, 1).contiguous()
x = x.view(B, C, H, W)
x = self.PEG(x)
return x
class Inter_SA(nn.Module):
def __init__(self,dim, head_num):
super(Inter_SA, self).__init__()
self.hidden_size = dim
self.head_num = head_num
self.attention_norm = nn.LayerNorm(self.hidden_size)
self.conv_input = nn.Conv2d(self.hidden_size, self.hidden_size, kernel_size=1, padding=0)
self.conv_h = nn.Conv2d(self.hidden_size//2, 3 * (self.hidden_size//2), kernel_size=1, padding=0) # qkv_h
self.conv_v = nn.Conv2d(self.hidden_size//2, 3 * (self.hidden_size//2), kernel_size=1, padding=0) # qkv_v
self.ffn_norm = nn.LayerNorm(self.hidden_size)
self.ffn = Mlp(self.hidden_size)
self.fuse_out = nn.Conv2d(self.hidden_size, self.hidden_size, kernel_size=1, padding=0)
self.attn = Attention(head_num=self.head_num)
self.PEG = PEG(dim)
def forward(self, x):
h = x
B, C, H, W = x.size()
x = x.view(B, C, H*W).permute(0, 2, 1).contiguous()
x = self.attention_norm(x).permute(0, 2, 1).contiguous()
x = x.view(B, C, H, W)
x_input = torch.chunk(self.conv_input(x), 2, dim=1)
feature_h = torch.chunk(self.conv_h(x_input[0]), 3, dim=1)
feature_v = torch.chunk(self.conv_v(x_input[1]), 3, dim=1)
query_h, key_h, value_h = feature_h[0], feature_h[1], feature_h[2]
query_v, key_v, value_v = feature_v[0], feature_v[1], feature_v[2]
horizontal_groups = torch.cat((query_h, key_h, value_h), dim=0)
horizontal_groups = horizontal_groups.permute(0, 2, 1, 3).contiguous()
horizontal_groups = horizontal_groups.view(3*B, H, -1)
horizontal_groups = torch.chunk(horizontal_groups, 3, dim=0)
query_h, key_h, value_h = horizontal_groups[0], horizontal_groups[1], horizontal_groups[2]
vertical_groups = torch.cat((query_v, key_v, value_v), dim=0)
vertical_groups = vertical_groups.permute(0, 3, 1, 2).contiguous()
vertical_groups = vertical_groups.view(3*B, W, -1)
vertical_groups = torch.chunk(vertical_groups, 3, dim=0)
query_v, key_v, value_v = vertical_groups[0], vertical_groups[1], vertical_groups[2]
if H == W:
query = torch.cat((query_h, query_v), dim=0)
key = torch.cat((key_h, key_v), dim=0)
value = torch.cat((value_h, value_v), dim=0)
attention_output = self.attn(query, key, value)
attention_output = torch.chunk(attention_output, 2, dim=0)
attention_output_h = attention_output[0]
attention_output_v = attention_output[1]
attention_output_h = attention_output_h.view(B, H, C//2, W).permute(0, 2, 1, 3).contiguous()
attention_output_v = attention_output_v.view(B, W, C//2, H).permute(0, 2, 3, 1).contiguous()
attn_out = self.fuse_out(torch.cat((attention_output_h, attention_output_v), dim=1))
else:
attention_output_h = self.attn(query_h, key_h, value_h)
attention_output_v = self.attn(query_v, key_v, value_v)
attention_output_h = attention_output_h.view(B, H, C//2, W).permute(0, 2, 1, 3).contiguous()
attention_output_v = attention_output_v.view(B, W, C//2, H).permute(0, 2, 3, 1).contiguous()
attn_out = self.fuse_out(torch.cat((attention_output_h, attention_output_v), dim=1))
x = attn_out + h
x = x.view(B, C, H*W).permute(0, 2, 1).contiguous()
h = x
x = self.ffn_norm(x)
x = self.ffn(x)
x = x + h
x = x.permute(0, 2, 1).contiguous()
x = x.view(B, C, H, W)
x = self.PEG(x)
return x
class Stripformer(nn.Module):
def __init__(self):
super(Stripformer, self).__init__()
self.encoder = Embeddings()
head_num = 5
dim = 320
self.Trans_block_1 = Intra_SA(dim, head_num)
self.Trans_block_2 = Inter_SA(dim, head_num)
self.Trans_block_3 = Intra_SA(dim, head_num)
self.Trans_block_4 = Inter_SA(dim, head_num)
self.Trans_block_5 = Intra_SA(dim, head_num)
self.Trans_block_6 = Inter_SA(dim, head_num)
self.Trans_block_7 = Intra_SA(dim, head_num)
self.Trans_block_8 = Inter_SA(dim, head_num)
self.Trans_block_9 = Intra_SA(dim, head_num)
self.Trans_block_10 = Inter_SA(dim, head_num)
self.Trans_block_11 = Intra_SA(dim, head_num)
self.Trans_block_12 = Inter_SA(dim, head_num)
self.decoder = Embeddings_output()
def forward(self, x):
hx, residual_1, residual_2 = self.encoder(x)
hx = self.Trans_block_1(hx)
hx = self.Trans_block_2(hx)
hx = self.Trans_block_3(hx)
hx = self.Trans_block_4(hx)
hx = self.Trans_block_5(hx)
hx = self.Trans_block_6(hx)
hx = self.Trans_block_7(hx)
hx = self.Trans_block_8(hx)
hx = self.Trans_block_9(hx)
hx = self.Trans_block_10(hx)
hx = self.Trans_block_11(hx)
hx = self.Trans_block_12(hx)
hx = self.decoder(hx, residual_1, residual_2)
return hx + x

@ -0,0 +1,131 @@
import torch
import torch.nn as nn
import torchvision.models as models
import torch.nn.functional as F
class Vgg19(torch.nn.Module):
def __init__(self, requires_grad=False):
super(Vgg19, self).__init__()
vgg_pretrained_features = models.vgg19(pretrained=True).features
self.slice1 = torch.nn.Sequential()
for x in range(12):
self.slice1.add_module(str(x), vgg_pretrained_features[x])
if not requires_grad:
for param in self.parameters():
param.requires_grad = False
def forward(self, X):
h_relu1 = self.slice1(X)
return h_relu1
class ContrastLoss(nn.Module):
def __init__(self, ablation=False):
super(ContrastLoss, self).__init__()
self.vgg = Vgg19().cuda()
self.l1 = nn.L1Loss()
self.ab = ablation
self.down_sample_4 = nn.Upsample(scale_factor=1 / 4, mode='bilinear')
def forward(self, restore, sharp, blur):
B, C, H, W = restore.size()
restore_vgg, sharp_vgg, blur_vgg = self.vgg(restore), self.vgg(sharp), self.vgg(blur)
# filter out sharp regions
threshold = 0.01
mask = torch.mean(torch.abs(sharp-blur), dim=1).view(B, 1, H, W)
mask[mask <= threshold] = 0
mask[mask > threshold] = 1
mask = self.down_sample_4(mask)
d_ap = torch.mean(torch.abs((restore_vgg - sharp_vgg.detach())), dim=1).view(B, 1, H//4, W//4)
d_an = torch.mean(torch.abs((restore_vgg - blur_vgg.detach())), dim=1).view(B, 1, H//4, W//4)
mask_size = torch.sum(mask)
contrastive = torch.sum((d_ap / (d_an + 1e-7)) * mask) / mask_size
return contrastive
class ContrastLoss_Ori(nn.Module):
def __init__(self, ablation=False):
super(ContrastLoss_Ori, self).__init__()
self.vgg = Vgg19().cuda()
self.l1 = nn.L1Loss()
self.ab = ablation
def forward(self, restore, sharp, blur):
restore_vgg, sharp_vgg, blur_vgg = self.vgg(restore), self.vgg(sharp), self.vgg(blur)
d_ap = self.l1(restore_vgg, sharp_vgg.detach())
d_an = self.l1(restore_vgg, blur_vgg.detach())
contrastive_loss = d_ap / (d_an + 1e-7)
return contrastive_loss
class CharbonnierLoss(nn.Module):
"""Charbonnier Loss (L1)"""
def __init__(self, eps=1e-3):
super(CharbonnierLoss, self).__init__()
self.eps = eps
def forward(self, x, y):
diff = x - y
# loss = torch.sum(torch.sqrt(diff * diff + self.eps))
loss = torch.mean(torch.sqrt((diff * diff) + (self.eps * self.eps)))
return loss
class EdgeLoss(nn.Module):
def __init__(self):
super(EdgeLoss, self).__init__()
k = torch.Tensor([[.05, .25, .4, .25, .05]])
self.kernel = torch.matmul(k.t(), k).unsqueeze(0).repeat(3, 1, 1, 1)
if torch.cuda.is_available():
self.kernel = self.kernel.cuda()
self.loss = CharbonnierLoss()
def conv_gauss(self, img):
n_channels, _, kw, kh = self.kernel.shape
img = F.pad(img, (kw // 2, kh // 2, kw // 2, kh // 2), mode='replicate')
return F.conv2d(img, self.kernel, groups=n_channels)
def laplacian_kernel(self, current):
filtered = self.conv_gauss(current) # filter
down = filtered[:, :, ::2, ::2] # downsample
new_filter = torch.zeros_like(filtered)
new_filter[:, :, ::2, ::2] = down * 4 # upsample
filtered = self.conv_gauss(new_filter) # filter
diff = current - filtered
return diff
def forward(self, x, y):
# x = torch.clamp(x + 0.5, min = 0,max = 1)
# y = torch.clamp(y + 0.5, min = 0,max = 1)
loss = self.loss(self.laplacian_kernel(x), self.laplacian_kernel(y))
return loss
class Stripformer_Loss(nn.Module):
def __init__(self, ):
super(Stripformer_Loss, self).__init__()
self.char = CharbonnierLoss()
self.edge = EdgeLoss()
self.contrastive = ContrastLoss()
def forward(self, restore, sharp, blur):
char = self.char(restore, sharp)
edge = 0.05 * self.edge(restore, sharp)
contrastive = 0.0005 * self.contrastive(restore, sharp, blur)
loss = char + edge + contrastive
return loss
def get_loss(model):
if model['content_loss'] == 'Stripformer_Loss':
content_loss = Stripformer_Loss()
else:
raise ValueError("ContentLoss [%s] not recognized." % model['content_loss'])
return content_loss

@ -0,0 +1,35 @@
import numpy as np
import torch.nn as nn
from skimage.measure import compare_ssim as SSIM
from util.metrics import PSNR
class DeblurModel(nn.Module):
def __init__(self):
super(DeblurModel, self).__init__()
def get_input(self, data):
img = data['a']
inputs = img
targets = data['b']
inputs, targets = inputs.cuda(), targets.cuda()
return inputs, targets
def tensor2im(self, image_tensor, imtype=np.uint8):
image_numpy = image_tensor[0].cpu().float().numpy()
image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 0.5) * 255.0
return image_numpy
def get_images_and_metrics(self, inp, output, target) -> (float, float, np.ndarray):
inp = self.tensor2im(inp)
fake = self.tensor2im(output.data)
real = self.tensor2im(target.data)
psnr = PSNR(fake, real)
ssim = SSIM(fake.astype('uint8'), real.astype('uint8'), multichannel=True)
vis_img = np.hstack((inp, fake, real))
return psnr, ssim, vis_img
def get_model(model_config):
return DeblurModel()

@ -0,0 +1,13 @@
import torch.nn as nn
from models.Stripformer import Stripformer
def get_generator(model_config):
generator_name = model_config['g_name']
if generator_name == 'Stripformer':
model_g = Stripformer()
else:
raise ValueError("Generator Network [%s] not recognized." % generator_name)
return nn.DataParallel(model_g)
def get_nets(model_config):
return get_generator(model_config)

@ -0,0 +1 @@
testing results are created in this folder

@ -0,0 +1,68 @@
from __future__ import print_function
import numpy as np
import torch
import cv2
import yaml
import os
from torch.autograd import Variable
from models.networks import get_generator
import torchvision
import time
import argparse
def get_args():
parser = argparse.ArgumentParser('Test an image')
parser.add_argument('--weights_path', required=True, help='Weights path')
return parser.parse_args()
if __name__ == '__main__':
args = get_args()
with open('config/config_Stripformer_gopro.yaml') as cfg:
config = yaml.load(cfg)
blur_path = './datasets/GoPro/test/blur/'
out_path = './out/Stripformer_GoPro_results'
if not os.path.isdir(out_path):
os.mkdir(out_path)
model = get_generator(config['model'])
model.load_state_dict(torch.load(args.weights_path))
model = model.cuda()
test_time = 0
iteration = 0
total_image_number = 1111
# warm-up
warm_up = 0
print('Hardware warm-up')
for file in os.listdir(blur_path):
for img_name in os.listdir(blur_path + '/' + file):
warm_up += 1
img = cv2.imread(blur_path + '/' + file + '/' + img_name)
img_tensor = torch.from_numpy(np.transpose(img / 255, (2, 0, 1)).astype('float32')) - 0.5
with torch.no_grad():
img_tensor = Variable(img_tensor.unsqueeze(0)).cuda()
result_image = model(img_tensor)
if warm_up == 20:
break
break
for file in os.listdir(blur_path):
if not os.path.isdir(out_path + '/' + file):
os.mkdir(out_path + '/' + file)
for img_name in os.listdir(blur_path + '/' + file):
img = cv2.imread(blur_path + '/' + file + '/' + img_name)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img_tensor = torch.from_numpy(np.transpose(img / 255, (2, 0, 1)).astype('float32')) - 0.5
with torch.no_grad():
iteration += 1
img_tensor = Variable(img_tensor.unsqueeze(0)).cuda()
start = time.time()
result_image = model(img_tensor)
stop = time.time()
print('Image:{}/{}, CNN Runtime:{:.4f}'.format(iteration, total_image_number, (stop - start)))
test_time += stop - start
print('Average Runtime:{:.4f}'.format(test_time / float(iteration)))
result_image = result_image + 0.5
out_file_name = out_path + '/' + file + '/' + img_name
torchvision.utils.save_image(result_image, out_file_name)

@ -0,0 +1,64 @@
from __future__ import print_function
import numpy as np
import torch
import cv2
import yaml
import os
from torch.autograd import Variable
from models.networks import get_generator
import torchvision
import time
import argparse
def get_args():
parser = argparse.ArgumentParser('Test an image')
parser.add_argument('--weights_path', required=True, help='Weights path')
return parser.parse_args()
if __name__ == '__main__':
args = get_args()
with open('config/config_Stripformer_gopro.yaml') as cfg:
config = yaml.load(cfg)
blur_path = './datasets/HIDE/blur/'
out_path = './out/Stripformer_HIDE_results'
if not os.path.isdir(out_path):
os.mkdir(out_path)
model = get_generator(config['model'])
model.load_state_dict(torch.load(args.weights_path))
model = model.cuda()
test_time = 0
iteration = 0
total_image_number = 2025
# warm up
warm_up = 0
print('Hardware warm-up')
for img_name in os.listdir(blur_path):
warm_up += 1
img = cv2.imread(blur_path + '/' + img_name)
img_tensor = torch.from_numpy(np.transpose(img / 255, (2, 0, 1)).astype('float32')) - 0.5
with torch.no_grad():
img_tensor = Variable(img_tensor.unsqueeze(0)).cuda()
result_image = model(img_tensor)
if warm_up == 20:
break
break
for img_name in os.listdir(blur_path):
img = cv2.imread(blur_path + '/' + img_name)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img_tensor = torch.from_numpy(np.transpose(img / 255, (2, 0, 1)).astype('float32')) - 0.5
with torch.no_grad():
iteration += 1
img_tensor = Variable(img_tensor.unsqueeze(0)).cuda()
start = time.time()
result_image = model(img_tensor)
stop = time.time()
print('Image:{}/{}, CNN Runtime:{:.4f}'.format(iteration, total_image_number, (stop - start)))
test_time += stop - start
print('Average Runtime:{:.4f}'.format(test_time / float(iteration)))
result_image = result_image + 0.5
out_file_name = out_path + '/' + img_name
torchvision.utils.save_image(result_image, out_file_name)

@ -0,0 +1,90 @@
from __future__ import print_function
import numpy as np
import torch
import cv2
import yaml
import os
from torch.autograd import Variable
from models.networks import get_generator
import torchvision
import time
import torch.nn.functional as F
import argparse
def get_args():
parser = argparse.ArgumentParser('Test an image')
parser.add_argument('--weights_path', required=True, help='Weights path')
return parser.parse_args()
if __name__ == '__main__':
args = get_args()
with open('config/config_Stripformer_gopro.yaml') as cfg:
config = yaml.load(cfg)
blur_path = './datasets/Realblur_J/test/blur/'
out_path = './out/Stripformer_realblur_J_results'
if not os.path.isdir(out_path):
os.mkdir(out_path)
model = get_generator(config['model'])
model.load_state_dict(torch.load(args.weights_path))
model = model.cuda()
test_time = 0
iteration = 0
total_image_number = 980
# warm up
warm_up = 0
print('Hardware warm-up')
for file in os.listdir(blur_path):
if not os.path.isdir(out_path + '/' + file):
os.mkdir(out_path + '/' + file)
for img_name in os.listdir(blur_path + '/' + file):
warm_up += 1
img = cv2.imread(blur_path + '/' + file + '/' + img_name)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img_tensor = torch.from_numpy(np.transpose(img / 255, (2, 0, 1)).astype('float32')) - 0.5
with torch.no_grad():
img_tensor = Variable(img_tensor.unsqueeze(0)).cuda()
factor = 8
h, w = img_tensor.shape[2], img_tensor.shape[3]
H, W = ((h + factor) // factor) * factor, ((w + factor) // factor) * factor
padh = H - h if h % factor != 0 else 0
padw = W - w if w % factor != 0 else 0
img_tensor = F.pad(img_tensor, (0, padw, 0, padh), 'reflect')
result_image = model(img_tensor)
if warm_up == 20:
break
break
for file in os.listdir(blur_path):
if not os.path.isdir(out_path + '/' + file):
os.mkdir(out_path + '/' + file)
for img_name in os.listdir(blur_path + '/' + file):
img = cv2.imread(blur_path + '/' + file + '/' + img_name)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img_tensor = torch.from_numpy(np.transpose(img / 255, (2, 0, 1)).astype('float32')) - 0.5
with torch.no_grad():
iteration += 1
img_tensor = Variable(img_tensor.unsqueeze(0)).cuda()
factor = 8
h, w = img_tensor.shape[2], img_tensor.shape[3]
H, W = ((h + factor) // factor) * factor, ((w + factor) // factor) * factor
padh = H - h if h % factor != 0 else 0
padw = W - w if w % factor != 0 else 0
img_tensor = F.pad(img_tensor, (0, padw, 0, padh), 'reflect')
H, W = img_tensor.shape[2], img_tensor.shape[3]
start = time.time()
_output = model(img_tensor)
stop = time.time()
result_image = _output[:, :, :h, :w]
result_image = torch.clamp(result_image, -0.5, 0.5)
result_image = result_image + 0.5
test_time += stop - start
print('Image:{}/{}, CNN Runtime:{:.4f}'.format(iteration, total_image_number, (stop - start)))
print('Average Runtime:{:.4f}'.format(test_time / float(iteration)))
out_file_name = out_path + '/' + file + '/' + img_name
torchvision.utils.save_image(result_image, out_file_name)

@ -0,0 +1,90 @@
from __future__ import print_function
import argparse
import numpy as np
import torch
import cv2
import yaml
import os
from torch.autograd import Variable
from models.networks import get_generator
import torchvision
import time
import torch.nn.functional as F
def get_args():
parser = argparse.ArgumentParser('Test an image')
parser.add_argument('--weights_path', required=True, help='Weights path')
return parser.parse_args()
if __name__ == '__main__':
args = get_args()
with open('config/config_Stripformer_gopro.yaml') as cfg:
config = yaml.load(cfg)
blur_path = './datasets/Realblur_R/test/blur/'
out_path = './out/Stripformer_realblur_R_results'
model = get_generator(config['model'])
model.load_state_dict(torch.load(args.weights_path))
model = model.cuda()
if not os.path.isdir(out_path):
os.mkdir(out_path)
test_time = 0
iteration = 0
total_image_number = 980
# warm up
warm_up = 0
print('Hardware warm-up')
for file in os.listdir(blur_path):
if not os.path.isdir(out_path + '/' + file):
os.mkdir(out_path + '/' + file)
for img_name in os.listdir(blur_path + '/' + file):
warm_up += 1
img = cv2.imread(blur_path + '/' + file + '/' + img_name)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img_tensor = torch.from_numpy(np.transpose(img / 255, (2, 0, 1)).astype('float32')) - 0.5
with torch.no_grad():
img_tensor = Variable(img_tensor.unsqueeze(0)).cuda()
factor = 8
h, w = img_tensor.shape[2], img_tensor.shape[3]
H, W = ((h + factor) // factor) * factor, ((w + factor) // factor) * factor
padh = H - h if h % factor != 0 else 0
padw = W - w if w % factor != 0 else 0
img_tensor = F.pad(img_tensor, (0, padw, 0, padh), 'reflect')
result_image = model(img_tensor)
if warm_up == 20:
break
break
for file in os.listdir(blur_path):
if not os.path.isdir(out_path + '/' + file):
os.mkdir(out_path + '/' + file)
for img_name in os.listdir(blur_path + '/' + file):
img = cv2.imread(blur_path + '/' + file + '/' + img_name)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img_tensor = torch.from_numpy(np.transpose(img / 255, (2, 0, 1)).astype('float32')) - 0.5
with torch.no_grad():
iteration += 1
img_tensor = Variable(img_tensor.unsqueeze(0)).cuda()
factor = 8
h, w = img_tensor.shape[2], img_tensor.shape[3]
H, W = ((h + factor) // factor) * factor, ((w + factor) // factor) * factor
padh = H - h if h % factor != 0 else 0
padw = W - w if w % factor != 0 else 0
img_tensor = F.pad(img_tensor, (0, padw, 0, padh), 'reflect')
H, W = img_tensor.shape[2], img_tensor.shape[3]
start = time.time()
_output = model(img_tensor)
stop = time.time()
result_image = _output[:, :, :h, :w]
result_image = torch.clamp(result_image, -0.5, 0.5)
result_image = result_image + 0.5
test_time += stop - start
print('Image:{}/{}, CNN Runtime:{:.4f}'.format(iteration, total_image_number, (stop - start)))
print('Average Runtime:{:.4f}'.format(test_time / float(iteration)))
out_file_name = out_path + '/' + file + '/' + img_name
torchvision.utils.save_image(result_image, out_file_name)

@ -0,0 +1,59 @@
import math
from torch.optim import lr_scheduler
class WarmRestart(lr_scheduler.CosineAnnealingLR):
"""This class implements Stochastic Gradient Descent with Warm Restarts(SGDR): https://arxiv.org/abs/1608.03983.
Set the learning rate of each parameter group using a cosine annealing schedule, When last_epoch=-1, sets initial lr as lr.
This can't support scheduler.step(epoch). please keep epoch=None.
"""
def __init__(self, optimizer, T_max=30, T_mult=1, eta_min=0, last_epoch=-1):
"""implements SGDR
Parameters:
----------
T_max : int
Maximum number of epochs.
T_mult : int
Multiplicative factor of T_max.
eta_min : int
Minimum learning rate. Default: 0.
last_epoch : int
The index of last epoch. Default: -1.
"""
self.T_mult = T_mult
super().__init__(optimizer, T_max, eta_min, last_epoch)
def get_lr(self):
if self.last_epoch == self.T_max:
self.last_epoch = 0
self.T_max *= self.T_mult
return [self.eta_min + (base_lr - self.eta_min) * (1 + math.cos(math.pi * self.last_epoch / self.T_max)) / 2 for
base_lr in self.base_lrs]
class LinearDecay(lr_scheduler._LRScheduler):
"""This class implements LinearDecay
"""
def __init__(self, optimizer, num_epochs, start_epoch=0, min_lr=0, last_epoch=-1):
"""implements LinearDecay
Parameters:
----------
"""
self.num_epochs = num_epochs
self.start_epoch = start_epoch
self.min_lr = min_lr
super().__init__(optimizer, last_epoch)
def get_lr(self):
if self.last_epoch < self.start_epoch:
return self.base_lrs
return [base_lr - ((base_lr - self.min_lr) / self.num_epochs) * (self.last_epoch - self.start_epoch) for
base_lr in self.base_lrs]

@ -0,0 +1,170 @@
import logging
from functools import partial
import os
import cv2
import torch
import torch.optim as optim
import tqdm
import yaml
from joblib import cpu_count
from torch.utils.data import DataLoader
import random
from dataset import PairedDataset
from metric_counter import MetricCounter
from models.losses import get_loss
from models.models import get_model
from models.networks import get_nets
import numpy as np
from torch.optim.lr_scheduler import CosineAnnealingLR
cv2.setNumThreads(0)
class Trainer:
def __init__(self, config, train: DataLoader, val: DataLoader):
self.config = config
self.train_dataset = train
self.val_dataset = val
self.metric_counter = MetricCounter(config['experiment_desc'])
def train(self):
self._init_params()
start_epoch = 0
if os.path.exists('last_Stripformer_gopro.pth'):
print('load_pretrained')
training_state = (torch.load('last_Stripformer_gopro.pth'))
start_epoch = training_state['epoch']
new_weight = self.netG.state_dict()
new_weight.update(training_state['model_state'])
self.netG.load_state_dict(new_weight)
new_optimizer = self.optimizer_G.state_dict()
new_optimizer.update(training_state['optimizer_state'])
self.optimizer_G.load_state_dict(new_optimizer)
new_scheduler = self.scheduler_G.state_dict()
new_scheduler.update(training_state['scheduler_state'])
self.scheduler_G.load_state_dict(new_scheduler)
else:
print('load_GoPro_pretrained')
training_state = (torch.load('final_Stripformer_pretrained.pth'))
new_weight = self.netG.state_dict()
new_weight.update(training_state)
self.netG.load_state_dict(new_weight)
for epoch in range(start_epoch, config['num_epochs']):
self._run_epoch(epoch)
if epoch % 30 == 0 or epoch == (config['num_epochs']-1):
self._validate(epoch)
self.scheduler_G.step()
scheduler_state = self.scheduler_G.state_dict()
training_state = {'epoch': epoch, 'model_state': self.netG.state_dict(),
'scheduler_state': scheduler_state, 'optimizer_state': self.optimizer_G.state_dict()}
if self.metric_counter.update_best_model():
torch.save(training_state['model_state'], 'best_{}.pth'.format(self.config['experiment_desc']))
if epoch % 200 == 0:
torch.save(training_state, 'last_{}_{}.pth'.format(self.config['experiment_desc'], epoch))
if epoch == (config['num_epochs']-1):
torch.save(training_state['model_state'], 'final_{}.pth'.format(self.config['experiment_desc']))
torch.save(training_state, 'last_{}.pth'.format(self.config['experiment_desc']))
logging.debug("Experiment Name: %s, Epoch: %d, Loss: %s" % (
self.config['experiment_desc'], epoch, self.metric_counter.loss_message()))
def _run_epoch(self, epoch):
self.metric_counter.clear()
for param_group in self.optimizer_G.param_groups:
lr = param_group['lr']
epoch_size = config.get('train_batches_per_epoch') or len(self.train_dataset)
tq = tqdm.tqdm(self.train_dataset, total=epoch_size)
tq.set_description('Epoch {}, lr {}'.format(epoch, lr))
i = 0
for data in tq:
inputs, targets = self.model.get_input(data)
outputs = self.netG(inputs)
self.optimizer_G.zero_grad()
loss_G = self.criterionG(outputs, targets, inputs)
loss_G.backward()
self.optimizer_G.step()
self.metric_counter.add_losses(loss_G.item())
curr_psnr, curr_ssim, img_for_vis = self.model.get_images_and_metrics(inputs, outputs, targets)
self.metric_counter.add_metrics(curr_psnr, curr_ssim)
tq.set_postfix(loss=self.metric_counter.loss_message())
if not i:
self.metric_counter.add_image(img_for_vis, tag='train')
i += 1
if i > epoch_size:
break
tq.close()
self.metric_counter.write_to_tensorboard(epoch)
def _validate(self, epoch):
self.metric_counter.clear()
epoch_size = config.get('val_batches_per_epoch') or len(self.val_dataset)
tq = tqdm.tqdm(self.val_dataset, total=epoch_size)
tq.set_description('Validation')
i = 0
for data in tq:
with torch.no_grad():
inputs, targets = self.model.get_input(data)
outputs = self.netG(inputs)
loss_G = self.criterionG(outputs, targets, inputs)
self.metric_counter.add_losses(loss_G.item())
curr_psnr, curr_ssim, img_for_vis = self.model.get_images_and_metrics(inputs, outputs, targets)
self.metric_counter.add_metrics(curr_psnr, curr_ssim)
if not i:
self.metric_counter.add_image(img_for_vis, tag='val')
i += 1
if i > epoch_size:
break
tq.close()
self.metric_counter.write_to_tensorboard(epoch, validation=True)
def _get_optim(self, params):
if self.config['optimizer']['name'] == 'adam':
optimizer = optim.Adam(params, lr=self.config['optimizer']['lr'])
else:
raise ValueError("Optimizer [%s] not recognized." % self.config['optimizer']['name'])
return optimizer
def _get_scheduler(self, optimizer):
if self.config['scheduler']['name'] == 'cosine':
scheduler = CosineAnnealingLR(optimizer, T_max=self.config['num_epochs'], eta_min=self.config['scheduler']['min_lr'])
else:
raise ValueError("Scheduler [%s] not recognized." % self.config['scheduler']['name'])
return scheduler
def _init_params(self):
self.criterionG = get_loss(self.config['model'])
self.netG = get_nets(self.config['model'])
self.netG.cuda()
self.model = get_model(self.config['model'])
self.optimizer_G = self._get_optim(filter(lambda p: p.requires_grad, self.netG.parameters()))
self.scheduler_G = self._get_scheduler(self.optimizer_G)
if __name__ == '__main__':
with open('config/config_Stripformer_gopro.yaml', 'r') as f:
config = yaml.load(f)
# setup
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
# set random seed
seed = 666
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
batch_size = config.pop('batch_size')
get_dataloader = partial(DataLoader, batch_size=batch_size, num_workers=cpu_count(), shuffle=True, drop_last=False)
datasets = map(config.pop, ('train', 'val'))
datasets = map(PairedDataset.from_config, datasets)
train, val = map(get_dataloader, datasets)
trainer = Trainer(config, train=train, val=val)
trainer.train()

@ -0,0 +1,164 @@
import logging
from functools import partial
import os
import cv2
import torch
import torch.optim as optim
import tqdm
import yaml
from joblib import cpu_count
from torch.utils.data import DataLoader
import random
from dataset import PairedDataset
from metric_counter import MetricCounter
from models.losses import get_loss
from models.models import get_model
from models.networks import get_nets
import numpy as np
from torch.optim.lr_scheduler import CosineAnnealingLR
cv2.setNumThreads(0)
class Trainer:
def __init__(self, config, train: DataLoader, val: DataLoader):
self.config = config
self.train_dataset = train
self.val_dataset = val
self.metric_counter = MetricCounter(config['experiment_desc'])
def train(self):
self._init_params()
start_epoch = 0
if os.path.exists('last_Stripformer_pretrained.pth'):
print('load_pretrained')
training_state = (torch.load('last_Stripformer_pretrained.pth'))
start_epoch = training_state['epoch']
new_weight = self.netG.state_dict()
new_weight.update(training_state['model_state'])
self.netG.load_state_dict(new_weight)
new_optimizer = self.optimizer_G.state_dict()
new_optimizer.update(training_state['optimizer_state'])
self.optimizer_G.load_state_dict(new_optimizer)
new_scheduler = self.scheduler_G.state_dict()
new_scheduler.update(training_state['scheduler_state'])
self.scheduler_G.load_state_dict(new_scheduler)
for epoch in range(start_epoch, config['num_epochs']):
self._run_epoch(epoch)
if epoch % 30 == 0 or epoch == (config['num_epochs']-1):
self._validate(epoch)
self.scheduler_G.step()
scheduler_state = self.scheduler_G.state_dict()
training_state = {'epoch': epoch, 'model_state': self.netG.state_dict(),
'scheduler_state': scheduler_state, 'optimizer_state': self.optimizer_G.state_dict()}
if self.metric_counter.update_best_model():
torch.save(training_state['model_state'], 'best_{}.pth'.format(self.config['experiment_desc']))
if epoch % 300 == 0:
torch.save(training_state, 'last_{}_{}.pth'.format(self.config['experiment_desc'], epoch))
if epoch == (config['num_epochs']-1):
torch.save(training_state['model_state'], 'final_{}.pth'.format(self.config['experiment_desc']))
torch.save(training_state, 'last_{}.pth'.format(self.config['experiment_desc']))
logging.debug("Experiment Name: %s, Epoch: %d, Loss: %s" % (
self.config['experiment_desc'], epoch, self.metric_counter.loss_message()))
def _run_epoch(self, epoch):
self.metric_counter.clear()
for param_group in self.optimizer_G.param_groups:
lr = param_group['lr']
epoch_size = config.get('train_batches_per_epoch') or len(self.train_dataset)
tq = tqdm.tqdm(self.train_dataset, total=epoch_size)
tq.set_description('Epoch {}, lr {}'.format(epoch, lr))
i = 0
for data in tq:
inputs, targets = self.model.get_input(data)
outputs = self.netG(inputs)
self.optimizer_G.zero_grad()
loss_G = self.criterionG(outputs, targets, inputs)
loss_G.backward()
self.optimizer_G.step()
self.metric_counter.add_losses(loss_G.item())
curr_psnr, curr_ssim, img_for_vis = self.model.get_images_and_metrics(inputs, outputs, targets)
self.metric_counter.add_metrics(curr_psnr, curr_ssim)
tq.set_postfix(loss=self.metric_counter.loss_message())
if not i:
self.metric_counter.add_image(img_for_vis, tag='train')
i += 1
if i > epoch_size:
break
tq.close()
self.metric_counter.write_to_tensorboard(epoch)
def _validate(self, epoch):
self.metric_counter.clear()
epoch_size = config.get('val_batches_per_epoch') or len(self.val_dataset)
tq = tqdm.tqdm(self.val_dataset, total=epoch_size)
tq.set_description('Validation')
i = 0
for data in tq:
with torch.no_grad():
inputs, targets = self.model.get_input(data)
outputs = self.netG(inputs)
loss_G = self.criterionG(outputs, targets, inputs)
self.metric_counter.add_losses(loss_G.item())
curr_psnr, curr_ssim, img_for_vis = self.model.get_images_and_metrics(inputs, outputs, targets)
self.metric_counter.add_metrics(curr_psnr, curr_ssim)
if not i:
self.metric_counter.add_image(img_for_vis, tag='val')
i += 1
if i > epoch_size:
break
tq.close()
self.metric_counter.write_to_tensorboard(epoch, validation=True)
def _get_optim(self, params):
if self.config['optimizer']['name'] == 'adam':
optimizer = optim.Adam(params, lr=self.config['optimizer']['lr'])
else:
raise ValueError("Optimizer [%s] not recognized." % self.config['optimizer']['name'])
return optimizer
def _get_scheduler(self, optimizer):
if self.config['scheduler']['name'] == 'cosine':
scheduler = CosineAnnealingLR(optimizer, T_max=self.config['num_epochs'], eta_min=self.config['scheduler']['min_lr'])
else:
raise ValueError("Scheduler [%s] not recognized." % self.config['scheduler']['name'])
return scheduler
def _init_params(self):
self.criterionG = get_loss(self.config['model'])
self.netG = get_nets(self.config['model'])
self.netG.cuda()
self.model = get_model(self.config['model'])
self.optimizer_G = self._get_optim(filter(lambda p: p.requires_grad, self.netG.parameters()))
self.scheduler_G = self._get_scheduler(self.optimizer_G)
if __name__ == '__main__':
with open('config/config_Stripformer_pretrained.yaml', 'r') as f:
config = yaml.load(f)
# setup
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
# set random seed
seed = 666
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
batch_size = config.pop('batch_size')
get_dataloader = partial(DataLoader, batch_size=batch_size, num_workers=cpu_count(), shuffle=True, drop_last=False)
datasets = map(config.pop, ('train', 'val'))
datasets = map(PairedDataset.from_config, datasets)
train, val = map(get_dataloader, datasets)
trainer = Trainer(config, train=train, val=val)
trainer.train()

@ -0,0 +1,66 @@
import dominate
from dominate.tags import *
import os
class HTML:
def __init__(self, web_dir, title, image_subdir='', reflesh=0):
self.title = title
self.web_dir = web_dir
# self.img_dir = os.path.join(self.web_dir, )
self.img_subdir = image_subdir
self.img_dir = os.path.join(self.web_dir, image_subdir)
if not os.path.exists(self.web_dir):
os.makedirs(self.web_dir)
if not os.path.exists(self.img_dir):
os.makedirs(self.img_dir)
# print(self.img_dir)
self.doc = dominate.document(title=title)
if reflesh > 0:
with self.doc.head:
meta(http_equiv="reflesh", content=str(reflesh))
def get_image_dir(self):
return self.img_dir
def add_header(self, str):
with self.doc:
h3(str)
def add_table(self, border=1):
self.t = table(border=border, style="table-layout: fixed;")
self.doc.add(self.t)
def add_images(self, ims, txts, links, width=400):
self.add_table()
with self.t:
with tr():
for im, txt, link in zip(ims, txts, links):
with td(style="word-wrap: break-word;", halign="center", valign="top"):
with p():
with a(href=os.path.join(link)):
img(style="width:%dpx" % width, src=os.path.join(im))
br()
p(txt)
def save(self,file='index'):
html_file = '%s/%s.html' % (self.web_dir,file)
f = open(html_file, 'wt')
f.write(self.doc.render())
f.close()
if __name__ == '__main__':
html = HTML('web/', 'test_html')
html.add_header('hello world')
ims = []
txts = []
links = []
for n in range(4):
ims.append('image_%d.png' % n)
txts.append('text_%d' % n)
links.append('image_%d.png' % n)
html.add_images(ims, txts, links)
html.save()

@ -0,0 +1,33 @@
import random
import numpy as np
import torch
from torch.autograd import Variable
from collections import deque
class ImagePool():
def __init__(self, pool_size):
self.pool_size = pool_size
self.sample_size = pool_size
if self.pool_size > 0:
self.num_imgs = 0
self.images = deque()
def add(self, images):
if self.pool_size == 0:
return images
for image in images.data:
image = torch.unsqueeze(image, 0)
if self.num_imgs < self.pool_size:
self.num_imgs = self.num_imgs + 1
self.images.append(image)
else:
self.images.popleft()
self.images.append(image)
def query(self):
if len(self.images) > self.sample_size:
return_images = list(random.sample(self.images, self.sample_size))
else:
return_images = list(self.images)
return torch.cat(return_images, 0)

@ -0,0 +1,54 @@
import math
from math import exp
import numpy as np
import torch
import torch.nn.functional as F
from torch.autograd import Variable
def gaussian(window_size, sigma):
gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)])
return gauss / gauss.sum()
def create_window(window_size, channel):
_1D_window = gaussian(window_size, 1.5).unsqueeze(1)
_2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
return window
def SSIM(img1, img2):
(_, channel, _, _) = img1.size()
window_size = 11
window = create_window(window_size, channel)
if img1.is_cuda:
window = window.cuda(img1.get_device())
window = window.type_as(img1)
mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
mu1_sq = mu1.pow(2)
mu2_sq = mu2.pow(2)
mu1_mu2 = mu1 * mu2
sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2
C1 = 0.01 ** 2
C2 = 0.03 ** 2
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
return ssim_map.mean()
def PSNR(img1, img2):
mse = np.mean((img1 / 255. - img2 / 255.) ** 2)
if mse == 0:
return 100
PIXEL_MAX = 1
return 20 * math.log10(PIXEL_MAX / math.sqrt(mse))

@ -0,0 +1,48 @@
from __future__ import print_function
import numpy as np
from PIL import Image
import numpy as np
import os
import matplotlib.pyplot as plt
import torch
def load_image(path):
if(path[-3:] == 'dng'):
import rawpy
with rawpy.imread(path) as raw:
img = raw.postprocess()
elif(path[-3:]=='bmp' or path[-3:]=='jpg' or path[-3:]=='png'):
import cv2
return cv2.imread(path)[:,:,::-1]
else:
img = (255*plt.imread(path)[:,:,:3]).astype('uint8')
return img
def save_image(image_numpy, image_path, ):
image_pil = Image.fromarray(image_numpy)
image_pil.save(image_path)
def mkdirs(paths):
if isinstance(paths, list) and not isinstance(paths, str):
for path in paths:
mkdir(path)
else:
mkdir(paths)
def mkdir(path):
if not os.path.exists(path):
os.makedirs(path)
def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.):
# def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=1.):
image_numpy = image_tensor[0].cpu().float().numpy()
image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor
return image_numpy.astype(imtype)
def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.):
# def im2tensor(image, imtype=np.uint8, cent=1., factor=1.):
return torch.Tensor((image / factor - cent)
[:, :, :, np.newaxis].transpose((3, 2, 0, 1)))

@ -0,0 +1,216 @@
import numpy as np
import os
import time
from . import util
from . import html
# from pdb import set_trace as st
import matplotlib.pyplot as plt
import math
# from IPython import embed
def zoom_to_res(img,res=256,order=0,axis=0):
# img 3xXxX
from scipy.ndimage import zoom
zoom_factor = res/img.shape[1]
if(axis==0):
return zoom(img,[1,zoom_factor,zoom_factor],order=order)
elif(axis==2):
return zoom(img,[zoom_factor,zoom_factor,1],order=order)
class Visualizer():
def __init__(self, opt):
# self.opt = opt
self.display_id = opt.display_id
# self.use_html = opt.is_train and not opt.no_html
self.win_size = opt.display_winsize
self.name = opt.name
self.display_cnt = 0 # display_current_results counter
self.display_cnt_high = 0
self.use_html = opt.use_html
if self.display_id > 0:
import visdom
self.vis = visdom.Visdom(port = opt.display_port)
self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web')
util.mkdirs([self.web_dir,])
if self.use_html:
self.img_dir = os.path.join(self.web_dir, 'images')
print('create web directory %s...' % self.web_dir)
util.mkdirs([self.img_dir,])
# |visuals|: dictionary of images to display or save
def display_current_results(self, visuals, epoch, nrows=None, res=256):
if self.display_id > 0: # show images in the browser
title = self.name
if(nrows is None):
nrows = int(math.ceil(len(visuals.items()) / 2.0))
images = []
idx = 0
for label, image_numpy in visuals.items():
title += " | " if idx % nrows == 0 else ", "
title += label
img = image_numpy.transpose([2, 0, 1])
img = zoom_to_res(img,res=res,order=0)
images.append(img)
idx += 1
if len(visuals.items()) % 2 != 0:
white_image = np.ones_like(image_numpy.transpose([2, 0, 1]))*255
white_image = zoom_to_res(white_image,res=res,order=0)
images.append(white_image)
self.vis.images(images, nrow=nrows, win=self.display_id + 1,
opts=dict(title=title))
if self.use_html: # save images to a html file
for label, image_numpy in visuals.items():
img_path = os.path.join(self.img_dir, 'epoch%.3d_cnt%.6d_%s.png' % (epoch, self.display_cnt, label))
util.save_image(zoom_to_res(image_numpy, res=res, axis=2), img_path)
self.display_cnt += 1
self.display_cnt_high = np.maximum(self.display_cnt_high, self.display_cnt)
# update website
webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, reflesh=1)
for n in range(epoch, 0, -1):
webpage.add_header('epoch [%d]' % n)
if(n==epoch):
high = self.display_cnt
else:
high = self.display_cnt_high
for c in range(high-1,-1,-1):
ims = []
txts = []
links = []
for label, image_numpy in visuals.items():
img_path = 'epoch%.3d_cnt%.6d_%s.png' % (n, c, label)
ims.append(os.path.join('images',img_path))
txts.append(label)
links.append(os.path.join('images',img_path))
webpage.add_images(ims, txts, links, width=self.win_size)
webpage.save()
# save errors into a directory
def plot_current_errors_save(self, epoch, counter_ratio, opt, errors,keys='+ALL',name='loss', to_plot=False):
if not hasattr(self, 'plot_data'):
self.plot_data = {'X':[],'Y':[], 'legend':list(errors.keys())}
self.plot_data['X'].append(epoch + counter_ratio)
self.plot_data['Y'].append([errors[k] for k in self.plot_data['legend']])
# embed()
if(keys=='+ALL'):
plot_keys = self.plot_data['legend']
else:
plot_keys = keys
if(to_plot):
(f,ax) = plt.subplots(1,1)
for (k,kname) in enumerate(plot_keys):
kk = np.where(np.array(self.plot_data['legend'])==kname)[0][0]
x = self.plot_data['X']
y = np.array(self.plot_data['Y'])[:,kk]
if(to_plot):
ax.plot(x, y, 'o-', label=kname)
np.save(os.path.join(self.web_dir,'%s_x')%kname,x)
np.save(os.path.join(self.web_dir,'%s_y')%kname,y)
if(to_plot):
plt.legend(loc=0,fontsize='small')
plt.xlabel('epoch')
plt.ylabel('Value')
f.savefig(os.path.join(self.web_dir,'%s.png'%name))
f.clf()
plt.close()
# errors: dictionary of error labels and values
def plot_current_errors(self, epoch, counter_ratio, opt, errors):
if not hasattr(self, 'plot_data'):
self.plot_data = {'X':[],'Y':[], 'legend':list(errors.keys())}
self.plot_data['X'].append(epoch + counter_ratio)
self.plot_data['Y'].append([errors[k] for k in self.plot_data['legend']])
self.vis.line(
X=np.stack([np.array(self.plot_data['X'])]*len(self.plot_data['legend']),1),
Y=np.array(self.plot_data['Y']),
opts={
'title': self.name + ' loss over time',
'legend': self.plot_data['legend'],
'xlabel': 'epoch',
'ylabel': 'loss'},
win=self.display_id)
# errors: same format as |errors| of plotCurrentErrors
def print_current_errors(self, epoch, i, errors, t, t2=-1, t2o=-1, fid=None):
message = '(ep: %d, it: %d, t: %.3f[s], ept: %.2f/%.2f[h]) ' % (epoch, i, t, t2o, t2)
message += (', ').join(['%s: %.3f' % (k, v) for k, v in errors.items()])
print(message)
if(fid is not None):
fid.write('%s\n'%message)
# save image to the disk
def save_images_simple(self, webpage, images, names, in_txts, prefix='', res=256):
image_dir = webpage.get_image_dir()
ims = []
txts = []
links = []
for name, image_numpy, txt in zip(names, images, in_txts):
image_name = '%s_%s.png' % (prefix, name)
save_path = os.path.join(image_dir, image_name)
if(res is not None):
util.save_image(zoom_to_res(image_numpy,res=res,axis=2), save_path)
else:
util.save_image(image_numpy, save_path)
ims.append(os.path.join(webpage.img_subdir,image_name))
# txts.append(name)
txts.append(txt)
links.append(os.path.join(webpage.img_subdir,image_name))
# embed()
webpage.add_images(ims, txts, links, width=self.win_size)
# save image to the disk
def save_images(self, webpage, images, names, image_path, title=''):
image_dir = webpage.get_image_dir()
# short_path = ntpath.basename(image_path)
# name = os.path.splitext(short_path)[0]
# name = short_path
# webpage.add_header('%s, %s' % (name, title))
ims = []
txts = []
links = []
for label, image_numpy in zip(names, images):
image_name = '%s.jpg' % (label,)
save_path = os.path.join(image_dir, image_name)
util.save_image(image_numpy, save_path)
ims.append(image_name)
txts.append(label)
links.append(image_name)
webpage.add_images(ims, txts, links, width=self.win_size)
# save image to the disk
# def save_images(self, webpage, visuals, image_path, short=False):
# image_dir = webpage.get_image_dir()
# if short:
# short_path = ntpath.basename(image_path)
# name = os.path.splitext(short_path)[0]
# else:
# name = image_path
# webpage.add_header(name)
# ims = []
# txts = []
# links = []
# for label, image_numpy in visuals.items():
# image_name = '%s_%s.png' % (name, label)
# save_path = os.path.join(image_dir, image_name)
# util.save_image(image_numpy, save_path)
# ims.append(image_name)
# txts.append(label)
# links.append(image_name)
# webpage.add_images(ims, txts, links, width=self.win_size)
Loading…
Cancel
Save