From df69507ef276ddc1d568d431426d9691b00180b9 Mon Sep 17 00:00:00 2001 From: pp00704831 Date: Sat, 16 Jul 2022 15:46:17 +0800 Subject: [PATCH] first version --- aug.py | 62 ++++ config/config_Stripformer_gopro.yaml | 40 +++ config/config_Stripformer_pretrained.yaml | 40 +++ dataset.py | 140 ++++++++ datasets/datasets.txt | 1 + evaluate_RealBlur_J.py | 101 ++++++ evaluate_RealBlur_R.py | 101 ++++++ evaluation_GoPro.m | 60 ++++ evaluation_HIDE.m | 60 ++++ metric_counter.py | 55 +++ models/Stripformer.py | 374 +++++++++++++++++++++ models/__init__.py | 0 models/losses.py | 131 ++++++++ models/models.py | 35 ++ models/networks.py | 13 + out/Results.txt | 1 + predict_GoPro_test_results.py | 68 ++++ predict_HIDE_results.py | 64 ++++ predict_RealBlur_J_test_results.py | 90 +++++ predict_RealBlur_R_test_results.py | 90 +++++ schedulers.py | 59 ++++ train_Stripformer_gopro.py | 170 ++++++++++ train_Stripformer_pretrained.py | 164 +++++++++ util/__init__.py | 0 util/__pycache__/__init__.cpython-36.pyc | Bin 0 -> 126 bytes util/__pycache__/image_pool.cpython-36.pyc | Bin 0 -> 1167 bytes util/__pycache__/metrics.cpython-36.pyc | Bin 0 -> 1873 bytes util/html.py | 66 ++++ util/image_pool.py | 33 ++ util/metrics.py | 54 +++ util/util.py | 48 +++ util/visualizer.py | 216 ++++++++++++ 32 files changed, 2336 insertions(+) create mode 100644 aug.py create mode 100644 config/config_Stripformer_gopro.yaml create mode 100644 config/config_Stripformer_pretrained.yaml create mode 100644 dataset.py create mode 100644 datasets/datasets.txt create mode 100644 evaluate_RealBlur_J.py create mode 100644 evaluate_RealBlur_R.py create mode 100644 evaluation_GoPro.m create mode 100644 evaluation_HIDE.m create mode 100644 metric_counter.py create mode 100644 models/Stripformer.py create mode 100644 models/__init__.py create mode 100644 models/losses.py create mode 100644 models/models.py create mode 100644 models/networks.py create mode 100644 out/Results.txt create mode 100644 predict_GoPro_test_results.py create mode 100644 predict_HIDE_results.py create mode 100644 predict_RealBlur_J_test_results.py create mode 100644 predict_RealBlur_R_test_results.py create mode 100644 schedulers.py create mode 100644 train_Stripformer_gopro.py create mode 100644 train_Stripformer_pretrained.py create mode 100644 util/__init__.py create mode 100644 util/__pycache__/__init__.cpython-36.pyc create mode 100644 util/__pycache__/image_pool.cpython-36.pyc create mode 100644 util/__pycache__/metrics.cpython-36.pyc create mode 100644 util/html.py create mode 100644 util/image_pool.py create mode 100644 util/metrics.py create mode 100644 util/util.py create mode 100644 util/visualizer.py diff --git a/aug.py b/aug.py new file mode 100644 index 0000000..09ec8d6 --- /dev/null +++ b/aug.py @@ -0,0 +1,62 @@ +from typing import List + +import albumentations as albu +from torchvision import transforms + +def get_transforms(size: int, scope: str = 'geometric', crop='random'): + augs = {'strong': albu.Compose([albu.HorizontalFlip(), + albu.ShiftScaleRotate(shift_limit=0.0, scale_limit=0.2, rotate_limit=20, p=.4), + albu.ElasticTransform(), + albu.OpticalDistortion(), + albu.OneOf([ + albu.CLAHE(clip_limit=2), + albu.IAASharpen(), + albu.IAAEmboss(), + albu.RandomBrightnessContrast(), + albu.RandomGamma() + ], p=0.5), + albu.OneOf([ + albu.RGBShift(), + albu.HueSaturationValue(), + ], p=0.5), + ]), + 'weak': albu.Compose([albu.HorizontalFlip(), + ]), + 'geometric': albu.Compose([albu.HorizontalFlip(), + albu.VerticalFlip(), + albu.RandomRotate90(), + ]), + 'None': None + } + + aug_fn = augs[scope] + crop_fn = {'random': albu.RandomCrop(size, size, always_apply=True), + 'center': albu.CenterCrop(size, size, always_apply=True)}[crop] + + pipeline = albu.Compose([aug_fn, crop_fn], additional_targets={'target': 'image'}) + + + def process(a, b): + r = pipeline(image=a, target=b) + return r['image'], r['target'] + + return process + + +def get_normalize(): + transform = transforms.Compose([ + transforms.ToTensor() + ]) + + def process(a, b): + image = transform(a).permute(1, 2, 0) - 0.5 + target = transform(b).permute(1, 2, 0) - 0.5 + return image, target + + return process + + + + + + diff --git a/config/config_Stripformer_gopro.yaml b/config/config_Stripformer_gopro.yaml new file mode 100644 index 0000000..9f9da35 --- /dev/null +++ b/config/config_Stripformer_gopro.yaml @@ -0,0 +1,40 @@ +--- +experiment_desc: Stripformer_gopro + +train: + files_a: ./datasets/GoPro/train/blur/**/*.png + files_b: ./datasets/GoPro/train/sharp/**/*.png + size: &SIZE 512 + crop: random + preload: &PRELOAD false + preload_size: &PRELOAD_SIZE 0 + bounds: [0, 1] + scope: geometric + +val: + files_a: ./datasets/GoPro/test/blur/**/*.png + files_b: ./datasets/GoPro/test/sharp/**/*.png + size: *SIZE + scope: None + crop: random + preload: *PRELOAD + preload_size: *PRELOAD_SIZE + bounds: [0, 1] + +model: + g_name: Stripformer + content_loss: Stripformer_Loss + +num_epochs: 1000 +train_batches_per_epoch: 2103 +val_batches_per_epoch: 1111 +batch_size: 8 +image_size: [512, 512] + +optimizer: + name: adam + lr: 0.0001 +scheduler: + name: cosine + start_epoch: 50 + min_lr: 0.0000001 diff --git a/config/config_Stripformer_pretrained.yaml b/config/config_Stripformer_pretrained.yaml new file mode 100644 index 0000000..ca66266 --- /dev/null +++ b/config/config_Stripformer_pretrained.yaml @@ -0,0 +1,40 @@ +--- +experiment_desc: Stripformer_pretrained + +train: + files_a: ./datasets/GoPro/train/blur/**/*.png + files_b: ./datasets/GoPro/train/sharp/**/*.png + size: &SIZE 256 + crop: random + preload: &PRELOAD false + preload_size: &PRELOAD_SIZE 0 + bounds: [0, 1] + scope: geometric + +val: + files_a: ./datasets/GoPro/test/blur/**/*.png + files_b: ./datasets/GoPro/test/sharp/**/*.png + size: *SIZE + scope: None + crop: random + preload: *PRELOAD + preload_size: *PRELOAD_SIZE + bounds: [0, 1] + +model: + g_name: Stripformer + content_loss: Stripformer_Loss + +num_epochs: 3000 +train_batches_per_epoch: 2103 +val_batches_per_epoch: 1111 +batch_size: 8 +image_size: [256, 256] + +optimizer: + name: adam + lr: 0.0001 +scheduler: + name: cosine + start_epoch: 50 + min_lr: 0.0000001 diff --git a/dataset.py b/dataset.py new file mode 100644 index 0000000..25d54f1 --- /dev/null +++ b/dataset.py @@ -0,0 +1,140 @@ +import os +from copy import deepcopy +from functools import partial +from glob import glob +from hashlib import sha1 +from typing import Callable, Iterable, Optional, Tuple + +import cv2 +import numpy as np +from glog import logger +from joblib import Parallel, cpu_count, delayed +from skimage.io import imread +from torch.utils.data import Dataset +from tqdm import tqdm + +import aug + + +def subsample(data: Iterable, bounds: Tuple[float, float], hash_fn: Callable, n_buckets=100, salt='', verbose=True): + data = list(data) + buckets = split_into_buckets(data, n_buckets=n_buckets, salt=salt, hash_fn=hash_fn) + + lower_bound, upper_bound = [x * n_buckets for x in bounds] + msg = f'Subsampling buckets from {lower_bound} to {upper_bound}, total buckets number is {n_buckets}' + if salt: + msg += f'; salt is {salt}' + if verbose: + logger.info(msg) + return np.array([sample for bucket, sample in zip(buckets, data) if lower_bound <= bucket < upper_bound]) + + +def hash_from_paths(x: Tuple[str, str], salt: str = '') -> str: + path_a, path_b = x + names = ''.join(map(os.path.basename, (path_a, path_b))) + return sha1(f'{names}_{salt}'.encode()).hexdigest() + + +def split_into_buckets(data: Iterable, n_buckets: int, hash_fn: Callable, salt=''): + hashes = map(partial(hash_fn, salt=salt), data) + return np.array([int(x, 16) % n_buckets for x in hashes]) + + +def _read_img(x: str): + img = cv2.imread(x) + if img is None: + logger.warning(f'Can not read image {x} with OpenCV, switching to scikit-image') + img = imread(x) + return img + + +class PairedDataset(Dataset): + def __init__(self, + files_a: Tuple[str], + files_b: Tuple[str], + transform_fn: Callable, + normalize_fn: Callable, + corrupt_fn: Optional[Callable] = None, + preload: bool = True, + preload_size: Optional[int] = 0, + verbose=True): + + assert len(files_a) == len(files_b) + + self.preload = preload + self.data_a = files_a + self.data_b = files_b + self.verbose = verbose + self.corrupt_fn = corrupt_fn + self.transform_fn = transform_fn + self.normalize_fn = normalize_fn + logger.info(f'Dataset has been created with {len(self.data_a)} samples') + + if preload: + preload_fn = partial(self._bulk_preload, preload_size=preload_size) + if files_a == files_b: + self.data_a = self.data_b = preload_fn(self.data_a) + else: + self.data_a, self.data_b = map(preload_fn, (self.data_a, self.data_b)) + self.preload = True + + def _bulk_preload(self, data: Iterable[str], preload_size: int): + jobs = [delayed(self._preload)(x, preload_size=preload_size) for x in data] + jobs = tqdm(jobs, desc='preloading images', disable=not self.verbose) + return Parallel(n_jobs=cpu_count(), backend='threading')(jobs) + + @staticmethod + def _preload(x: str, preload_size: int): + img = _read_img(x) + if preload_size: + h, w, *_ = img.shape + h_scale = preload_size / h + w_scale = preload_size / w + scale = max(h_scale, w_scale) + img = cv2.resize(img, fx=scale, fy=scale, dsize=None) + assert min(img.shape[:2]) >= preload_size, f'weird img shape: {img.shape}' + return img + + def _preprocess(self, img, res): + def transpose(x): + return np.transpose(x, (2, 0, 1)) + + return map(transpose, self.normalize_fn(img, res)) + + def __len__(self): + return len(self.data_a) + + def __getitem__(self, idx): + a, b = self.data_a[idx], self.data_b[idx] + if not self.preload: + a, b = map(_read_img, (a, b)) + a, b = self.transform_fn(a, b) + if self.corrupt_fn is not None: + a = self.corrupt_fn(a) + a, b = self._preprocess(a, b) + return {'a': a, 'b': b} + + @staticmethod + def from_config(config): + config = deepcopy(config) + files_a, files_b = map(lambda x: sorted(glob(config[x], recursive=True)), ('files_a', 'files_b')) + transform_fn = aug.get_transforms(size=config['size'], scope=config['scope'], crop=config['crop']) + normalize_fn = aug.get_normalize() + + hash_fn = hash_from_paths + # ToDo: add more hash functions + verbose = config.get('verbose', True) + data = subsample(data=zip(files_a, files_b), + bounds=config.get('bounds', (0, 1)), + hash_fn=hash_fn, + verbose=verbose) + + files_a, files_b = map(list, zip(*data)) + + return PairedDataset(files_a=files_a, + files_b=files_b, + preload=config['preload'], + preload_size=config['preload_size'], + normalize_fn=normalize_fn, + transform_fn=transform_fn, + verbose=verbose) diff --git a/datasets/datasets.txt b/datasets/datasets.txt new file mode 100644 index 0000000..ca64d94 --- /dev/null +++ b/datasets/datasets.txt @@ -0,0 +1 @@ +Download 'GoPro' datasets and put the datasets into folder './datasets' \ No newline at end of file diff --git a/evaluate_RealBlur_J.py b/evaluate_RealBlur_J.py new file mode 100644 index 0000000..aaad77d --- /dev/null +++ b/evaluate_RealBlur_J.py @@ -0,0 +1,101 @@ +import os +from skimage import io +import cv2 +import numpy as np +from skimage.metrics import structural_similarity +import concurrent.futures + +def image_align(deblurred, gt): + # this function is based on kohler evaluation code + z = deblurred + c = np.ones_like(z) + x = gt + + zs = (np.sum(x * z) / np.sum(z * z)) * z # simple intensity matching + + warp_mode = cv2.MOTION_HOMOGRAPHY + warp_matrix = np.eye(3, 3, dtype=np.float32) + + # Specify the number of iterations. + number_of_iterations = 100 + + termination_eps = 0 + + criteria = (cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT, + number_of_iterations, termination_eps) + + # Run the ECC algorithm. The results are stored in warp_matrix. + (cc, warp_matrix) = cv2.findTransformECC(cv2.cvtColor(x, cv2.COLOR_RGB2GRAY), cv2.cvtColor(zs, cv2.COLOR_RGB2GRAY), + warp_matrix, warp_mode, criteria, inputMask=None, gaussFiltSize=5) + + target_shape = x.shape + shift = warp_matrix + + zr = cv2.warpPerspective( + zs, + warp_matrix, + (target_shape[1], target_shape[0]), + flags=cv2.INTER_CUBIC + cv2.WARP_INVERSE_MAP, + borderMode=cv2.BORDER_REFLECT) + + cr = cv2.warpPerspective( + np.ones_like(zs, dtype='float32'), + warp_matrix, + (target_shape[1], target_shape[0]), + flags=cv2.INTER_NEAREST + cv2.WARP_INVERSE_MAP, + borderMode=cv2.BORDER_CONSTANT, + borderValue=0) + + zr = zr * cr + xr = x * cr + + return zr, xr, cr, shift + +def compute_psnr(image_true, image_test, image_mask, data_range=None): + # this function is based on skimage.metrics.peak_signal_noise_ratio + err = np.sum((image_true - image_test) ** 2, dtype=np.float64) / np.sum(image_mask) + return 10 * np.log10((data_range ** 2) / err) + + +def compute_ssim(tar_img, prd_img, cr1): + ssim_pre, ssim_map = structural_similarity(tar_img, prd_img, multichannel=True, gaussian_weights=True, + use_sample_covariance=False, data_range=1.0, full=True) + ssim_map = ssim_map * cr1 + r = int(3.5 * 1.5 + 0.5) # radius as in ndimage + win_size = 2 * r + 1 + pad = (win_size - 1) // 2 + ssim = ssim_map[pad:-pad, pad:-pad, :] + crop_cr1 = cr1[pad:-pad, pad:-pad, :] + ssim = ssim.sum(axis=0).sum(axis=0) / crop_cr1.sum(axis=0).sum(axis=0) + ssim = np.mean(ssim) + return ssim + +total_psnr = 0. +total_ssim = 0. +count = 0 +img_path = './out/Stripformer_realblur_J_results' +gt_path = './datasets/Realblur_J/test/sharp' +print(img_path) +for file in os.listdir(img_path): + for img_name in os.listdir(img_path + '/' + file): + count += 1 + number = img_name.split('_')[1] + gt_name = 'gt_' + number + img_dir = img_path + '/' + file + '/' + img_name + gt_dir = gt_path + '/' + file + '/' + gt_name + with concurrent.futures.ProcessPoolExecutor(max_workers=10) as executor: + tar_img = io.imread(gt_dir) + prd_img = io.imread(img_dir) + tar_img = tar_img.astype(np.float32) / 255.0 + prd_img = prd_img.astype(np.float32) / 255.0 + prd_img, tar_img, cr1, shift = image_align(prd_img, tar_img) + PSNR = compute_psnr(tar_img, prd_img, cr1, data_range=1) + SSIM = compute_ssim(tar_img, prd_img, cr1) + total_psnr += PSNR + total_ssim += SSIM + print(count, PSNR) + +print('PSNR:', total_psnr / count) +print('SSIM:', total_ssim / count) +print(img_path) + diff --git a/evaluate_RealBlur_R.py b/evaluate_RealBlur_R.py new file mode 100644 index 0000000..91bd441 --- /dev/null +++ b/evaluate_RealBlur_R.py @@ -0,0 +1,101 @@ +import os +from skimage import io +import cv2 +import numpy as np +from skimage.metrics import structural_similarity +import concurrent.futures + +def image_align(deblurred, gt): + # this function is based on kohler evaluation code + z = deblurred + c = np.ones_like(z) + x = gt + + zs = (np.sum(x * z) / np.sum(z * z)) * z # simple intensity matching + + warp_mode = cv2.MOTION_HOMOGRAPHY + warp_matrix = np.eye(3, 3, dtype=np.float32) + + # Specify the number of iterations. + number_of_iterations = 100 + + termination_eps = 0 + + criteria = (cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT, + number_of_iterations, termination_eps) + + # Run the ECC algorithm. The results are stored in warp_matrix. + (cc, warp_matrix) = cv2.findTransformECC(cv2.cvtColor(x, cv2.COLOR_RGB2GRAY), cv2.cvtColor(zs, cv2.COLOR_RGB2GRAY), + warp_matrix, warp_mode, criteria, inputMask=None, gaussFiltSize=5) + + target_shape = x.shape + shift = warp_matrix + + zr = cv2.warpPerspective( + zs, + warp_matrix, + (target_shape[1], target_shape[0]), + flags=cv2.INTER_CUBIC + cv2.WARP_INVERSE_MAP, + borderMode=cv2.BORDER_REFLECT) + + cr = cv2.warpPerspective( + np.ones_like(zs, dtype='float32'), + warp_matrix, + (target_shape[1], target_shape[0]), + flags=cv2.INTER_NEAREST + cv2.WARP_INVERSE_MAP, + borderMode=cv2.BORDER_CONSTANT, + borderValue=0) + + zr = zr * cr + xr = x * cr + + return zr, xr, cr, shift + +def compute_psnr(image_true, image_test, image_mask, data_range=None): + # this function is based on skimage.metrics.peak_signal_noise_ratio + err = np.sum((image_true - image_test) ** 2, dtype=np.float64) / np.sum(image_mask) + return 10 * np.log10((data_range ** 2) / err) + + +def compute_ssim(tar_img, prd_img, cr1): + ssim_pre, ssim_map = structural_similarity(tar_img, prd_img, multichannel=True, gaussian_weights=True, + use_sample_covariance=False, data_range=1.0, full=True) + ssim_map = ssim_map * cr1 + r = int(3.5 * 1.5 + 0.5) # radius as in ndimage + win_size = 2 * r + 1 + pad = (win_size - 1) // 2 + ssim = ssim_map[pad:-pad, pad:-pad, :] + crop_cr1 = cr1[pad:-pad, pad:-pad, :] + ssim = ssim.sum(axis=0).sum(axis=0) / crop_cr1.sum(axis=0).sum(axis=0) + ssim = np.mean(ssim) + return ssim + +total_psnr = 0. +total_ssim = 0. +count = 0 +img_path = './out/Stripformer_realblur_R_results' +gt_path = './datasets/Realblur_R/test/sharp' +print(img_path) +for file in os.listdir(img_path): + for img_name in os.listdir(img_path + '/' + file): + count += 1 + number = img_name.split('_')[1] + gt_name = 'gt_' + number + img_dir = img_path + '/' + file + '/' + img_name + gt_dir = gt_path + '/' + file + '/' + gt_name + with concurrent.futures.ProcessPoolExecutor(max_workers=10) as executor: + tar_img = io.imread(gt_dir) + prd_img = io.imread(img_dir) + tar_img = tar_img.astype(np.float32) / 255.0 + prd_img = prd_img.astype(np.float32) / 255.0 + prd_img, tar_img, cr1, shift = image_align(prd_img, tar_img) + PSNR = compute_psnr(tar_img, prd_img, cr1, data_range=1) + SSIM = compute_ssim(tar_img, prd_img, cr1) + total_psnr += PSNR + total_ssim += SSIM + print(count, PSNR) + +print('PSNR:', total_psnr / count) +print('SSIM:', total_ssim / count) +print(img_path) + diff --git a/evaluation_GoPro.m b/evaluation_GoPro.m new file mode 100644 index 0000000..6838242 --- /dev/null +++ b/evaluation_GoPro.m @@ -0,0 +1,60 @@ +p = genpath('.\out\stripformer_10_patches');% GoPro Deblur Results +gt = genpath('.\datasets\GoPro\test\sharp');% GoPro GT Results + +length_p = size(p,2); +path = {}; +temp = []; +for i = 1:length_p + if p(i) ~= ';' + temp = [temp p(i)]; + else + temp = [temp '\']; + path = [path ; temp]; + temp = []; + end +end +clear p length_p temp; +length_gt = size(gt,2); +path_gt = {}; +temp_gt = []; +for i = 1:length_gt + if gt(i) ~= ';' + temp_gt = [temp_gt gt(i)]; + else + temp_gt = [temp_gt '\']; + path_gt = [path_gt ; temp_gt]; + temp_gt = []; + end +end +clear gt length_gt temp_gt; + +file_num = size(path,1); +total_psnr = 0; +n = 0; +total_ssim = 0; +for i = 1:file_num + file_path = path{i}; + gt_file_path = path_gt{i}; + img_path_list = dir(strcat(file_path,'*.png')); + gt_path_list = dir(strcat(gt_file_path,'*.png')); + img_num = length(img_path_list); + if img_num > 0 + for j = 1:img_num + image_name = img_path_list(j).name; + gt_name = gt_path_list(j).name; + image = imread(strcat(file_path,image_name)); + gt = imread(strcat(gt_file_path,gt_name)); + size(image); + size(gt); + peaksnr = psnr(image,gt); + ssimval = ssim(image,gt); + total_psnr = total_psnr + peaksnr; + total_ssim = total_ssim + ssimval; + n = n + 1 + end + end +end +psnr = total_psnr / n +ssim = total_ssim / n +close all;clear all; + diff --git a/evaluation_HIDE.m b/evaluation_HIDE.m new file mode 100644 index 0000000..e9ad868 --- /dev/null +++ b/evaluation_HIDE.m @@ -0,0 +1,60 @@ +p = genpath('.\out\Stripformer_HIDE_result');% HIDE Deblur Results +gt = genpath('.\datasets\HIDE\sharp');% HIDE GT Results + +length_p = size(p,2); +path = {}; +temp = []; +for i = 1:length_p + if p(i) ~= ';' + temp = [temp p(i)]; + else + temp = [temp '\']; + path = [path ; temp]; + temp = []; + end +end +clear p length_p temp; +length_gt = size(gt,2); +path_gt = {}; +temp_gt = []; +for i = 1:length_gt + if gt(i) ~= ';' + temp_gt = [temp_gt gt(i)]; + else + temp_gt = [temp_gt '\']; + path_gt = [path_gt ; temp_gt]; + temp_gt = []; + end +end +clear gt length_gt temp_gt; + +file_num = size(path,1); +total_psnr = 0; +n = 0; +total_ssim = 0; +for i = 1:file_num + file_path = path{i}; + gt_file_path = path_gt{i}; + img_path_list = dir(strcat(file_path,'*.png')); + gt_path_list = dir(strcat(gt_file_path,'*.png')); + img_num = length(img_path_list); + if img_num > 0 + for j = 1:img_num + image_name = img_path_list(j).name; + gt_name = gt_path_list(j).name; + image = imread(strcat(file_path,image_name)); + gt = imread(strcat(gt_file_path,gt_name)); + size(image); + size(gt); + peaksnr = psnr(image,gt); + ssimval = ssim(image,gt); + total_psnr = total_psnr + peaksnr; + total_ssim = total_ssim + ssimval; + n = n + 1 + end + end +end +psnr = total_psnr / n +ssim = total_ssim / n +close all;clear all; + diff --git a/metric_counter.py b/metric_counter.py new file mode 100644 index 0000000..807ccc1 --- /dev/null +++ b/metric_counter.py @@ -0,0 +1,55 @@ +import logging +from collections import defaultdict + +import numpy as np +from tensorboardX import SummaryWriter + +WINDOW_SIZE = 100 + + +class MetricCounter: + def __init__(self, exp_name): + self.writer = SummaryWriter(exp_name) + logging.basicConfig(filename='{}.log'.format(exp_name), level=logging.DEBUG) + self.metrics = defaultdict(list) + self.images = defaultdict(list) + self.best_metric = 0 + + def add_image(self, x: np.ndarray, tag: str): + self.images[tag].append(x) + + def clear(self): + self.metrics = defaultdict(list) + self.images = defaultdict(list) + + def add_losses(self, l_G): + for name, value in zip(('G_loss', None), (l_G, None)): + self.metrics[name].append(value) + + def add_metrics(self, psnr, ssim): + for name, value in zip(('PSNR', 'SSIM'), + (psnr, ssim)): + self.metrics[name].append(value) + + def loss_message(self): + metrics = ((k, np.mean(self.metrics[k][-WINDOW_SIZE:])) for k in ('G_loss', 'PSNR', 'SSIM')) + return '; '.join(map(lambda x: f'{x[0]}={x[1]:.4f}', metrics)) + + def write_to_tensorboard(self, epoch_num, validation=False): + scalar_prefix = 'Validation' if validation else 'Train' + for tag in ('G_loss', 'SSIM', 'PSNR'): + self.writer.add_scalar(f'{scalar_prefix}_{tag}', np.mean(self.metrics[tag]), global_step=epoch_num) + for tag in self.images: + imgs = self.images[tag] + if imgs: + imgs = np.array(imgs) + self.writer.add_images(tag, imgs[:, :, :, ::-1].astype('float32') / 255, dataformats='NHWC', + global_step=epoch_num) + self.images[tag] = [] + + def update_best_model(self): + cur_metric = np.mean(self.metrics['PSNR']) + if self.best_metric < cur_metric: + self.best_metric = cur_metric + return True + return False diff --git a/models/Stripformer.py b/models/Stripformer.py new file mode 100644 index 0000000..55c1ec7 --- /dev/null +++ b/models/Stripformer.py @@ -0,0 +1,374 @@ +import torch +import torch.nn as nn +import math + +class Embeddings(nn.Module): + def __init__(self): + super(Embeddings, self).__init__() + + self.activation = nn.LeakyReLU(0.2, True) + + self.en_layer1_1 = nn.Sequential( + nn.Conv2d(3, 64, kernel_size=3, padding=1), + self.activation, + ) + self.en_layer1_2 = nn.Sequential( + nn.Conv2d(64, 64, kernel_size=3, padding=1), + self.activation, + nn.Conv2d(64, 64, kernel_size=3, padding=1)) + self.en_layer1_3 = nn.Sequential( + nn.Conv2d(64, 64, kernel_size=3, padding=1), + self.activation, + nn.Conv2d(64, 64, kernel_size=3, padding=1)) + self.en_layer1_4 = nn.Sequential( + nn.Conv2d(64, 64, kernel_size=3, padding=1), + self.activation, + nn.Conv2d(64, 64, kernel_size=3, padding=1)) + + self.en_layer2_1 = nn.Sequential( + nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1), + self.activation, + ) + self.en_layer2_2 = nn.Sequential( + nn.Conv2d(128, 128, kernel_size=3, padding=1), + self.activation, + nn.Conv2d(128, 128, kernel_size=3, padding=1)) + self.en_layer2_3 = nn.Sequential( + nn.Conv2d(128, 128, kernel_size=3, padding=1), + self.activation, + nn.Conv2d(128, 128, kernel_size=3, padding=1)) + self.en_layer2_4 = nn.Sequential( + nn.Conv2d(128, 128, kernel_size=3, padding=1), + self.activation, + nn.Conv2d(128, 128, kernel_size=3, padding=1)) + + + self.en_layer3_1 = nn.Sequential( + nn.Conv2d(128, 320, kernel_size=3, stride=2, padding=1), + self.activation, + ) + + + def forward(self, x): + + hx = self.en_layer1_1(x) + hx = self.activation(self.en_layer1_2(hx) + hx) + hx = self.activation(self.en_layer1_3(hx) + hx) + hx = self.activation(self.en_layer1_4(hx) + hx) + residual_1 = hx + hx = self.en_layer2_1(hx) + hx = self.activation(self.en_layer2_2(hx) + hx) + hx = self.activation(self.en_layer2_3(hx) + hx) + hx = self.activation(self.en_layer2_4(hx) + hx) + residual_2 = hx + hx = self.en_layer3_1(hx) + + return hx, residual_1, residual_2 + + +class Embeddings_output(nn.Module): + def __init__(self): + super(Embeddings_output, self).__init__() + + self.activation = nn.LeakyReLU(0.2, True) + + self.de_layer3_1 = nn.Sequential( + nn.ConvTranspose2d(320, 192, kernel_size=4, stride=2, padding=1), + self.activation, + ) + head_num = 3 + dim = 192 + + self.de_layer2_2 = nn.Sequential( + nn.Conv2d(192+128, 192, kernel_size=1, padding=0), + self.activation, + ) + + self.de_block_1 = Intra_SA(dim, head_num) + self.de_block_2 = Inter_SA(dim, head_num) + self.de_block_3 = Intra_SA(dim, head_num) + self.de_block_4 = Inter_SA(dim, head_num) + self.de_block_5 = Intra_SA(dim, head_num) + self.de_block_6 = Inter_SA(dim, head_num) + + + self.de_layer2_1 = nn.Sequential( + nn.ConvTranspose2d(192, 64, kernel_size=4, stride=2, padding=1), + self.activation, + ) + + self.de_layer1_3 = nn.Sequential( + nn.Conv2d(128, 64, kernel_size=1, padding=0), + self.activation, + nn.Conv2d(64, 64, kernel_size=3, padding=1)) + self.de_layer1_2 = nn.Sequential( + nn.Conv2d(64, 64, kernel_size=3, padding=1), + self.activation, + nn.Conv2d(64, 64, kernel_size=3, padding=1)) + self.de_layer1_1 = nn.Sequential( + nn.Conv2d(64, 3, kernel_size=3, padding=1), + self.activation + ) + + def forward(self, x, residual_1, residual_2): + + + hx = self.de_layer3_1(x) + + hx = self.de_layer2_2(torch.cat((hx, residual_2), dim = 1)) + hx = self.de_block_1(hx) + hx = self.de_block_2(hx) + hx = self.de_block_3(hx) + hx = self.de_block_4(hx) + hx = self.de_block_5(hx) + hx = self.de_block_6(hx) + hx = self.de_layer2_1(hx) + + hx = self.activation(self.de_layer1_3(torch.cat((hx, residual_1), dim = 1)) + hx) + hx = self.activation(self.de_layer1_2(hx) + hx) + hx = self.de_layer1_1(hx) + + return hx + +class Attention(nn.Module): + def __init__(self, head_num): + super(Attention, self).__init__() + self.num_attention_heads = head_num + self.softmax = nn.Softmax(dim=-1) + + def transpose_for_scores(self, x): + B, N, C = x.size() + attention_head_size = int(C / self.num_attention_heads) + new_x_shape = x.size()[:-1] + (self.num_attention_heads, attention_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3).contiguous() + + def forward(self, query_layer, key_layer, value_layer): + B, N, C = query_layer.size() + query_layer = self.transpose_for_scores(query_layer) + key_layer = self.transpose_for_scores(key_layer) + value_layer = self.transpose_for_scores(value_layer) + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + _, _, _, d = query_layer.size() + attention_scores = attention_scores / math.sqrt(d) + attention_probs = self.softmax(attention_scores) + context_layer = torch.matmul(attention_probs, value_layer) + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (C,) + attention_out = context_layer.view(*new_context_layer_shape) + + return attention_out + + +class Mlp(nn.Module): + def __init__(self, hidden_size): + super(Mlp, self).__init__() + self.fc1 = nn.Linear(hidden_size, 4*hidden_size) + self.fc2 = nn.Linear(4*hidden_size, hidden_size) + self.act_fn = torch.nn.functional.gelu + self._init_weights() + + def _init_weights(self): + nn.init.xavier_uniform_(self.fc1.weight) + nn.init.xavier_uniform_(self.fc2.weight) + nn.init.normal_(self.fc1.bias, std=1e-6) + nn.init.normal_(self.fc2.bias, std=1e-6) + + def forward(self, x): + x = self.fc1(x) + x = self.act_fn(x) + x = self.fc2(x) + return x + + +# CPE (Conditional Positional Embedding) +class PEG(nn.Module): + def __init__(self, hidden_size): + super(PEG, self).__init__() + self.PEG = nn.Conv2d(hidden_size, hidden_size, kernel_size=3, padding=1, groups=hidden_size) + + def forward(self, x): + x = self.PEG(x) + x + return x + + +class Intra_SA(nn.Module): + def __init__(self, dim, head_num): + super(Intra_SA, self).__init__() + self.hidden_size = dim // 2 + self.head_num = head_num + self.attention_norm = nn.LayerNorm(dim) + self.conv_input = nn.Conv2d(dim, dim, kernel_size=1, padding=0) + self.qkv_local_h = nn.Linear(self.hidden_size, self.hidden_size * 3) # qkv_h + self.qkv_local_v = nn.Linear(self.hidden_size, self.hidden_size * 3) # qkv_v + self.fuse_out = nn.Conv2d(dim, dim, kernel_size=1, padding=0) + self.ffn_norm = nn.LayerNorm(dim) + self.ffn = Mlp(dim) + self.attn = Attention(head_num=self.head_num) + self.PEG = PEG(dim) + def forward(self, x): + h = x + B, C, H, W = x.size() + + x = x.view(B, C, H*W).permute(0, 2, 1).contiguous() + x = self.attention_norm(x).permute(0, 2, 1).contiguous() + x = x.view(B, C, H, W) + + x_input = torch.chunk(self.conv_input(x), 2, dim=1) + feature_h = (x_input[0]).permute(0, 2, 3, 1).contiguous() + feature_h = feature_h.view(B * H, W, C//2) + feature_v = (x_input[1]).permute(0, 3, 2, 1).contiguous() + feature_v = feature_v.view(B * W, H, C//2) + qkv_h = torch.chunk(self.qkv_local_h(feature_h), 3, dim=2) + qkv_v = torch.chunk(self.qkv_local_v(feature_v), 3, dim=2) + q_h, k_h, v_h = qkv_h[0], qkv_h[1], qkv_h[2] + q_v, k_v, v_v = qkv_v[0], qkv_v[1], qkv_v[2] + + if H == W: + query = torch.cat((q_h, q_v), dim=0) + key = torch.cat((k_h, k_v), dim=0) + value = torch.cat((v_h, v_v), dim=0) + attention_output = self.attn(query, key, value) + attention_output = torch.chunk(attention_output, 2, dim=0) + attention_output_h = attention_output[0] + attention_output_v = attention_output[1] + attention_output_h = attention_output_h.view(B, H, W, C//2).permute(0, 3, 1, 2).contiguous() + attention_output_v = attention_output_v.view(B, W, H, C//2).permute(0, 3, 2, 1).contiguous() + attn_out = self.fuse_out(torch.cat((attention_output_h, attention_output_v), dim=1)) + else: + attention_output_h = self.attn(q_h, k_h, v_h) + attention_output_v = self.attn(q_v, k_v, v_v) + attention_output_h = attention_output_h.view(B, H, W, C//2).permute(0, 3, 1, 2).contiguous() + attention_output_v = attention_output_v.view(B, W, H, C//2).permute(0, 3, 2, 1).contiguous() + attn_out = self.fuse_out(torch.cat((attention_output_h, attention_output_v), dim=1)) + + x = attn_out + h + x = x.view(B, C, H*W).permute(0, 2, 1).contiguous() + h = x + x = self.ffn_norm(x) + x = self.ffn(x) + x = x + h + x = x.permute(0, 2, 1).contiguous() + x = x.view(B, C, H, W) + + x = self.PEG(x) + + return x + +class Inter_SA(nn.Module): + def __init__(self,dim, head_num): + super(Inter_SA, self).__init__() + self.hidden_size = dim + self.head_num = head_num + self.attention_norm = nn.LayerNorm(self.hidden_size) + self.conv_input = nn.Conv2d(self.hidden_size, self.hidden_size, kernel_size=1, padding=0) + self.conv_h = nn.Conv2d(self.hidden_size//2, 3 * (self.hidden_size//2), kernel_size=1, padding=0) # qkv_h + self.conv_v = nn.Conv2d(self.hidden_size//2, 3 * (self.hidden_size//2), kernel_size=1, padding=0) # qkv_v + self.ffn_norm = nn.LayerNorm(self.hidden_size) + self.ffn = Mlp(self.hidden_size) + self.fuse_out = nn.Conv2d(self.hidden_size, self.hidden_size, kernel_size=1, padding=0) + self.attn = Attention(head_num=self.head_num) + self.PEG = PEG(dim) + + def forward(self, x): + h = x + B, C, H, W = x.size() + + x = x.view(B, C, H*W).permute(0, 2, 1).contiguous() + x = self.attention_norm(x).permute(0, 2, 1).contiguous() + x = x.view(B, C, H, W) + + x_input = torch.chunk(self.conv_input(x), 2, dim=1) + feature_h = torch.chunk(self.conv_h(x_input[0]), 3, dim=1) + feature_v = torch.chunk(self.conv_v(x_input[1]), 3, dim=1) + query_h, key_h, value_h = feature_h[0], feature_h[1], feature_h[2] + query_v, key_v, value_v = feature_v[0], feature_v[1], feature_v[2] + + horizontal_groups = torch.cat((query_h, key_h, value_h), dim=0) + horizontal_groups = horizontal_groups.permute(0, 2, 1, 3).contiguous() + horizontal_groups = horizontal_groups.view(3*B, H, -1) + horizontal_groups = torch.chunk(horizontal_groups, 3, dim=0) + query_h, key_h, value_h = horizontal_groups[0], horizontal_groups[1], horizontal_groups[2] + + vertical_groups = torch.cat((query_v, key_v, value_v), dim=0) + vertical_groups = vertical_groups.permute(0, 3, 1, 2).contiguous() + vertical_groups = vertical_groups.view(3*B, W, -1) + vertical_groups = torch.chunk(vertical_groups, 3, dim=0) + query_v, key_v, value_v = vertical_groups[0], vertical_groups[1], vertical_groups[2] + + + if H == W: + query = torch.cat((query_h, query_v), dim=0) + key = torch.cat((key_h, key_v), dim=0) + value = torch.cat((value_h, value_v), dim=0) + attention_output = self.attn(query, key, value) + attention_output = torch.chunk(attention_output, 2, dim=0) + attention_output_h = attention_output[0] + attention_output_v = attention_output[1] + attention_output_h = attention_output_h.view(B, H, C//2, W).permute(0, 2, 1, 3).contiguous() + attention_output_v = attention_output_v.view(B, W, C//2, H).permute(0, 2, 3, 1).contiguous() + attn_out = self.fuse_out(torch.cat((attention_output_h, attention_output_v), dim=1)) + else: + attention_output_h = self.attn(query_h, key_h, value_h) + attention_output_v = self.attn(query_v, key_v, value_v) + attention_output_h = attention_output_h.view(B, H, C//2, W).permute(0, 2, 1, 3).contiguous() + attention_output_v = attention_output_v.view(B, W, C//2, H).permute(0, 2, 3, 1).contiguous() + attn_out = self.fuse_out(torch.cat((attention_output_h, attention_output_v), dim=1)) + + x = attn_out + h + x = x.view(B, C, H*W).permute(0, 2, 1).contiguous() + h = x + x = self.ffn_norm(x) + x = self.ffn(x) + x = x + h + x = x.permute(0, 2, 1).contiguous() + x = x.view(B, C, H, W) + + x = self.PEG(x) + + return x + +class Stripformer(nn.Module): + def __init__(self): + super(Stripformer, self).__init__() + + self.encoder = Embeddings() + head_num = 5 + dim = 320 + self.Trans_block_1 = Intra_SA(dim, head_num) + self.Trans_block_2 = Inter_SA(dim, head_num) + self.Trans_block_3 = Intra_SA(dim, head_num) + self.Trans_block_4 = Inter_SA(dim, head_num) + self.Trans_block_5 = Intra_SA(dim, head_num) + self.Trans_block_6 = Inter_SA(dim, head_num) + self.Trans_block_7 = Intra_SA(dim, head_num) + self.Trans_block_8 = Inter_SA(dim, head_num) + self.Trans_block_9 = Intra_SA(dim, head_num) + self.Trans_block_10 = Inter_SA(dim, head_num) + self.Trans_block_11 = Intra_SA(dim, head_num) + self.Trans_block_12 = Inter_SA(dim, head_num) + self.decoder = Embeddings_output() + + + def forward(self, x): + + hx, residual_1, residual_2 = self.encoder(x) + hx = self.Trans_block_1(hx) + hx = self.Trans_block_2(hx) + hx = self.Trans_block_3(hx) + hx = self.Trans_block_4(hx) + hx = self.Trans_block_5(hx) + hx = self.Trans_block_6(hx) + hx = self.Trans_block_7(hx) + hx = self.Trans_block_8(hx) + hx = self.Trans_block_9(hx) + hx = self.Trans_block_10(hx) + hx = self.Trans_block_11(hx) + hx = self.Trans_block_12(hx) + hx = self.decoder(hx, residual_1, residual_2) + + return hx + x + + + + diff --git a/models/__init__.py b/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/models/losses.py b/models/losses.py new file mode 100644 index 0000000..226f026 --- /dev/null +++ b/models/losses.py @@ -0,0 +1,131 @@ +import torch +import torch.nn as nn +import torchvision.models as models +import torch.nn.functional as F + +class Vgg19(torch.nn.Module): + def __init__(self, requires_grad=False): + super(Vgg19, self).__init__() + vgg_pretrained_features = models.vgg19(pretrained=True).features + self.slice1 = torch.nn.Sequential() + + for x in range(12): + self.slice1.add_module(str(x), vgg_pretrained_features[x]) + + if not requires_grad: + for param in self.parameters(): + param.requires_grad = False + + def forward(self, X): + h_relu1 = self.slice1(X) + return h_relu1 + +class ContrastLoss(nn.Module): + def __init__(self, ablation=False): + + super(ContrastLoss, self).__init__() + self.vgg = Vgg19().cuda() + self.l1 = nn.L1Loss() + self.ab = ablation + self.down_sample_4 = nn.Upsample(scale_factor=1 / 4, mode='bilinear') + def forward(self, restore, sharp, blur): + B, C, H, W = restore.size() + restore_vgg, sharp_vgg, blur_vgg = self.vgg(restore), self.vgg(sharp), self.vgg(blur) + + # filter out sharp regions + threshold = 0.01 + mask = torch.mean(torch.abs(sharp-blur), dim=1).view(B, 1, H, W) + mask[mask <= threshold] = 0 + mask[mask > threshold] = 1 + mask = self.down_sample_4(mask) + d_ap = torch.mean(torch.abs((restore_vgg - sharp_vgg.detach())), dim=1).view(B, 1, H//4, W//4) + d_an = torch.mean(torch.abs((restore_vgg - blur_vgg.detach())), dim=1).view(B, 1, H//4, W//4) + mask_size = torch.sum(mask) + contrastive = torch.sum((d_ap / (d_an + 1e-7)) * mask) / mask_size + + return contrastive + + +class ContrastLoss_Ori(nn.Module): + def __init__(self, ablation=False): + super(ContrastLoss_Ori, self).__init__() + self.vgg = Vgg19().cuda() + self.l1 = nn.L1Loss() + self.ab = ablation + + def forward(self, restore, sharp, blur): + + restore_vgg, sharp_vgg, blur_vgg = self.vgg(restore), self.vgg(sharp), self.vgg(blur) + d_ap = self.l1(restore_vgg, sharp_vgg.detach()) + d_an = self.l1(restore_vgg, blur_vgg.detach()) + contrastive_loss = d_ap / (d_an + 1e-7) + + return contrastive_loss + +class CharbonnierLoss(nn.Module): + """Charbonnier Loss (L1)""" + + def __init__(self, eps=1e-3): + super(CharbonnierLoss, self).__init__() + self.eps = eps + + def forward(self, x, y): + diff = x - y + # loss = torch.sum(torch.sqrt(diff * diff + self.eps)) + loss = torch.mean(torch.sqrt((diff * diff) + (self.eps * self.eps))) + return loss + + +class EdgeLoss(nn.Module): + def __init__(self): + super(EdgeLoss, self).__init__() + k = torch.Tensor([[.05, .25, .4, .25, .05]]) + self.kernel = torch.matmul(k.t(), k).unsqueeze(0).repeat(3, 1, 1, 1) + if torch.cuda.is_available(): + self.kernel = self.kernel.cuda() + self.loss = CharbonnierLoss() + + def conv_gauss(self, img): + n_channels, _, kw, kh = self.kernel.shape + img = F.pad(img, (kw // 2, kh // 2, kw // 2, kh // 2), mode='replicate') + return F.conv2d(img, self.kernel, groups=n_channels) + + def laplacian_kernel(self, current): + filtered = self.conv_gauss(current) # filter + down = filtered[:, :, ::2, ::2] # downsample + new_filter = torch.zeros_like(filtered) + new_filter[:, :, ::2, ::2] = down * 4 # upsample + filtered = self.conv_gauss(new_filter) # filter + diff = current - filtered + return diff + + def forward(self, x, y): + # x = torch.clamp(x + 0.5, min = 0,max = 1) + # y = torch.clamp(y + 0.5, min = 0,max = 1) + loss = self.loss(self.laplacian_kernel(x), self.laplacian_kernel(y)) + return loss + + +class Stripformer_Loss(nn.Module): + + def __init__(self, ): + super(Stripformer_Loss, self).__init__() + + self.char = CharbonnierLoss() + self.edge = EdgeLoss() + self.contrastive = ContrastLoss() + + def forward(self, restore, sharp, blur): + char = self.char(restore, sharp) + edge = 0.05 * self.edge(restore, sharp) + contrastive = 0.0005 * self.contrastive(restore, sharp, blur) + loss = char + edge + contrastive + return loss + + +def get_loss(model): + if model['content_loss'] == 'Stripformer_Loss': + content_loss = Stripformer_Loss() + else: + raise ValueError("ContentLoss [%s] not recognized." % model['content_loss']) + return content_loss \ No newline at end of file diff --git a/models/models.py b/models/models.py new file mode 100644 index 0000000..b05fb8e --- /dev/null +++ b/models/models.py @@ -0,0 +1,35 @@ +import numpy as np +import torch.nn as nn +from skimage.measure import compare_ssim as SSIM + +from util.metrics import PSNR + + +class DeblurModel(nn.Module): + def __init__(self): + super(DeblurModel, self).__init__() + + def get_input(self, data): + img = data['a'] + inputs = img + targets = data['b'] + inputs, targets = inputs.cuda(), targets.cuda() + return inputs, targets + + def tensor2im(self, image_tensor, imtype=np.uint8): + image_numpy = image_tensor[0].cpu().float().numpy() + image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 0.5) * 255.0 + return image_numpy + + def get_images_and_metrics(self, inp, output, target) -> (float, float, np.ndarray): + inp = self.tensor2im(inp) + fake = self.tensor2im(output.data) + real = self.tensor2im(target.data) + psnr = PSNR(fake, real) + ssim = SSIM(fake.astype('uint8'), real.astype('uint8'), multichannel=True) + vis_img = np.hstack((inp, fake, real)) + return psnr, ssim, vis_img + + +def get_model(model_config): + return DeblurModel() diff --git a/models/networks.py b/models/networks.py new file mode 100644 index 0000000..8dfc72a --- /dev/null +++ b/models/networks.py @@ -0,0 +1,13 @@ +import torch.nn as nn +from models.Stripformer import Stripformer + +def get_generator(model_config): + generator_name = model_config['g_name'] + if generator_name == 'Stripformer': + model_g = Stripformer() + else: + raise ValueError("Generator Network [%s] not recognized." % generator_name) + return nn.DataParallel(model_g) + +def get_nets(model_config): + return get_generator(model_config) diff --git a/out/Results.txt b/out/Results.txt new file mode 100644 index 0000000..d9b9b6a --- /dev/null +++ b/out/Results.txt @@ -0,0 +1 @@ +testing results are created in this folder \ No newline at end of file diff --git a/predict_GoPro_test_results.py b/predict_GoPro_test_results.py new file mode 100644 index 0000000..82fa930 --- /dev/null +++ b/predict_GoPro_test_results.py @@ -0,0 +1,68 @@ +from __future__ import print_function +import numpy as np +import torch +import cv2 +import yaml +import os +from torch.autograd import Variable +from models.networks import get_generator +import torchvision +import time +import argparse + +def get_args(): + parser = argparse.ArgumentParser('Test an image') + parser.add_argument('--weights_path', required=True, help='Weights path') + return parser.parse_args() + +if __name__ == '__main__': + args = get_args() + with open('config/config_Stripformer_gopro.yaml') as cfg: + config = yaml.load(cfg) + blur_path = './datasets/GoPro/test/blur/' + out_path = './out/Stripformer_GoPro_results' + if not os.path.isdir(out_path): + os.mkdir(out_path) + model = get_generator(config['model']) + model.load_state_dict(torch.load(args.weights_path)) + model = model.cuda() + + test_time = 0 + iteration = 0 + total_image_number = 1111 + + # warm-up + warm_up = 0 + print('Hardware warm-up') + for file in os.listdir(blur_path): + for img_name in os.listdir(blur_path + '/' + file): + warm_up += 1 + img = cv2.imread(blur_path + '/' + file + '/' + img_name) + img_tensor = torch.from_numpy(np.transpose(img / 255, (2, 0, 1)).astype('float32')) - 0.5 + with torch.no_grad(): + img_tensor = Variable(img_tensor.unsqueeze(0)).cuda() + result_image = model(img_tensor) + if warm_up == 20: + break + break + + for file in os.listdir(blur_path): + if not os.path.isdir(out_path + '/' + file): + os.mkdir(out_path + '/' + file) + for img_name in os.listdir(blur_path + '/' + file): + img = cv2.imread(blur_path + '/' + file + '/' + img_name) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + img_tensor = torch.from_numpy(np.transpose(img / 255, (2, 0, 1)).astype('float32')) - 0.5 + with torch.no_grad(): + iteration += 1 + img_tensor = Variable(img_tensor.unsqueeze(0)).cuda() + + start = time.time() + result_image = model(img_tensor) + stop = time.time() + print('Image:{}/{}, CNN Runtime:{:.4f}'.format(iteration, total_image_number, (stop - start))) + test_time += stop - start + print('Average Runtime:{:.4f}'.format(test_time / float(iteration))) + result_image = result_image + 0.5 + out_file_name = out_path + '/' + file + '/' + img_name + torchvision.utils.save_image(result_image, out_file_name) diff --git a/predict_HIDE_results.py b/predict_HIDE_results.py new file mode 100644 index 0000000..5e40d94 --- /dev/null +++ b/predict_HIDE_results.py @@ -0,0 +1,64 @@ +from __future__ import print_function +import numpy as np +import torch +import cv2 +import yaml +import os +from torch.autograd import Variable +from models.networks import get_generator +import torchvision +import time +import argparse + +def get_args(): + parser = argparse.ArgumentParser('Test an image') + parser.add_argument('--weights_path', required=True, help='Weights path') + return parser.parse_args() + +if __name__ == '__main__': + args = get_args() + with open('config/config_Stripformer_gopro.yaml') as cfg: + config = yaml.load(cfg) + blur_path = './datasets/HIDE/blur/' + out_path = './out/Stripformer_HIDE_results' + if not os.path.isdir(out_path): + os.mkdir(out_path) + model = get_generator(config['model']) + model.load_state_dict(torch.load(args.weights_path)) + model = model.cuda() + test_time = 0 + iteration = 0 + total_image_number = 2025 + + # warm up + warm_up = 0 + print('Hardware warm-up') + for img_name in os.listdir(blur_path): + warm_up += 1 + img = cv2.imread(blur_path + '/' + img_name) + img_tensor = torch.from_numpy(np.transpose(img / 255, (2, 0, 1)).astype('float32')) - 0.5 + with torch.no_grad(): + img_tensor = Variable(img_tensor.unsqueeze(0)).cuda() + result_image = model(img_tensor) + if warm_up == 20: + break + break + + for img_name in os.listdir(blur_path): + img = cv2.imread(blur_path + '/' + img_name) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + img_tensor = torch.from_numpy(np.transpose(img / 255, (2, 0, 1)).astype('float32')) - 0.5 + with torch.no_grad(): + iteration += 1 + img_tensor = Variable(img_tensor.unsqueeze(0)).cuda() + + start = time.time() + result_image = model(img_tensor) + stop = time.time() + + print('Image:{}/{}, CNN Runtime:{:.4f}'.format(iteration, total_image_number, (stop - start))) + test_time += stop - start + print('Average Runtime:{:.4f}'.format(test_time / float(iteration))) + result_image = result_image + 0.5 + out_file_name = out_path + '/' + img_name + torchvision.utils.save_image(result_image, out_file_name) \ No newline at end of file diff --git a/predict_RealBlur_J_test_results.py b/predict_RealBlur_J_test_results.py new file mode 100644 index 0000000..36d34fb --- /dev/null +++ b/predict_RealBlur_J_test_results.py @@ -0,0 +1,90 @@ +from __future__ import print_function +import numpy as np +import torch +import cv2 +import yaml +import os +from torch.autograd import Variable +from models.networks import get_generator +import torchvision +import time +import torch.nn.functional as F +import argparse + +def get_args(): + parser = argparse.ArgumentParser('Test an image') + parser.add_argument('--weights_path', required=True, help='Weights path') + return parser.parse_args() + +if __name__ == '__main__': + args = get_args() + with open('config/config_Stripformer_gopro.yaml') as cfg: + config = yaml.load(cfg) + blur_path = './datasets/Realblur_J/test/blur/' + out_path = './out/Stripformer_realblur_J_results' + if not os.path.isdir(out_path): + os.mkdir(out_path) + model = get_generator(config['model']) + model.load_state_dict(torch.load(args.weights_path)) + model = model.cuda() + test_time = 0 + iteration = 0 + total_image_number = 980 + + # warm up + warm_up = 0 + print('Hardware warm-up') + for file in os.listdir(blur_path): + if not os.path.isdir(out_path + '/' + file): + os.mkdir(out_path + '/' + file) + for img_name in os.listdir(blur_path + '/' + file): + warm_up += 1 + img = cv2.imread(blur_path + '/' + file + '/' + img_name) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + img_tensor = torch.from_numpy(np.transpose(img / 255, (2, 0, 1)).astype('float32')) - 0.5 + with torch.no_grad(): + img_tensor = Variable(img_tensor.unsqueeze(0)).cuda() + factor = 8 + h, w = img_tensor.shape[2], img_tensor.shape[3] + H, W = ((h + factor) // factor) * factor, ((w + factor) // factor) * factor + padh = H - h if h % factor != 0 else 0 + padw = W - w if w % factor != 0 else 0 + img_tensor = F.pad(img_tensor, (0, padw, 0, padh), 'reflect') + result_image = model(img_tensor) + if warm_up == 20: + break + break + + for file in os.listdir(blur_path): + if not os.path.isdir(out_path + '/' + file): + os.mkdir(out_path + '/' + file) + for img_name in os.listdir(blur_path + '/' + file): + img = cv2.imread(blur_path + '/' + file + '/' + img_name) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + img_tensor = torch.from_numpy(np.transpose(img / 255, (2, 0, 1)).astype('float32')) - 0.5 + with torch.no_grad(): + iteration += 1 + img_tensor = Variable(img_tensor.unsqueeze(0)).cuda() + + factor = 8 + h, w = img_tensor.shape[2], img_tensor.shape[3] + H, W = ((h + factor) // factor) * factor, ((w + factor) // factor) * factor + padh = H - h if h % factor != 0 else 0 + padw = W - w if w % factor != 0 else 0 + img_tensor = F.pad(img_tensor, (0, padw, 0, padh), 'reflect') + H, W = img_tensor.shape[2], img_tensor.shape[3] + + start = time.time() + _output = model(img_tensor) + stop = time.time() + + result_image = _output[:, :, :h, :w] + result_image = torch.clamp(result_image, -0.5, 0.5) + result_image = result_image + 0.5 + + test_time += stop - start + print('Image:{}/{}, CNN Runtime:{:.4f}'.format(iteration, total_image_number, (stop - start))) + print('Average Runtime:{:.4f}'.format(test_time / float(iteration))) + out_file_name = out_path + '/' + file + '/' + img_name + torchvision.utils.save_image(result_image, out_file_name) + diff --git a/predict_RealBlur_R_test_results.py b/predict_RealBlur_R_test_results.py new file mode 100644 index 0000000..e889df2 --- /dev/null +++ b/predict_RealBlur_R_test_results.py @@ -0,0 +1,90 @@ +from __future__ import print_function +import argparse +import numpy as np +import torch +import cv2 +import yaml +import os +from torch.autograd import Variable +from models.networks import get_generator +import torchvision +import time +import torch.nn.functional as F + +def get_args(): + parser = argparse.ArgumentParser('Test an image') + parser.add_argument('--weights_path', required=True, help='Weights path') + return parser.parse_args() + +if __name__ == '__main__': + args = get_args() + with open('config/config_Stripformer_gopro.yaml') as cfg: + config = yaml.load(cfg) + blur_path = './datasets/Realblur_R/test/blur/' + out_path = './out/Stripformer_realblur_R_results' + model = get_generator(config['model']) + model.load_state_dict(torch.load(args.weights_path)) + model = model.cuda() + if not os.path.isdir(out_path): + os.mkdir(out_path) + test_time = 0 + iteration = 0 + total_image_number = 980 + + # warm up + warm_up = 0 + print('Hardware warm-up') + for file in os.listdir(blur_path): + if not os.path.isdir(out_path + '/' + file): + os.mkdir(out_path + '/' + file) + for img_name in os.listdir(blur_path + '/' + file): + warm_up += 1 + img = cv2.imread(blur_path + '/' + file + '/' + img_name) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + img_tensor = torch.from_numpy(np.transpose(img / 255, (2, 0, 1)).astype('float32')) - 0.5 + with torch.no_grad(): + img_tensor = Variable(img_tensor.unsqueeze(0)).cuda() + factor = 8 + h, w = img_tensor.shape[2], img_tensor.shape[3] + H, W = ((h + factor) // factor) * factor, ((w + factor) // factor) * factor + padh = H - h if h % factor != 0 else 0 + padw = W - w if w % factor != 0 else 0 + img_tensor = F.pad(img_tensor, (0, padw, 0, padh), 'reflect') + result_image = model(img_tensor) + if warm_up == 20: + break + break + + for file in os.listdir(blur_path): + if not os.path.isdir(out_path + '/' + file): + os.mkdir(out_path + '/' + file) + for img_name in os.listdir(blur_path + '/' + file): + img = cv2.imread(blur_path + '/' + file + '/' + img_name) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + img_tensor = torch.from_numpy(np.transpose(img / 255, (2, 0, 1)).astype('float32')) - 0.5 + with torch.no_grad(): + iteration += 1 + img_tensor = Variable(img_tensor.unsqueeze(0)).cuda() + + factor = 8 + h, w = img_tensor.shape[2], img_tensor.shape[3] + H, W = ((h + factor) // factor) * factor, ((w + factor) // factor) * factor + padh = H - h if h % factor != 0 else 0 + padw = W - w if w % factor != 0 else 0 + img_tensor = F.pad(img_tensor, (0, padw, 0, padh), 'reflect') + H, W = img_tensor.shape[2], img_tensor.shape[3] + + start = time.time() + _output = model(img_tensor) + stop = time.time() + + result_image = _output[:, :, :h, :w] + result_image = torch.clamp(result_image, -0.5, 0.5) + result_image = result_image + 0.5 + + test_time += stop - start + print('Image:{}/{}, CNN Runtime:{:.4f}'.format(iteration, total_image_number, (stop - start))) + print('Average Runtime:{:.4f}'.format(test_time / float(iteration))) + out_file_name = out_path + '/' + file + '/' + img_name + torchvision.utils.save_image(result_image, out_file_name) + diff --git a/schedulers.py b/schedulers.py new file mode 100644 index 0000000..ca1841f --- /dev/null +++ b/schedulers.py @@ -0,0 +1,59 @@ +import math + +from torch.optim import lr_scheduler + + +class WarmRestart(lr_scheduler.CosineAnnealingLR): + """This class implements Stochastic Gradient Descent with Warm Restarts(SGDR): https://arxiv.org/abs/1608.03983. + + Set the learning rate of each parameter group using a cosine annealing schedule, When last_epoch=-1, sets initial lr as lr. + This can't support scheduler.step(epoch). please keep epoch=None. + """ + + def __init__(self, optimizer, T_max=30, T_mult=1, eta_min=0, last_epoch=-1): + """implements SGDR + + Parameters: + ---------- + T_max : int + Maximum number of epochs. + T_mult : int + Multiplicative factor of T_max. + eta_min : int + Minimum learning rate. Default: 0. + last_epoch : int + The index of last epoch. Default: -1. + """ + self.T_mult = T_mult + super().__init__(optimizer, T_max, eta_min, last_epoch) + + def get_lr(self): + if self.last_epoch == self.T_max: + self.last_epoch = 0 + self.T_max *= self.T_mult + return [self.eta_min + (base_lr - self.eta_min) * (1 + math.cos(math.pi * self.last_epoch / self.T_max)) / 2 for + base_lr in self.base_lrs] + + +class LinearDecay(lr_scheduler._LRScheduler): + """This class implements LinearDecay + + """ + + def __init__(self, optimizer, num_epochs, start_epoch=0, min_lr=0, last_epoch=-1): + """implements LinearDecay + + Parameters: + ---------- + + """ + self.num_epochs = num_epochs + self.start_epoch = start_epoch + self.min_lr = min_lr + super().__init__(optimizer, last_epoch) + + def get_lr(self): + if self.last_epoch < self.start_epoch: + return self.base_lrs + return [base_lr - ((base_lr - self.min_lr) / self.num_epochs) * (self.last_epoch - self.start_epoch) for + base_lr in self.base_lrs] diff --git a/train_Stripformer_gopro.py b/train_Stripformer_gopro.py new file mode 100644 index 0000000..2194587 --- /dev/null +++ b/train_Stripformer_gopro.py @@ -0,0 +1,170 @@ +import logging +from functools import partial +import os +import cv2 +import torch +import torch.optim as optim +import tqdm +import yaml +from joblib import cpu_count +from torch.utils.data import DataLoader +import random +from dataset import PairedDataset +from metric_counter import MetricCounter +from models.losses import get_loss +from models.models import get_model +from models.networks import get_nets +import numpy as np +from torch.optim.lr_scheduler import CosineAnnealingLR +cv2.setNumThreads(0) + + +class Trainer: + def __init__(self, config, train: DataLoader, val: DataLoader): + self.config = config + self.train_dataset = train + self.val_dataset = val + self.metric_counter = MetricCounter(config['experiment_desc']) + + def train(self): + self._init_params() + start_epoch = 0 + if os.path.exists('last_Stripformer_gopro.pth'): + print('load_pretrained') + training_state = (torch.load('last_Stripformer_gopro.pth')) + start_epoch = training_state['epoch'] + new_weight = self.netG.state_dict() + new_weight.update(training_state['model_state']) + self.netG.load_state_dict(new_weight) + new_optimizer = self.optimizer_G.state_dict() + new_optimizer.update(training_state['optimizer_state']) + self.optimizer_G.load_state_dict(new_optimizer) + new_scheduler = self.scheduler_G.state_dict() + new_scheduler.update(training_state['scheduler_state']) + self.scheduler_G.load_state_dict(new_scheduler) + else: + print('load_GoPro_pretrained') + training_state = (torch.load('final_Stripformer_pretrained.pth')) + new_weight = self.netG.state_dict() + new_weight.update(training_state) + self.netG.load_state_dict(new_weight) + + + for epoch in range(start_epoch, config['num_epochs']): + self._run_epoch(epoch) + if epoch % 30 == 0 or epoch == (config['num_epochs']-1): + self._validate(epoch) + self.scheduler_G.step() + + scheduler_state = self.scheduler_G.state_dict() + training_state = {'epoch': epoch, 'model_state': self.netG.state_dict(), + 'scheduler_state': scheduler_state, 'optimizer_state': self.optimizer_G.state_dict()} + if self.metric_counter.update_best_model(): + torch.save(training_state['model_state'], 'best_{}.pth'.format(self.config['experiment_desc'])) + + if epoch % 200 == 0: + torch.save(training_state, 'last_{}_{}.pth'.format(self.config['experiment_desc'], epoch)) + + if epoch == (config['num_epochs']-1): + torch.save(training_state['model_state'], 'final_{}.pth'.format(self.config['experiment_desc'])) + + torch.save(training_state, 'last_{}.pth'.format(self.config['experiment_desc'])) + logging.debug("Experiment Name: %s, Epoch: %d, Loss: %s" % ( + self.config['experiment_desc'], epoch, self.metric_counter.loss_message())) + + def _run_epoch(self, epoch): + self.metric_counter.clear() + for param_group in self.optimizer_G.param_groups: + lr = param_group['lr'] + + epoch_size = config.get('train_batches_per_epoch') or len(self.train_dataset) + tq = tqdm.tqdm(self.train_dataset, total=epoch_size) + tq.set_description('Epoch {}, lr {}'.format(epoch, lr)) + i = 0 + for data in tq: + inputs, targets = self.model.get_input(data) + outputs = self.netG(inputs) + self.optimizer_G.zero_grad() + loss_G = self.criterionG(outputs, targets, inputs) + loss_G.backward() + self.optimizer_G.step() + self.metric_counter.add_losses(loss_G.item()) + curr_psnr, curr_ssim, img_for_vis = self.model.get_images_and_metrics(inputs, outputs, targets) + self.metric_counter.add_metrics(curr_psnr, curr_ssim) + tq.set_postfix(loss=self.metric_counter.loss_message()) + if not i: + self.metric_counter.add_image(img_for_vis, tag='train') + i += 1 + if i > epoch_size: + break + tq.close() + self.metric_counter.write_to_tensorboard(epoch) + + def _validate(self, epoch): + self.metric_counter.clear() + epoch_size = config.get('val_batches_per_epoch') or len(self.val_dataset) + tq = tqdm.tqdm(self.val_dataset, total=epoch_size) + tq.set_description('Validation') + i = 0 + for data in tq: + with torch.no_grad(): + inputs, targets = self.model.get_input(data) + outputs = self.netG(inputs) + loss_G = self.criterionG(outputs, targets, inputs) + self.metric_counter.add_losses(loss_G.item()) + curr_psnr, curr_ssim, img_for_vis = self.model.get_images_and_metrics(inputs, outputs, targets) + self.metric_counter.add_metrics(curr_psnr, curr_ssim) + if not i: + self.metric_counter.add_image(img_for_vis, tag='val') + i += 1 + if i > epoch_size: + break + tq.close() + self.metric_counter.write_to_tensorboard(epoch, validation=True) + + + def _get_optim(self, params): + if self.config['optimizer']['name'] == 'adam': + optimizer = optim.Adam(params, lr=self.config['optimizer']['lr']) + else: + raise ValueError("Optimizer [%s] not recognized." % self.config['optimizer']['name']) + return optimizer + + def _get_scheduler(self, optimizer): + if self.config['scheduler']['name'] == 'cosine': + scheduler = CosineAnnealingLR(optimizer, T_max=self.config['num_epochs'], eta_min=self.config['scheduler']['min_lr']) + else: + raise ValueError("Scheduler [%s] not recognized." % self.config['scheduler']['name']) + return scheduler + + def _init_params(self): + self.criterionG = get_loss(self.config['model']) + self.netG = get_nets(self.config['model']) + self.netG.cuda() + self.model = get_model(self.config['model']) + self.optimizer_G = self._get_optim(filter(lambda p: p.requires_grad, self.netG.parameters())) + self.scheduler_G = self._get_scheduler(self.optimizer_G) + + +if __name__ == '__main__': + with open('config/config_Stripformer_gopro.yaml', 'r') as f: + config = yaml.load(f) + # setup + torch.backends.cudnn.enabled = True + torch.backends.cudnn.benchmark = True + + # set random seed + seed = 666 + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + random.seed(seed) + np.random.seed(seed) + + batch_size = config.pop('batch_size') + get_dataloader = partial(DataLoader, batch_size=batch_size, num_workers=cpu_count(), shuffle=True, drop_last=False) + + datasets = map(config.pop, ('train', 'val')) + datasets = map(PairedDataset.from_config, datasets) + train, val = map(get_dataloader, datasets) + trainer = Trainer(config, train=train, val=val) + trainer.train() diff --git a/train_Stripformer_pretrained.py b/train_Stripformer_pretrained.py new file mode 100644 index 0000000..aac7ab6 --- /dev/null +++ b/train_Stripformer_pretrained.py @@ -0,0 +1,164 @@ +import logging +from functools import partial +import os +import cv2 +import torch +import torch.optim as optim +import tqdm +import yaml +from joblib import cpu_count +from torch.utils.data import DataLoader +import random +from dataset import PairedDataset +from metric_counter import MetricCounter +from models.losses import get_loss +from models.models import get_model +from models.networks import get_nets +import numpy as np +from torch.optim.lr_scheduler import CosineAnnealingLR +cv2.setNumThreads(0) + + +class Trainer: + def __init__(self, config, train: DataLoader, val: DataLoader): + self.config = config + self.train_dataset = train + self.val_dataset = val + self.metric_counter = MetricCounter(config['experiment_desc']) + + def train(self): + self._init_params() + start_epoch = 0 + if os.path.exists('last_Stripformer_pretrained.pth'): + print('load_pretrained') + training_state = (torch.load('last_Stripformer_pretrained.pth')) + start_epoch = training_state['epoch'] + new_weight = self.netG.state_dict() + new_weight.update(training_state['model_state']) + self.netG.load_state_dict(new_weight) + new_optimizer = self.optimizer_G.state_dict() + new_optimizer.update(training_state['optimizer_state']) + self.optimizer_G.load_state_dict(new_optimizer) + new_scheduler = self.scheduler_G.state_dict() + new_scheduler.update(training_state['scheduler_state']) + self.scheduler_G.load_state_dict(new_scheduler) + + + for epoch in range(start_epoch, config['num_epochs']): + self._run_epoch(epoch) + if epoch % 30 == 0 or epoch == (config['num_epochs']-1): + self._validate(epoch) + self.scheduler_G.step() + + scheduler_state = self.scheduler_G.state_dict() + training_state = {'epoch': epoch, 'model_state': self.netG.state_dict(), + 'scheduler_state': scheduler_state, 'optimizer_state': self.optimizer_G.state_dict()} + if self.metric_counter.update_best_model(): + torch.save(training_state['model_state'], 'best_{}.pth'.format(self.config['experiment_desc'])) + + if epoch % 300 == 0: + torch.save(training_state, 'last_{}_{}.pth'.format(self.config['experiment_desc'], epoch)) + + if epoch == (config['num_epochs']-1): + torch.save(training_state['model_state'], 'final_{}.pth'.format(self.config['experiment_desc'])) + + torch.save(training_state, 'last_{}.pth'.format(self.config['experiment_desc'])) + logging.debug("Experiment Name: %s, Epoch: %d, Loss: %s" % ( + self.config['experiment_desc'], epoch, self.metric_counter.loss_message())) + + def _run_epoch(self, epoch): + self.metric_counter.clear() + for param_group in self.optimizer_G.param_groups: + lr = param_group['lr'] + + epoch_size = config.get('train_batches_per_epoch') or len(self.train_dataset) + tq = tqdm.tqdm(self.train_dataset, total=epoch_size) + tq.set_description('Epoch {}, lr {}'.format(epoch, lr)) + i = 0 + for data in tq: + inputs, targets = self.model.get_input(data) + outputs = self.netG(inputs) + self.optimizer_G.zero_grad() + loss_G = self.criterionG(outputs, targets, inputs) + loss_G.backward() + self.optimizer_G.step() + self.metric_counter.add_losses(loss_G.item()) + curr_psnr, curr_ssim, img_for_vis = self.model.get_images_and_metrics(inputs, outputs, targets) + self.metric_counter.add_metrics(curr_psnr, curr_ssim) + tq.set_postfix(loss=self.metric_counter.loss_message()) + if not i: + self.metric_counter.add_image(img_for_vis, tag='train') + i += 1 + if i > epoch_size: + break + tq.close() + self.metric_counter.write_to_tensorboard(epoch) + + def _validate(self, epoch): + self.metric_counter.clear() + epoch_size = config.get('val_batches_per_epoch') or len(self.val_dataset) + tq = tqdm.tqdm(self.val_dataset, total=epoch_size) + tq.set_description('Validation') + i = 0 + for data in tq: + with torch.no_grad(): + inputs, targets = self.model.get_input(data) + outputs = self.netG(inputs) + loss_G = self.criterionG(outputs, targets, inputs) + self.metric_counter.add_losses(loss_G.item()) + curr_psnr, curr_ssim, img_for_vis = self.model.get_images_and_metrics(inputs, outputs, targets) + self.metric_counter.add_metrics(curr_psnr, curr_ssim) + if not i: + self.metric_counter.add_image(img_for_vis, tag='val') + i += 1 + if i > epoch_size: + break + tq.close() + self.metric_counter.write_to_tensorboard(epoch, validation=True) + + + def _get_optim(self, params): + if self.config['optimizer']['name'] == 'adam': + optimizer = optim.Adam(params, lr=self.config['optimizer']['lr']) + else: + raise ValueError("Optimizer [%s] not recognized." % self.config['optimizer']['name']) + return optimizer + + def _get_scheduler(self, optimizer): + if self.config['scheduler']['name'] == 'cosine': + scheduler = CosineAnnealingLR(optimizer, T_max=self.config['num_epochs'], eta_min=self.config['scheduler']['min_lr']) + else: + raise ValueError("Scheduler [%s] not recognized." % self.config['scheduler']['name']) + return scheduler + + def _init_params(self): + self.criterionG = get_loss(self.config['model']) + self.netG = get_nets(self.config['model']) + self.netG.cuda() + self.model = get_model(self.config['model']) + self.optimizer_G = self._get_optim(filter(lambda p: p.requires_grad, self.netG.parameters())) + self.scheduler_G = self._get_scheduler(self.optimizer_G) + + +if __name__ == '__main__': + with open('config/config_Stripformer_pretrained.yaml', 'r') as f: + config = yaml.load(f) + # setup + torch.backends.cudnn.enabled = True + torch.backends.cudnn.benchmark = True + + # set random seed + seed = 666 + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + random.seed(seed) + np.random.seed(seed) + + batch_size = config.pop('batch_size') + get_dataloader = partial(DataLoader, batch_size=batch_size, num_workers=cpu_count(), shuffle=True, drop_last=False) + + datasets = map(config.pop, ('train', 'val')) + datasets = map(PairedDataset.from_config, datasets) + train, val = map(get_dataloader, datasets) + trainer = Trainer(config, train=train, val=val) + trainer.train() diff --git a/util/__init__.py b/util/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/util/__pycache__/__init__.cpython-36.pyc b/util/__pycache__/__init__.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9b55f4d90d0673b6413fa436ec62519012688787 GIT binary patch literal 126 zcmXr!<>h)4aXNtk2p)q77+?f49Dul(1xTbY1T$zd`mJOr0tq9CU)uU6Ax`>*<*61X zW`@Q_`Yx$SIi*GJj(%lE`lThAIr{POnR%Hd@$q^EmA5!-a`RJ4b5iXs;;iC>hJ5Coo=`N=i|@Mzc?Yk$-;m&*@syjfDlB` zj6{v55wmL^anLMtayN1*?(xja{m7@}hzLixS46mT76pR7BcXQ3GI88}cbII%hKmQ4U645Os{UU>lYQu-&0IWphf&c-^l2IiZ+!!T$js z9SGwMVb_+mz`{3@+$^5JSK0AajjG?Ild{-^|E&veKYf^Ifl zON*tj5leW4Cqk@(QLTX>K#D^H&&}D6pdz1hT4-8mv_zihu@w(6PR-r8dB0T+A)&;%5@HPs1`)~@{3b3-`hZ9iVml02+yLmj;7txOegU+=$1)xS5 zeuAFhdxT*|9(y%;il_K0nu2VJgg3LMXcd8bYio?*3;k$!9*iXzVhAEZ5`BHD#Wj-n zDQ0K|oq&H^LS0nD?5V-a;B*znk7S2}JG`&G1cYm)?0b0?)d5jt{j|al0oWoL z(~?!YBa!aKzo3_=HqPFVq7#cQT0@A=0Q^Kc83QC>Iw19Bd4zvg?YE?s!1 zj%Hy2JUqb4nXMU0X2NjMAM-e>d>rlhsOO88LPnFR3Acr$sn_Wg^`pZ&a}WmI+SPMy=$M)}wi%q6XRAX+t_)xdLl4Tg7!}dv^;DzV+t+U# zgTur-AasUN>=OqsfINne5g+^5BWsX%YlLD-PGi0S`wp{Dv#4vql;mXqx;^TmT!p@- zC3H2iq;m@jlUd&o{w{-ZV|HsXr-oK3Es#_~`RY!$tRv*kOYx8DqKT3z#eD|Oxo)5P z%OC<7Q6djKM|3K55$o7ZoD^x5dfrQ2^XaC3X> zivesDy7M5sf<54XVtx*0!6H8cYTj|nrkFh9x_k3Jh;YgDz4fwOzt@!GI<3kmhijF} z|9{Rkui?UMXVlbH!XsunhiNS)g$$+UyjL!#Z^LJ;@qAISCXX-ZmLzmX^l*=OaL*p@ KS)T2=-s*o~uBv7L literal 0 HcmV?d00001 diff --git a/util/html.py b/util/html.py new file mode 100644 index 0000000..ebec4fb --- /dev/null +++ b/util/html.py @@ -0,0 +1,66 @@ +import dominate +from dominate.tags import * +import os + + +class HTML: + def __init__(self, web_dir, title, image_subdir='', reflesh=0): + self.title = title + self.web_dir = web_dir + # self.img_dir = os.path.join(self.web_dir, ) + self.img_subdir = image_subdir + self.img_dir = os.path.join(self.web_dir, image_subdir) + if not os.path.exists(self.web_dir): + os.makedirs(self.web_dir) + if not os.path.exists(self.img_dir): + os.makedirs(self.img_dir) + # print(self.img_dir) + + self.doc = dominate.document(title=title) + if reflesh > 0: + with self.doc.head: + meta(http_equiv="reflesh", content=str(reflesh)) + + def get_image_dir(self): + return self.img_dir + + def add_header(self, str): + with self.doc: + h3(str) + + def add_table(self, border=1): + self.t = table(border=border, style="table-layout: fixed;") + self.doc.add(self.t) + + def add_images(self, ims, txts, links, width=400): + self.add_table() + with self.t: + with tr(): + for im, txt, link in zip(ims, txts, links): + with td(style="word-wrap: break-word;", halign="center", valign="top"): + with p(): + with a(href=os.path.join(link)): + img(style="width:%dpx" % width, src=os.path.join(im)) + br() + p(txt) + + def save(self,file='index'): + html_file = '%s/%s.html' % (self.web_dir,file) + f = open(html_file, 'wt') + f.write(self.doc.render()) + f.close() + + +if __name__ == '__main__': + html = HTML('web/', 'test_html') + html.add_header('hello world') + + ims = [] + txts = [] + links = [] + for n in range(4): + ims.append('image_%d.png' % n) + txts.append('text_%d' % n) + links.append('image_%d.png' % n) + html.add_images(ims, txts, links) + html.save() diff --git a/util/image_pool.py b/util/image_pool.py new file mode 100644 index 0000000..590bba8 --- /dev/null +++ b/util/image_pool.py @@ -0,0 +1,33 @@ +import random +import numpy as np +import torch +from torch.autograd import Variable +from collections import deque + + +class ImagePool(): + def __init__(self, pool_size): + self.pool_size = pool_size + self.sample_size = pool_size + if self.pool_size > 0: + self.num_imgs = 0 + self.images = deque() + + def add(self, images): + if self.pool_size == 0: + return images + for image in images.data: + image = torch.unsqueeze(image, 0) + if self.num_imgs < self.pool_size: + self.num_imgs = self.num_imgs + 1 + self.images.append(image) + else: + self.images.popleft() + self.images.append(image) + + def query(self): + if len(self.images) > self.sample_size: + return_images = list(random.sample(self.images, self.sample_size)) + else: + return_images = list(self.images) + return torch.cat(return_images, 0) diff --git a/util/metrics.py b/util/metrics.py new file mode 100644 index 0000000..13e4671 --- /dev/null +++ b/util/metrics.py @@ -0,0 +1,54 @@ +import math +from math import exp + +import numpy as np +import torch +import torch.nn.functional as F +from torch.autograd import Variable + + +def gaussian(window_size, sigma): + gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)]) + return gauss / gauss.sum() + + +def create_window(window_size, channel): + _1D_window = gaussian(window_size, 1.5).unsqueeze(1) + _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) + window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) + return window + + +def SSIM(img1, img2): + (_, channel, _, _) = img1.size() + window_size = 11 + window = create_window(window_size, channel) + + if img1.is_cuda: + window = window.cuda(img1.get_device()) + window = window.type_as(img1) + + mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) + mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) + + mu1_sq = mu1.pow(2) + mu2_sq = mu2.pow(2) + mu1_mu2 = mu1 * mu2 + + sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq + sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq + sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2 + + C1 = 0.01 ** 2 + C2 = 0.03 ** 2 + + ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) + return ssim_map.mean() + + +def PSNR(img1, img2): + mse = np.mean((img1 / 255. - img2 / 255.) ** 2) + if mse == 0: + return 100 + PIXEL_MAX = 1 + return 20 * math.log10(PIXEL_MAX / math.sqrt(mse)) diff --git a/util/util.py b/util/util.py new file mode 100644 index 0000000..6d6367a --- /dev/null +++ b/util/util.py @@ -0,0 +1,48 @@ +from __future__ import print_function + +import numpy as np +from PIL import Image +import numpy as np +import os +import matplotlib.pyplot as plt +import torch + +def load_image(path): + if(path[-3:] == 'dng'): + import rawpy + with rawpy.imread(path) as raw: + img = raw.postprocess() + elif(path[-3:]=='bmp' or path[-3:]=='jpg' or path[-3:]=='png'): + import cv2 + return cv2.imread(path)[:,:,::-1] + else: + img = (255*plt.imread(path)[:,:,:3]).astype('uint8') + + return img + +def save_image(image_numpy, image_path, ): + image_pil = Image.fromarray(image_numpy) + image_pil.save(image_path) + +def mkdirs(paths): + if isinstance(paths, list) and not isinstance(paths, str): + for path in paths: + mkdir(path) + else: + mkdir(paths) + +def mkdir(path): + if not os.path.exists(path): + os.makedirs(path) + + +def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.): +# def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=1.): + image_numpy = image_tensor[0].cpu().float().numpy() + image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor + return image_numpy.astype(imtype) + +def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.): +# def im2tensor(image, imtype=np.uint8, cent=1., factor=1.): + return torch.Tensor((image / factor - cent) + [:, :, :, np.newaxis].transpose((3, 2, 0, 1))) diff --git a/util/visualizer.py b/util/visualizer.py new file mode 100644 index 0000000..1509c8a --- /dev/null +++ b/util/visualizer.py @@ -0,0 +1,216 @@ +import numpy as np +import os +import time +from . import util +from . import html +# from pdb import set_trace as st +import matplotlib.pyplot as plt +import math +# from IPython import embed + +def zoom_to_res(img,res=256,order=0,axis=0): + # img 3xXxX + from scipy.ndimage import zoom + zoom_factor = res/img.shape[1] + if(axis==0): + return zoom(img,[1,zoom_factor,zoom_factor],order=order) + elif(axis==2): + return zoom(img,[zoom_factor,zoom_factor,1],order=order) + +class Visualizer(): + def __init__(self, opt): + # self.opt = opt + self.display_id = opt.display_id + # self.use_html = opt.is_train and not opt.no_html + self.win_size = opt.display_winsize + self.name = opt.name + self.display_cnt = 0 # display_current_results counter + self.display_cnt_high = 0 + self.use_html = opt.use_html + + if self.display_id > 0: + import visdom + self.vis = visdom.Visdom(port = opt.display_port) + + self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web') + util.mkdirs([self.web_dir,]) + if self.use_html: + self.img_dir = os.path.join(self.web_dir, 'images') + print('create web directory %s...' % self.web_dir) + util.mkdirs([self.img_dir,]) + + # |visuals|: dictionary of images to display or save + def display_current_results(self, visuals, epoch, nrows=None, res=256): + if self.display_id > 0: # show images in the browser + title = self.name + if(nrows is None): + nrows = int(math.ceil(len(visuals.items()) / 2.0)) + images = [] + idx = 0 + for label, image_numpy in visuals.items(): + title += " | " if idx % nrows == 0 else ", " + title += label + img = image_numpy.transpose([2, 0, 1]) + img = zoom_to_res(img,res=res,order=0) + images.append(img) + idx += 1 + if len(visuals.items()) % 2 != 0: + white_image = np.ones_like(image_numpy.transpose([2, 0, 1]))*255 + white_image = zoom_to_res(white_image,res=res,order=0) + images.append(white_image) + self.vis.images(images, nrow=nrows, win=self.display_id + 1, + opts=dict(title=title)) + + if self.use_html: # save images to a html file + for label, image_numpy in visuals.items(): + img_path = os.path.join(self.img_dir, 'epoch%.3d_cnt%.6d_%s.png' % (epoch, self.display_cnt, label)) + util.save_image(zoom_to_res(image_numpy, res=res, axis=2), img_path) + + self.display_cnt += 1 + self.display_cnt_high = np.maximum(self.display_cnt_high, self.display_cnt) + + # update website + webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, reflesh=1) + for n in range(epoch, 0, -1): + webpage.add_header('epoch [%d]' % n) + if(n==epoch): + high = self.display_cnt + else: + high = self.display_cnt_high + for c in range(high-1,-1,-1): + ims = [] + txts = [] + links = [] + + for label, image_numpy in visuals.items(): + img_path = 'epoch%.3d_cnt%.6d_%s.png' % (n, c, label) + ims.append(os.path.join('images',img_path)) + txts.append(label) + links.append(os.path.join('images',img_path)) + webpage.add_images(ims, txts, links, width=self.win_size) + webpage.save() + + # save errors into a directory + def plot_current_errors_save(self, epoch, counter_ratio, opt, errors,keys='+ALL',name='loss', to_plot=False): + if not hasattr(self, 'plot_data'): + self.plot_data = {'X':[],'Y':[], 'legend':list(errors.keys())} + self.plot_data['X'].append(epoch + counter_ratio) + self.plot_data['Y'].append([errors[k] for k in self.plot_data['legend']]) + + # embed() + if(keys=='+ALL'): + plot_keys = self.plot_data['legend'] + else: + plot_keys = keys + + if(to_plot): + (f,ax) = plt.subplots(1,1) + for (k,kname) in enumerate(plot_keys): + kk = np.where(np.array(self.plot_data['legend'])==kname)[0][0] + x = self.plot_data['X'] + y = np.array(self.plot_data['Y'])[:,kk] + if(to_plot): + ax.plot(x, y, 'o-', label=kname) + np.save(os.path.join(self.web_dir,'%s_x')%kname,x) + np.save(os.path.join(self.web_dir,'%s_y')%kname,y) + + if(to_plot): + plt.legend(loc=0,fontsize='small') + plt.xlabel('epoch') + plt.ylabel('Value') + f.savefig(os.path.join(self.web_dir,'%s.png'%name)) + f.clf() + plt.close() + + # errors: dictionary of error labels and values + def plot_current_errors(self, epoch, counter_ratio, opt, errors): + if not hasattr(self, 'plot_data'): + self.plot_data = {'X':[],'Y':[], 'legend':list(errors.keys())} + self.plot_data['X'].append(epoch + counter_ratio) + self.plot_data['Y'].append([errors[k] for k in self.plot_data['legend']]) + self.vis.line( + X=np.stack([np.array(self.plot_data['X'])]*len(self.plot_data['legend']),1), + Y=np.array(self.plot_data['Y']), + opts={ + 'title': self.name + ' loss over time', + 'legend': self.plot_data['legend'], + 'xlabel': 'epoch', + 'ylabel': 'loss'}, + win=self.display_id) + + # errors: same format as |errors| of plotCurrentErrors + def print_current_errors(self, epoch, i, errors, t, t2=-1, t2o=-1, fid=None): + message = '(ep: %d, it: %d, t: %.3f[s], ept: %.2f/%.2f[h]) ' % (epoch, i, t, t2o, t2) + message += (', ').join(['%s: %.3f' % (k, v) for k, v in errors.items()]) + + print(message) + if(fid is not None): + fid.write('%s\n'%message) + + + # save image to the disk + def save_images_simple(self, webpage, images, names, in_txts, prefix='', res=256): + image_dir = webpage.get_image_dir() + ims = [] + txts = [] + links = [] + + for name, image_numpy, txt in zip(names, images, in_txts): + image_name = '%s_%s.png' % (prefix, name) + save_path = os.path.join(image_dir, image_name) + if(res is not None): + util.save_image(zoom_to_res(image_numpy,res=res,axis=2), save_path) + else: + util.save_image(image_numpy, save_path) + + ims.append(os.path.join(webpage.img_subdir,image_name)) + # txts.append(name) + txts.append(txt) + links.append(os.path.join(webpage.img_subdir,image_name)) + # embed() + webpage.add_images(ims, txts, links, width=self.win_size) + + # save image to the disk + def save_images(self, webpage, images, names, image_path, title=''): + image_dir = webpage.get_image_dir() + # short_path = ntpath.basename(image_path) + # name = os.path.splitext(short_path)[0] + # name = short_path + # webpage.add_header('%s, %s' % (name, title)) + ims = [] + txts = [] + links = [] + + for label, image_numpy in zip(names, images): + image_name = '%s.jpg' % (label,) + save_path = os.path.join(image_dir, image_name) + util.save_image(image_numpy, save_path) + + ims.append(image_name) + txts.append(label) + links.append(image_name) + webpage.add_images(ims, txts, links, width=self.win_size) + + # save image to the disk + # def save_images(self, webpage, visuals, image_path, short=False): + # image_dir = webpage.get_image_dir() + # if short: + # short_path = ntpath.basename(image_path) + # name = os.path.splitext(short_path)[0] + # else: + # name = image_path + + # webpage.add_header(name) + # ims = [] + # txts = [] + # links = [] + + # for label, image_numpy in visuals.items(): + # image_name = '%s_%s.png' % (name, label) + # save_path = os.path.join(image_dir, image_name) + # util.save_image(image_numpy, save_path) + + # ims.append(image_name) + # txts.append(label) + # links.append(image_name) + # webpage.add_images(ims, txts, links, width=self.win_size)