@ -0,0 +1 @@
|
||||
|
||||
@ -0,0 +1,75 @@
|
||||
import torch.utils.data as data
|
||||
from PIL import Image
|
||||
import os
|
||||
import os.path
|
||||
import numpy as np
|
||||
import h5py
|
||||
import glob
|
||||
|
||||
IMG_EXTENSIONS = [
|
||||
'.jpg', '.JPG', '.jpeg', '.JPEG',
|
||||
'.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
|
||||
]
|
||||
|
||||
def is_image_file(filename):
|
||||
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
|
||||
|
||||
def make_dataset(dir):
|
||||
images = []
|
||||
if not os.path.isdir(dir):
|
||||
raise Exception('Check dataroot')
|
||||
for root, _, fnames in sorted(os.walk(dir)):
|
||||
for fname in fnames:
|
||||
if is_image_file(fname):
|
||||
path = os.path.join(dir, fname)
|
||||
item = path
|
||||
images.append(item)
|
||||
return images
|
||||
|
||||
def default_loader(path):
|
||||
return Image.open(path).convert('RGB')
|
||||
|
||||
class classification(data.Dataset):
|
||||
def __init__(self, root, transform=None, loader=default_loader, seed=None):
|
||||
# imgs = make_dataset(root)
|
||||
# if len(imgs) == 0:
|
||||
# raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n"
|
||||
# "Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))
|
||||
self.root = root
|
||||
# self.imgs = imgs
|
||||
self.transform = transform
|
||||
self.loader = loader
|
||||
|
||||
if seed is not None:
|
||||
np.random.seed(seed)
|
||||
|
||||
def __getitem__(self, _):
|
||||
index = np.random.randint(1,self.__len__())
|
||||
# path = self.imgs[index]
|
||||
# img = self.loader(path)
|
||||
#img = img.resize((w, h), Image.BILINEAR)
|
||||
|
||||
|
||||
|
||||
file_name=self.root+'/'+str(index)+'.h5'
|
||||
f=h5py.File(file_name,'r')
|
||||
|
||||
haze_image=f['haze'][:]
|
||||
label=f['label'][:]
|
||||
label=label.mean()-1
|
||||
|
||||
haze_image=np.swapaxes(haze_image,0,2)
|
||||
haze_image=np.swapaxes(haze_image,1,2)
|
||||
|
||||
|
||||
# if self.transform is not None:
|
||||
# # NOTE preprocessing for each pair of images
|
||||
# imgA, imgB = self.transform(imgA, imgB)
|
||||
return haze_image, label
|
||||
|
||||
def __len__(self):
|
||||
train_list=glob.glob(self.root+'/*h5')
|
||||
# print len(train_list)
|
||||
return len(train_list)
|
||||
|
||||
# return len(self.imgs)
|
||||
@ -0,0 +1,133 @@
|
||||
import torch.utils.data as data
|
||||
|
||||
from PIL import Image
|
||||
import os
|
||||
import os.path
|
||||
import numpy as np
|
||||
|
||||
|
||||
|
||||
IMG_EXTENSIONS = [
|
||||
'.jpg', '.JPG', '.jpeg', '.JPEG',
|
||||
'.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '',
|
||||
]
|
||||
|
||||
def is_image_file(filename):
|
||||
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
|
||||
|
||||
def make_dataset(dir):
|
||||
images = []
|
||||
if not os.path.isdir(dir):
|
||||
raise Exception('Check dataroot')
|
||||
for root, _, fnames in sorted(os.walk(dir)):
|
||||
for fname in fnames:
|
||||
if is_image_file(fname):
|
||||
path = os.path.join(dir, fname)
|
||||
item = path
|
||||
images.append(item)
|
||||
return images
|
||||
|
||||
def default_loader(path):
|
||||
return Image.open(path).convert('RGB')
|
||||
|
||||
class pix2pix(data.Dataset):
|
||||
def __init__(self, root, transform=None, loader=default_loader, seed=None):
|
||||
imgs = make_dataset(root)
|
||||
if len(imgs) == 0:
|
||||
raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n"
|
||||
"Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))
|
||||
self.root = root
|
||||
self.imgs = imgs
|
||||
self.transform = transform
|
||||
self.loader = loader
|
||||
|
||||
if seed is not None:
|
||||
np.random.seed(seed)
|
||||
|
||||
def __getitem__(self, index):
|
||||
# index = np.random.randint(self.__len__(), size=1)[0]
|
||||
# index = np.random.randint(self.__len__(), size=1)[0]+1
|
||||
# index = np.random.randint(self.__len__(), size=1)[0]
|
||||
|
||||
# index_folder = np.random.randint(1,4)
|
||||
index_folder = np.random.randint(0,1)
|
||||
|
||||
index_sub = np.random.randint(2, 5)
|
||||
|
||||
label=index_folder
|
||||
|
||||
|
||||
if index_folder==0:
|
||||
path='/home/openset/Desktop/derain2018/facades/training2'+'/'+str(index)+'.jpg'
|
||||
|
||||
|
||||
|
||||
if index_folder==1:
|
||||
if index_sub<4:
|
||||
path='/home/openset/Desktop/derain2018/facades/DB_Rain_new/Rain_Heavy/train2018new'+'/'+str(index)+'.jpg'
|
||||
if index_sub==4:
|
||||
index = np.random.randint(0,400)
|
||||
path='/home/openset/Desktop/derain2018/facades/DB_Rain/Rain_Heavy/trainnew'+'/'+str(index)+'.jpg'
|
||||
|
||||
if index_folder==2:
|
||||
if index_sub<4:
|
||||
path='/home/openset/Desktop/derain2018/facades/DB_Rain_new/Rain_Medium/train2018new'+'/'+str(index)+'.jpg'
|
||||
if index_sub==4:
|
||||
index = np.random.randint(0,400)
|
||||
path='/home/openset/Desktop/derain2018/facades/DB_Rain/Rain_Medium/trainnew'+'/'+str(index)+'.jpg'
|
||||
|
||||
if index_folder==3:
|
||||
if index_sub<4:
|
||||
path='/home/openset/Desktop/derain2018/facades/DB_Rain_new/Rain_Light/train2018new'+'/'+str(index)+'.jpg'
|
||||
if index_sub==4:
|
||||
index = np.random.randint(0,400)
|
||||
path='/home/openset/Desktop/derain2018/facades/DB_Rain/Rain_Light/trainnew'+'/'+str(index)+'.jpg'
|
||||
|
||||
|
||||
|
||||
# img = self.loader(path)
|
||||
|
||||
img = self.loader(path)
|
||||
|
||||
# NOTE: img -> PIL Image
|
||||
# w, h = img.size
|
||||
# w, h = 1024, 512
|
||||
# img = img.resize((w, h), Image.BILINEAR)
|
||||
# pix = np.array(I)
|
||||
#
|
||||
# r = 16
|
||||
# eps = 1
|
||||
#
|
||||
# I = img.crop((0, 0, w/2, h))
|
||||
# pix = np.array(I)
|
||||
# base=guidedfilter(pix, pix, r, eps)
|
||||
# base = PIL.Image.fromarray(numpy.uint8(base))
|
||||
#
|
||||
#
|
||||
#
|
||||
# imgA=base
|
||||
# imgB=I-base
|
||||
# imgC = img.crop((w/2, 0, w, h))
|
||||
|
||||
w, h = img.size
|
||||
# img = img.resize((w, h), Image.BILINEAR)
|
||||
|
||||
|
||||
# NOTE: split a sample into imgA and imgB
|
||||
imgA = img.crop((0, 0, w/2, h))
|
||||
# imgC = img.crop((2*w/3, 0, w, h))
|
||||
|
||||
imgB = img.crop((w/2, 0, w, h))
|
||||
|
||||
|
||||
if self.transform is not None:
|
||||
# NOTE preprocessing for each pair of images
|
||||
# imgA, imgB, imgC = self.transform(imgA, imgB, imgC)
|
||||
imgA, imgB = self.transform(imgA, imgB)
|
||||
|
||||
return imgA, imgB, label
|
||||
|
||||
def __len__(self):
|
||||
# return 679
|
||||
print(len(self.imgs))
|
||||
return len(self.imgs)
|
||||
@ -0,0 +1,109 @@
|
||||
import torch.utils.data as data
|
||||
from PIL import Image
|
||||
import os
|
||||
import os.path
|
||||
import numpy as np
|
||||
import h5py
|
||||
import glob
|
||||
import scipy.ndimage
|
||||
IMG_EXTENSIONS = [
|
||||
'.jpg', '.JPG', '.jpeg', '.JPEG',
|
||||
'.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
|
||||
]
|
||||
|
||||
def is_image_file(filename):
|
||||
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
|
||||
|
||||
def make_dataset(dir):
|
||||
images = []
|
||||
if not os.path.isdir(dir):
|
||||
raise Exception('Check dataroot')
|
||||
for root, _, fnames in sorted(os.walk(dir)):
|
||||
for fname in fnames:
|
||||
if is_image_file(fname):
|
||||
path = os.path.join(dir, fname)
|
||||
item = path
|
||||
images.append(item)
|
||||
return images
|
||||
|
||||
def default_loader(path):
|
||||
return Image.open(path).convert('RGB')
|
||||
|
||||
class pix2pix(data.Dataset):
|
||||
def __init__(self, root, transform=None, loader=default_loader, seed=None):
|
||||
# imgs = make_dataset(root)
|
||||
# if len(imgs) == 0:
|
||||
# raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n"
|
||||
# "Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))
|
||||
self.root = root
|
||||
# self.imgs = imgs
|
||||
self.transform = transform
|
||||
self.loader = loader
|
||||
|
||||
if seed is not None:
|
||||
np.random.seed(seed)
|
||||
|
||||
def __getitem__(self, _):
|
||||
index = np.random.randint(1,self.__len__())
|
||||
# index = np.random.randint(self.__len__(), size=1)[0]
|
||||
|
||||
# path = self.imgs[index]
|
||||
# img = self.loader(path)
|
||||
#img = img.resize((w, h), Image.BILINEAR)
|
||||
|
||||
|
||||
|
||||
file_name=self.root+'/'+str(index)+'.h5'
|
||||
f=h5py.File(file_name,'r')
|
||||
|
||||
haze_image=f['haze'][:]
|
||||
trans_map=f['trans'][:]
|
||||
ato_map=f['ato'][:]
|
||||
GT=f['gt'][:]
|
||||
|
||||
|
||||
|
||||
haze_image=np.swapaxes(haze_image,0,2)
|
||||
trans_map=np.swapaxes(trans_map,0,2)
|
||||
ato_map=np.swapaxes(ato_map,0,2)
|
||||
GT=np.swapaxes(GT,0,2)
|
||||
|
||||
|
||||
|
||||
haze_image=np.swapaxes(haze_image,1,2)
|
||||
trans_map=np.swapaxes(trans_map,1,2)
|
||||
ato_map=np.swapaxes(ato_map,1,2)
|
||||
GT=np.swapaxes(GT,1,2)
|
||||
|
||||
# if np.random.uniform()>0.5:
|
||||
# haze_image=np.flip(haze_image,2).copy()
|
||||
# GT = np.flip(GT, 2).copy()
|
||||
# trans_map=np.flip(trans_map, 2).copy()
|
||||
# if np.random.uniform()>0.5:
|
||||
# angle = np.random.uniform(-10, 10)
|
||||
# haze_image=scipy.ndimage.interpolation.rotate(haze_image, angle)
|
||||
# GT = scipy.ndimage.interpolation.rotate(GT, angle)
|
||||
|
||||
# if np.random.uniform()>0.5:
|
||||
# angle = np.random.uniform(-10, 10)
|
||||
# haze_image=scipy.ndimage.interpolation.rotate(haze_image, angle)
|
||||
# GT = scipy.ndimage.interpolation.rotate(GT, angle)
|
||||
|
||||
# if np.random.uniform()>0.5:
|
||||
# std = np.random.uniform(0.2, 1.2)
|
||||
# haze_image = scipy.ndimage.filters.gaussian_filter(haze_image, std,mode='constant')
|
||||
|
||||
# haze_image=np.random.uniform(-10/5000,10/5000,size=haze_image.shape)
|
||||
# haze_image = np.maximum(0, haze_image)
|
||||
|
||||
# if self.transform is not None:
|
||||
# # NOTE preprocessing for each pair of images
|
||||
# imgA, imgB = self.transform(imgA, imgB)
|
||||
return haze_image, GT, trans_map, ato_map
|
||||
|
||||
def __len__(self):
|
||||
train_list=glob.glob(self.root+'/*h5')
|
||||
# print len(train_list)
|
||||
return len(train_list)
|
||||
|
||||
# return len(self.imgs)
|
||||
@ -0,0 +1,88 @@
|
||||
import torch.utils.data as data
|
||||
|
||||
from PIL import Image
|
||||
import os
|
||||
import os.path
|
||||
import numpy as np
|
||||
|
||||
|
||||
IMG_EXTENSIONS = [
|
||||
'.jpg', '.JPG', '.jpeg', '.JPEG',
|
||||
'.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '',
|
||||
]
|
||||
|
||||
def is_image_file(filename):
|
||||
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
|
||||
|
||||
def make_dataset(dir):
|
||||
images = []
|
||||
if not os.path.isdir(dir):
|
||||
raise Exception('Check dataroot')
|
||||
for root, _, fnames in sorted(os.walk(dir)):
|
||||
for fname in fnames:
|
||||
if is_image_file(fname):
|
||||
path = os.path.join(dir, fname)
|
||||
item = path
|
||||
images.append(item)
|
||||
return images
|
||||
|
||||
def default_loader(path):
|
||||
return Image.open(path).convert('RGB')
|
||||
|
||||
class pix2pix(data.Dataset):
|
||||
def __init__(self, root, transform=None, loader=default_loader, seed=None):
|
||||
imgs = make_dataset(root)
|
||||
if len(imgs) == 0:
|
||||
raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n"
|
||||
"Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))
|
||||
self.root = root
|
||||
self.imgs = imgs
|
||||
self.transform = transform
|
||||
self.loader = loader
|
||||
|
||||
if seed is not None:
|
||||
np.random.seed(seed)
|
||||
|
||||
def __getitem__(self, index):
|
||||
|
||||
index_sub = np.random.randint(0, 3)
|
||||
label=index_sub
|
||||
|
||||
|
||||
|
||||
if index_sub==0:
|
||||
index = np.random.randint(0,4000)
|
||||
path='/media/openset/Z/derain2018/facades/DID-MDN-training/Rain_Heavy/train2018new'+'/'+str(index)+'.jpg'
|
||||
|
||||
|
||||
if index_sub==1:
|
||||
index = np.random.randint(0,4000)
|
||||
path='/media/openset/Z/derain2018/facades/DID-MDN-training/Rain_Medium/train2018new'+'/'+str(index)+'.jpg'
|
||||
|
||||
if index_sub==2:
|
||||
index = np.random.randint(0,4000)
|
||||
path='/media/openset/Z/derain2018/facades/DID-MDN-training/Rain_Light/train2018new'+'/'+str(index)+'.jpg'
|
||||
|
||||
|
||||
|
||||
img = self.loader(path)
|
||||
|
||||
|
||||
w, h = img.size
|
||||
|
||||
|
||||
# NOTE: split a sample into imgA and imgB
|
||||
imgA = img.crop((0, 0, w/2, h))
|
||||
imgB = img.crop((w/2, 0, w, h))
|
||||
|
||||
|
||||
if self.transform is not None:
|
||||
# NOTE preprocessing for each pair of images
|
||||
imgA, imgB = self.transform(imgA, imgB)
|
||||
|
||||
return imgA, imgB, label
|
||||
|
||||
def __len__(self):
|
||||
# return 679
|
||||
# print(len(self.imgs))
|
||||
return len(self.imgs)
|
||||
@ -0,0 +1,139 @@
|
||||
import torch.utils.data as data
|
||||
|
||||
from PIL import Image
|
||||
import os
|
||||
import os.path
|
||||
import numpy as np
|
||||
import cv2
|
||||
|
||||
# from guidedfilter import guidedfilter
|
||||
# import guidedfilter.guidedfilter as guidedfilter
|
||||
|
||||
|
||||
|
||||
|
||||
IMG_EXTENSIONS = [
|
||||
'.jpg', '.JPG', '.jpeg', '.JPEG',
|
||||
'.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '',
|
||||
]
|
||||
|
||||
def is_image_file(filename):
|
||||
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
|
||||
|
||||
def make_dataset(dir):
|
||||
images = []
|
||||
if not os.path.isdir(dir):
|
||||
raise Exception('Check dataroot')
|
||||
for root, _, fnames in sorted(os.walk(dir)):
|
||||
for fname in fnames:
|
||||
if is_image_file(fname):
|
||||
path = os.path.join(dir, fname)
|
||||
item = path
|
||||
images.append(item)
|
||||
return images
|
||||
|
||||
def default_loader(path):
|
||||
return Image.open(path).convert('RGB')
|
||||
|
||||
class pix2pix_val(data.Dataset):
|
||||
def __init__(self, root, transform=None, loader=default_loader, seed=None):
|
||||
imgs = make_dataset(root)
|
||||
if len(imgs) == 0:
|
||||
raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n"
|
||||
"Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))
|
||||
self.root = root
|
||||
self.imgs = imgs
|
||||
self.transform = transform
|
||||
self.loader = loader
|
||||
|
||||
if seed is not None:
|
||||
np.random.seed(seed)
|
||||
|
||||
def __getitem__(self, index):
|
||||
# index = np.random.randint(self.__len__(), size=1)[0]
|
||||
# index = np.random.randint(self.__len__(), size=1)[0]
|
||||
|
||||
path = self.imgs[index]
|
||||
|
||||
if index>75000:
|
||||
index = index%75000+10
|
||||
path=self.root+'/'+'%06d'%(index+1)+'.png'
|
||||
|
||||
index_folder = np.random.randint(0,4)
|
||||
label=index_folder
|
||||
|
||||
# path='/home/openset/Desktop/derain2018/facades/DB_Rain_test/Rain_Heavy/test2018'+'/'+str(index)+'.jpg'
|
||||
img = self.loader(path)
|
||||
|
||||
# # NOTE: img -> PIL Image
|
||||
# w, h = img.size
|
||||
# w, h = 1024, 512
|
||||
# img = img.resize((w, h), Image.BILINEAR)
|
||||
# # NOTE: split a sample into imgA and imgB
|
||||
# imgA = img.crop((0, 0, w/2, h))
|
||||
# imgB = img.crop((w/2, 0, w, h))
|
||||
# if self.transform is not None:
|
||||
# # NOTE preprocessing for each pair of images
|
||||
# imgA, imgB = self.transform(imgA, imgB)
|
||||
# return imgA, imgB
|
||||
|
||||
|
||||
# w, h = 1536, 512
|
||||
# img = img.resize((w, h), Image.BILINEAR)
|
||||
#
|
||||
#
|
||||
# # NOTE: split a sample into imgA and imgB
|
||||
# imgA = img.crop((0, 0, w/3, h))
|
||||
# imgC = img.crop((2*w/3, 0, w, h))
|
||||
#
|
||||
# imgB = img.crop((w/3, 0, 2*w/3, h))
|
||||
|
||||
# w, h = 1024, 512
|
||||
# img = img.resize((w, h), Image.BILINEAR)
|
||||
#
|
||||
# r = 16
|
||||
# eps = 1
|
||||
#
|
||||
# # I = img.crop((0, 0, w/2, h))
|
||||
# # pix = np.array(I)
|
||||
# # print
|
||||
# # base[idx,:,:,:]=guidedfilter(pix[], pix[], r, eps)
|
||||
# # base[]=guidedfilter(pix[], pix[], r, eps)
|
||||
# # base[]=guidedfilter(pix[], pix[], r, eps)
|
||||
#
|
||||
#
|
||||
# # base = PIL.Image.fromarray(numpy.uint8(base))
|
||||
#
|
||||
# # NOTE: split a sample into imgA and imgB
|
||||
# imgA = img.crop((0, 0, w/3, h))
|
||||
# imgC = img.crop((2*w/3, 0, w, h))
|
||||
#
|
||||
# imgB = img.crop((w/3, 0, 2*w/3, h))
|
||||
# imgA=base
|
||||
# imgB=I-base
|
||||
# imgC = img.crop((w/2, 0, w, h))
|
||||
w, h = img.size
|
||||
#print(w,h)
|
||||
# w, h = 586*2, 586
|
||||
|
||||
# img = img.resize((w, h), Image.BILINEAR)
|
||||
|
||||
|
||||
# NOTE: split a sample into imgA and imgB
|
||||
imgA = img.crop((0, 0, w/2, h))
|
||||
# imgC = img.crop((2*w/3, 0, w, h))
|
||||
|
||||
imgB = img.crop((w/2, 0, w, h))
|
||||
|
||||
|
||||
|
||||
if self.transform is not None:
|
||||
# NOTE preprocessing for each pair of images
|
||||
# imgA, imgB, imgC = self.transform(imgA, imgB, imgC)
|
||||
imgA, imgB = self.transform(img, img)
|
||||
|
||||
|
||||
return imgA, imgB
|
||||
|
||||
def __len__(self):
|
||||
return len(self.imgs)
|
||||
|
After Width: | Height: | Size: 79 KiB |
|
After Width: | Height: | Size: 58 KiB |
|
After Width: | Height: | Size: 36 KiB |
|
After Width: | Height: | Size: 74 KiB |
|
After Width: | Height: | Size: 51 KiB |
|
After Width: | Height: | Size: 59 KiB |
|
After Width: | Height: | Size: 58 KiB |
|
After Width: | Height: | Size: 54 KiB |
|
After Width: | Height: | Size: 79 KiB |
|
After Width: | Height: | Size: 58 KiB |
|
After Width: | Height: | Size: 36 KiB |
|
After Width: | Height: | Size: 74 KiB |
|
After Width: | Height: | Size: 51 KiB |
|
After Width: | Height: | Size: 59 KiB |
|
After Width: | Height: | Size: 58 KiB |
|
After Width: | Height: | Size: 54 KiB |
@ -0,0 +1,123 @@
|
||||
import torch
|
||||
import os
|
||||
import sys
|
||||
|
||||
|
||||
def create_exp_dir(exp):
|
||||
try:
|
||||
os.makedirs(exp)
|
||||
print('Creating exp dir: %s' % exp)
|
||||
except OSError:
|
||||
pass
|
||||
return True
|
||||
|
||||
|
||||
def weights_init(m):
|
||||
classname = m.__class__.__name__
|
||||
if classname.find('Conv') != -1:
|
||||
m.weight.data.normal_(0.0, 0.02)
|
||||
elif classname.find('BatchNorm') != -1:
|
||||
m.weight.data.normal_(1.0, 0.02)
|
||||
m.bias.data.fill_(0)
|
||||
|
||||
|
||||
def getLoader(datasetName, dataroot, originalSize, imageSize, batchSize=64, workers=4,
|
||||
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), split='train', shuffle=True, seed=None):
|
||||
|
||||
#import pdb; pdb.set_trace()
|
||||
if datasetName == 'pix2pix':
|
||||
# from datasets.pix2pix import pix2pix as commonDataset
|
||||
# import transforms.pix2pix as transforms
|
||||
from datasets.pix2pix import pix2pix as commonDataset
|
||||
import transforms.pix2pix as transforms
|
||||
elif datasetName == 'pix2pix_val':
|
||||
# from datasets.pix2pix_val import pix2pix_val as commonDataset
|
||||
# import transforms.pix2pix as transforms
|
||||
from datasets.pix2pix_val import pix2pix_val as commonDataset
|
||||
import transforms.pix2pix as transforms
|
||||
if datasetName == 'pix2pix_class':
|
||||
# from datasets.pix2pix import pix2pix as commonDataset
|
||||
# import transforms.pix2pix as transforms
|
||||
from datasets.pix2pix_class import pix2pix as commonDataset
|
||||
import transforms.pix2pix as transforms
|
||||
if split == 'train':
|
||||
dataset = commonDataset(root=dataroot,
|
||||
transform=transforms.Compose([
|
||||
transforms.Scale(originalSize),
|
||||
#transforms.RandomCrop(imageSize),
|
||||
#transforms.CenterCrop(imageSize),
|
||||
#transforms.RandomHorizontalFlip(),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean, std),
|
||||
]),
|
||||
seed=seed)
|
||||
else:
|
||||
dataset = commonDataset(root=dataroot,
|
||||
transform=transforms.Compose([
|
||||
transforms.Scale(originalSize),
|
||||
#transforms.CenterCrop(imageSize),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean, std),
|
||||
]),
|
||||
seed=seed)
|
||||
|
||||
dataloader = torch.utils.data.DataLoader(dataset,
|
||||
batch_size=batchSize,
|
||||
shuffle=shuffle,
|
||||
num_workers=int(workers))
|
||||
return dataloader
|
||||
|
||||
|
||||
class AverageMeter(object):
|
||||
"""Computes and stores the average and current value"""
|
||||
def __init__(self):
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.val = 0
|
||||
self.avg = 0
|
||||
self.sum = 0
|
||||
self.count = 0
|
||||
|
||||
def update(self, val, n=1):
|
||||
self.val = val
|
||||
self.sum += val * n
|
||||
self.count += n
|
||||
self.avg = self.sum / self.count
|
||||
|
||||
|
||||
import numpy as np
|
||||
class ImagePool:
|
||||
def __init__(self, pool_size=50):
|
||||
self.pool_size = pool_size
|
||||
if pool_size > 0:
|
||||
self.num_imgs = 0
|
||||
self.images = []
|
||||
|
||||
def query(self, image):
|
||||
if self.pool_size == 0:
|
||||
return image
|
||||
if self.num_imgs < self.pool_size:
|
||||
self.images.append(image.clone())
|
||||
self.num_imgs += 1
|
||||
return image
|
||||
else:
|
||||
if np.random.uniform(0,1) > 0.5:
|
||||
random_id = np.random.randint(self.pool_size, size=1)[0]
|
||||
tmp = self.images[random_id].clone()
|
||||
self.images[random_id] = image.clone()
|
||||
return tmp
|
||||
else:
|
||||
return image
|
||||
|
||||
|
||||
def adjust_learning_rate(optimizer, init_lr, epoch, factor, every):
|
||||
#import pdb; pdb.set_trace()
|
||||
lrd = init_lr / every
|
||||
old_lr = optimizer.param_groups[0]['lr']
|
||||
# linearly decaying lr
|
||||
print('learning rate: %f'%old_lr)
|
||||
lr = old_lr - lrd
|
||||
if lr < 0: lr = 0
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group['lr'] = lr
|
||||
@ -0,0 +1 @@
|
||||
|
||||
@ -0,0 +1,855 @@
|
||||
import torch
|
||||
import torch.backends.cudnn as cudnn
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.model_zoo as model_zoo
|
||||
from collections import OrderedDict
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.model_zoo as model_zoo
|
||||
from collections import OrderedDict
|
||||
import torchvision.models as models
|
||||
from torch.autograd import Variable
|
||||
|
||||
|
||||
|
||||
def conv_block(in_dim,out_dim):
|
||||
return nn.Sequential(nn.Conv2d(in_dim,in_dim,kernel_size=3,stride=1,padding=1),
|
||||
nn.ELU(True),
|
||||
nn.Conv2d(in_dim,in_dim,kernel_size=3,stride=1,padding=1),
|
||||
nn.ELU(True),
|
||||
nn.Conv2d(in_dim,out_dim,kernel_size=1,stride=1,padding=0),
|
||||
nn.AvgPool2d(kernel_size=2,stride=2))
|
||||
def deconv_block(in_dim,out_dim):
|
||||
return nn.Sequential(nn.Conv2d(in_dim,out_dim,kernel_size=3,stride=1,padding=1),
|
||||
nn.ELU(True),
|
||||
nn.Conv2d(out_dim,out_dim,kernel_size=3,stride=1,padding=1),
|
||||
nn.ELU(True),
|
||||
nn.UpsamplingNearest2d(scale_factor=2))
|
||||
|
||||
|
||||
def blockUNet1(in_c, out_c, name, transposed=False, bn=False, relu=True, dropout=False):
|
||||
block = nn.Sequential()
|
||||
if relu:
|
||||
block.add_module('%s.relu' % name, nn.ReLU(inplace=True))
|
||||
else:
|
||||
block.add_module('%s.leakyrelu' % name, nn.LeakyReLU(0.2, inplace=True))
|
||||
if not transposed:
|
||||
block.add_module('%s.conv' % name, nn.Conv2d(in_c, out_c, 3, 1, 1, bias=False))
|
||||
else:
|
||||
block.add_module('%s.tconv' % name, nn.ConvTranspose2d(in_c, out_c, 3, 1, 1, bias=False))
|
||||
if bn:
|
||||
block.add_module('%s.bn' % name, nn.InstanceNorm2d(out_c))
|
||||
if dropout:
|
||||
block.add_module('%s.dropout' % name, nn.Dropout2d(0.5, inplace=True))
|
||||
return block
|
||||
|
||||
def blockUNet(in_c, out_c, name, transposed=False, bn=False, relu=True, dropout=False):
|
||||
block = nn.Sequential()
|
||||
if relu:
|
||||
block.add_module('%s.relu' % name, nn.ReLU(inplace=True))
|
||||
else:
|
||||
block.add_module('%s.leakyrelu' % name, nn.LeakyReLU(0.2, inplace=True))
|
||||
if not transposed:
|
||||
block.add_module('%s.conv' % name, nn.Conv2d(in_c, out_c, 4, 2, 1, bias=False))
|
||||
else:
|
||||
block.add_module('%s.tconv' % name, nn.ConvTranspose2d(in_c, out_c, 4, 2, 1, bias=False))
|
||||
if bn:
|
||||
block.add_module('%s.bn' % name, nn.InstanceNorm2d(out_c))
|
||||
if dropout:
|
||||
block.add_module('%s.dropout' % name, nn.Dropout2d(0.5, inplace=True))
|
||||
return block
|
||||
|
||||
|
||||
class D1(nn.Module):
|
||||
def __init__(self, nc, ndf, hidden_size):
|
||||
super(D1, self).__init__()
|
||||
|
||||
# 256
|
||||
self.conv1 = nn.Sequential(nn.Conv2d(nc,ndf,kernel_size=3,stride=1,padding=1),
|
||||
nn.ELU(True))
|
||||
# 256
|
||||
self.conv2 = conv_block(ndf,ndf)
|
||||
# 128
|
||||
self.conv3 = conv_block(ndf, ndf*2)
|
||||
# 64
|
||||
self.conv4 = conv_block(ndf*2, ndf*3)
|
||||
# 32
|
||||
self.encode = nn.Conv2d(ndf*3, hidden_size, kernel_size=1,stride=1,padding=0)
|
||||
self.decode = nn.Conv2d(hidden_size, ndf, kernel_size=1,stride=1,padding=0)
|
||||
# 32
|
||||
self.deconv4 = deconv_block(ndf, ndf)
|
||||
# 64
|
||||
self.deconv3 = deconv_block(ndf, ndf)
|
||||
# 128
|
||||
self.deconv2 = deconv_block(ndf, ndf)
|
||||
# 256
|
||||
self.deconv1 = nn.Sequential(nn.Conv2d(ndf,ndf,kernel_size=3,stride=1,padding=1),
|
||||
nn.ELU(True),
|
||||
nn.Conv2d(ndf,ndf,kernel_size=3,stride=1,padding=1),
|
||||
nn.ELU(True),
|
||||
nn.Conv2d(ndf, nc, kernel_size=3, stride=1, padding=1),
|
||||
nn.Tanh())
|
||||
"""
|
||||
self.deconv1 = nn.Sequential(nn.Conv2d(ndf,nc,kernel_size=3,stride=1,padding=1),
|
||||
nn.Tanh())
|
||||
"""
|
||||
def forward(self,x):
|
||||
out1 = self.conv1(x)
|
||||
out2 = self.conv2(out1)
|
||||
out3 = self.conv3(out2)
|
||||
out4 = self.conv4(out3)
|
||||
out5 = self.encode(out4)
|
||||
dout5= self.decode(out5)
|
||||
dout4= self.deconv4(dout5)
|
||||
dout3= self.deconv3(dout4)
|
||||
dout2= self.deconv2(dout3)
|
||||
dout1= self.deconv1(dout2)
|
||||
return dout1
|
||||
|
||||
class D(nn.Module):
|
||||
def __init__(self, nc, nf):
|
||||
super(D, self).__init__()
|
||||
|
||||
main = nn.Sequential()
|
||||
# 256
|
||||
layer_idx = 1
|
||||
name = 'layer%d' % layer_idx
|
||||
main.add_module('%s.conv' % name, nn.Conv2d(nc, nf, 4, 2, 1, bias=False))
|
||||
|
||||
# 128
|
||||
layer_idx += 1
|
||||
name = 'layer%d' % layer_idx
|
||||
main.add_module(name, blockUNet(nf, nf*2, name, transposed=False, bn=True, relu=False, dropout=False))
|
||||
|
||||
# 64
|
||||
layer_idx += 1
|
||||
name = 'layer%d' % layer_idx
|
||||
nf = nf * 2
|
||||
main.add_module(name, blockUNet(nf, nf*2, name, transposed=False, bn=True, relu=False, dropout=False))
|
||||
|
||||
# 32
|
||||
layer_idx += 1
|
||||
name = 'layer%d' % layer_idx
|
||||
nf = nf * 2
|
||||
main.add_module('%s.leakyrelu' % name, nn.LeakyReLU(0.2, inplace=True))
|
||||
main.add_module('%s.conv' % name, nn.Conv2d(nf, nf*2, 4, 1, 1, bias=False))
|
||||
main.add_module('%s.bn' % name, nn.InstanceNorm2d(nf*2))
|
||||
|
||||
# 31
|
||||
layer_idx += 1
|
||||
name = 'layer%d' % layer_idx
|
||||
nf = nf * 2
|
||||
main.add_module('%s.leakyrelu' % name, nn.LeakyReLU(0.2, inplace=True))
|
||||
main.add_module('%s.conv' % name, nn.Conv2d(nf, 1, 4, 1, 1, bias=False))
|
||||
main.add_module('%s.sigmoid' % name , nn.Sigmoid())
|
||||
# 30 (sizePatchGAN=30)
|
||||
|
||||
self.main = main
|
||||
|
||||
def forward(self, x):
|
||||
output = self.main(x)
|
||||
return output
|
||||
|
||||
class ShareSepConv(nn.Module):
|
||||
def __init__(self, kernel_size):
|
||||
super(ShareSepConv, self).__init__()
|
||||
assert kernel_size % 2 == 1, 'kernel size should be odd'
|
||||
self.padding = (kernel_size - 1)//2
|
||||
weight_tensor = torch.zeros(1, 1, kernel_size, kernel_size)
|
||||
weight_tensor[0, 0, (kernel_size-1)//2, (kernel_size-1)//2] = 1
|
||||
self.weight = nn.Parameter(weight_tensor)
|
||||
self.kernel_size = kernel_size
|
||||
|
||||
def forward(self, x):
|
||||
inc = x.size(1)
|
||||
expand_weight = self.weight.expand(inc, 1, self.kernel_size, self.kernel_size).contiguous()
|
||||
return F.conv2d(x, expand_weight,
|
||||
None, 1, self.padding, 1, inc)
|
||||
|
||||
|
||||
class BottleneckBlockdls(nn.Module):
|
||||
def __init__(self, in_planes, out_planes, dropRate=0.0):
|
||||
super(BottleneckBlockdls, self).__init__()
|
||||
inter_planes = out_planes * 4
|
||||
self.bn1 = nn.BatchNorm2d(in_planes)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.conv_o = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1,
|
||||
padding=0, bias=False)
|
||||
self.sharewconv1 = ShareSepConv(3)
|
||||
self.conv1 = nn.Conv2d(in_planes, inter_planes, kernel_size=1, stride=1,
|
||||
padding=0, bias=False)
|
||||
self.sharewconv2 = ShareSepConv(3)
|
||||
self.bn2 = nn.BatchNorm2d(inter_planes)
|
||||
self.conv2 = nn.Conv2d(inter_planes, inter_planes, kernel_size=3, stride=1,
|
||||
padding=2, dilation=2, bias=False)
|
||||
self.bn4 = nn.BatchNorm2d(inter_planes)
|
||||
self.conv4 = nn.Conv2d(inter_planes, out_planes, kernel_size=3, stride=1,
|
||||
padding=2, dilation=2, bias=False)
|
||||
self.droprate = dropRate
|
||||
def forward(self, x):
|
||||
out = self.conv1(self.relu(self.bn1(x)))
|
||||
if self.droprate > 0:
|
||||
out = F.dropout(out, p=self.droprate, inplace=False, training=self.training)
|
||||
out = self.conv2(self.relu(self.bn2(out)))
|
||||
outx = self.conv_o(x)
|
||||
out = outx + self.conv4(self.sharewconv2(self.relu(self.bn4(out))))
|
||||
if self.droprate > 0:
|
||||
out = F.dropout(out, p=self.droprate, inplace=False, training=self.training)
|
||||
return torch.cat([x, out], 1)
|
||||
|
||||
class BottleneckBlockdl(nn.Module):
|
||||
def __init__(self, in_planes, out_planes, dropRate=0.0):
|
||||
super(BottleneckBlockdl, self).__init__()
|
||||
inter_planes = out_planes * 3
|
||||
self.bn1 = nn.InstanceNorm2d(in_planes)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.conv_o = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1,
|
||||
padding=0, bias=False)
|
||||
self.conv1 = nn.Conv2d(in_planes, inter_planes, kernel_size=1, stride=1,
|
||||
padding=0, bias=False)
|
||||
self.bn2 = nn.InstanceNorm2d(inter_planes)
|
||||
self.conv2 = nn.Conv2d(inter_planes, inter_planes, kernel_size=3, stride=1,
|
||||
padding=1, dilation=1, bias=False)
|
||||
self.bn3 = nn.InstanceNorm2d(inter_planes)
|
||||
self.conv3 = nn.Conv2d(inter_planes, inter_planes, kernel_size=3, stride=1,
|
||||
padding=2, dilation=2, bias=False)
|
||||
self.bn4 = nn.InstanceNorm2d(inter_planes)
|
||||
self.sharewconv = ShareSepConv(3)
|
||||
self.conv4 = nn.Conv2d(inter_planes, out_planes, kernel_size=3, stride=1,
|
||||
padding=2, dilation=2, bias=False)
|
||||
self.droprate = dropRate
|
||||
def forward(self, x):
|
||||
out = self.conv1(self.relu(self.bn1(x)))
|
||||
if self.droprate > 0:
|
||||
out = F.dropout(out, p=self.droprate, inplace=False, training=self.training)
|
||||
out = self.conv2(self.relu(self.bn2(out)))
|
||||
out = self.conv3(self.relu(self.bn3(out)))
|
||||
outx = self.conv_o(x)
|
||||
out = outx + self.conv4(self.sharewconv(self.relu(self.bn4(out))))
|
||||
if self.droprate > 0:
|
||||
out = F.dropout(out, p=self.droprate, inplace=False, training=self.training)
|
||||
return torch.cat([x, out], 1)
|
||||
|
||||
|
||||
class BottleneckBlockrs1(nn.Module):
|
||||
def __init__(self, in_planes, out_planes, dropRate=0.0):
|
||||
super(BottleneckBlockrs1, self).__init__()
|
||||
inter_planes = out_planes * 3
|
||||
self.bn1 = nn.InstanceNorm2d(in_planes)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.conv_o = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1,
|
||||
padding=0, bias=False)
|
||||
self.conv1 = nn.Conv2d(in_planes, inter_planes, kernel_size=1, stride=1,
|
||||
padding=0, bias=False)
|
||||
self.bn2 = nn.InstanceNorm2d(inter_planes)
|
||||
self.conv2 = nn.Conv2d(inter_planes, inter_planes, kernel_size=3, stride=1,
|
||||
padding=1, bias=False)
|
||||
self.bn3 = nn.InstanceNorm2d(inter_planes)
|
||||
self.conv3 = nn.Conv2d(inter_planes, inter_planes, kernel_size=3, stride=1,
|
||||
padding=2, dilation=2, bias=False)
|
||||
self.bn4 = nn.InstanceNorm2d(inter_planes)
|
||||
self.conv4 = nn.Conv2d(inter_planes, out_planes, kernel_size=3, stride=1,
|
||||
padding=1, bias=False)
|
||||
self.droprate = dropRate
|
||||
def forward(self, x):
|
||||
out = self.conv1(self.relu(self.bn1(x)))
|
||||
if self.droprate > 0:
|
||||
out = F.dropout(out, p=self.droprate, inplace=False, training=self.training)
|
||||
out = self.conv2(self.relu(self.bn2(out)))
|
||||
out = self.conv3(self.relu(self.bn3(out)))
|
||||
outx = self.conv_o(x)
|
||||
out = outx + self.conv4(self.relu(self.bn4(out)))
|
||||
if self.droprate > 0:
|
||||
out = F.dropout(out, p=self.droprate, inplace=False, training=self.training)
|
||||
return torch.cat([x, out], 1)
|
||||
|
||||
class BottleneckBlockrs(nn.Module):
|
||||
def __init__(self, in_planes, out_planes, dropRate=0.0):
|
||||
super(BottleneckBlockrs, self).__init__()
|
||||
inter_planes = out_planes * 3
|
||||
self.bn1 = nn.InstanceNorm2d(in_planes)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.conv_o = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1,
|
||||
padding=0, bias=False)
|
||||
self.conv1 = nn.Conv2d(in_planes, inter_planes, kernel_size=1, stride=1,
|
||||
padding=0, bias=False)
|
||||
self.bn2 = nn.InstanceNorm2d(inter_planes)
|
||||
self.conv2 = nn.Conv2d(inter_planes, inter_planes, kernel_size=3, stride=1,
|
||||
padding=1, bias=False)
|
||||
self.bn3 = nn.InstanceNorm2d(inter_planes)
|
||||
self.conv3 = nn.Conv2d(inter_planes, inter_planes, kernel_size=3, stride=1,
|
||||
padding=1, bias=False)
|
||||
self.bn4 = nn.InstanceNorm2d(inter_planes)
|
||||
self.conv4 = nn.Conv2d(inter_planes, out_planes, kernel_size=3, stride=1,
|
||||
padding=1, bias=False)
|
||||
self.droprate = dropRate
|
||||
def forward(self, x):
|
||||
out = self.conv1(self.relu(self.bn1(x)))
|
||||
if self.droprate > 0:
|
||||
out = F.dropout(out, p=self.droprate, inplace=False, training=self.training)
|
||||
out = self.conv2(self.relu(self.bn2(out)))
|
||||
out = self.conv3(self.relu(self.bn3(out)))
|
||||
outx = self.conv_o(x)
|
||||
out = outx + self.conv4(self.relu(self.bn4(out)))
|
||||
if self.droprate > 0:
|
||||
out = F.dropout(out, p=self.droprate, inplace=False, training=self.training)
|
||||
return torch.cat([x, out], 1)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
class TransitionBlock(nn.Module):
|
||||
def __init__(self, in_planes, out_planes, dropRate=0.0):
|
||||
super(TransitionBlock, self).__init__()
|
||||
self.bn1 = nn.InstanceNorm2d(in_planes)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.conv1 = nn.ConvTranspose2d(in_planes, out_planes, kernel_size=1, stride=1,
|
||||
padding=0, bias=False)
|
||||
self.droprate = dropRate
|
||||
def forward(self, x):
|
||||
out = self.conv1(self.relu(self.bn1(x)))
|
||||
if self.droprate > 0:
|
||||
out = F.dropout(out, p=self.droprate, inplace=False, training=self.training)
|
||||
return F.upsample_nearest(out, scale_factor=2)
|
||||
|
||||
|
||||
|
||||
class TransitionBlock1(nn.Module):
|
||||
def __init__(self, in_planes, out_planes, dropRate=0.0):
|
||||
super(TransitionBlock1, self).__init__()
|
||||
self.bn1 = nn.InstanceNorm2d(in_planes)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.conv1 = nn.ConvTranspose2d(in_planes, out_planes, kernel_size=1, stride=1,
|
||||
padding=0, bias=False)
|
||||
self.droprate = dropRate
|
||||
def forward(self, x):
|
||||
out = self.conv1(self.relu(self.bn1(x)))
|
||||
if self.droprate > 0:
|
||||
out = F.dropout(out, p=self.droprate, inplace=False, training=self.training)
|
||||
return F.avg_pool2d(out, 2)
|
||||
|
||||
|
||||
|
||||
class TransitionBlock3(nn.Module):
|
||||
def __init__(self, in_planes, out_planes, dropRate=0.0):
|
||||
super(TransitionBlock3, self).__init__()
|
||||
self.bn1 = nn.InstanceNorm2d(in_planes)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.conv1 = nn.ConvTranspose2d(in_planes, out_planes, kernel_size=1, stride=1,
|
||||
padding=0, bias=False)
|
||||
self.droprate = dropRate
|
||||
def forward(self, x):
|
||||
out = self.conv1(self.relu(self.bn1(x)))
|
||||
if self.droprate > 0:
|
||||
out = F.dropout(out, p=self.droprate, inplace=False, training=self.training)
|
||||
return out
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
class vgg19ca(nn.Module):
|
||||
def __init__(self):
|
||||
super(vgg19ca, self).__init__()
|
||||
|
||||
|
||||
|
||||
|
||||
############# 256-256 ##############
|
||||
haze_class = models.vgg19_bn(pretrained=True)
|
||||
self.feature = nn.Sequential(haze_class.features[0])
|
||||
|
||||
for i in range(1,3):
|
||||
self.feature.add_module(str(i),haze_class.features[i])
|
||||
|
||||
self.conv16=nn.Conv2d(64, 24, kernel_size=3,stride=1,padding=1) # 1mm
|
||||
self.dense_classifier=nn.Linear(127896, 512)
|
||||
self.dense_classifier1=nn.Linear(512, 4)
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
feature=self.feature(x)
|
||||
# feature = Variable(feature.data, requires_grad=True)
|
||||
|
||||
feature=self.conv16(feature)
|
||||
# print feature.size()
|
||||
|
||||
# feature=Variable(feature.data,requires_grad=True)
|
||||
|
||||
|
||||
|
||||
out = F.relu(feature, inplace=True)
|
||||
out = F.avg_pool2d(out, kernel_size=7).view(out.size(0), -1)
|
||||
# print out.size()
|
||||
|
||||
# out=Variable(out.data,requires_grad=True)
|
||||
out = F.relu(self.dense_classifier(out))
|
||||
out = (self.dense_classifier1(out))
|
||||
|
||||
|
||||
return out
|
||||
class BottleneckBlock(nn.Module):
|
||||
def __init__(self, in_planes, out_planes, dropRate=0.0):
|
||||
super(BottleneckBlock, self).__init__()
|
||||
inter_planes = out_planes * 3
|
||||
self.bn1 = nn.InstanceNorm2d(in_planes)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.conv_o = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1,
|
||||
padding=0, bias=False)
|
||||
self.conv1 = nn.Conv2d(in_planes, inter_planes, kernel_size=1, stride=1,
|
||||
padding=0, bias=False)
|
||||
self.bn2 = nn.InstanceNorm2d(inter_planes)
|
||||
self.conv2 = nn.Conv2d(inter_planes, inter_planes, kernel_size=3, stride=1,
|
||||
padding=1, bias=False)
|
||||
|
||||
self.bn4 = nn.InstanceNorm2d(inter_planes)
|
||||
self.conv4 = nn.Conv2d(inter_planes, out_planes, kernel_size=3, stride=1,
|
||||
padding=1, bias=False)
|
||||
self.droprate = dropRate
|
||||
def forward(self, x):
|
||||
out = self.conv1(self.relu(self.bn1(x)))
|
||||
if self.droprate > 0:
|
||||
out = F.dropout(out, p=self.droprate, inplace=False, training=self.training)
|
||||
out = self.conv2(self.relu(self.bn2(out)))
|
||||
outx = self.conv_o(x)
|
||||
out = outx + self.conv4(self.relu(self.bn4(out)))
|
||||
if self.droprate > 0:
|
||||
out = F.dropout(out, p=self.droprate, inplace=False, training=self.training)
|
||||
return out
|
||||
|
||||
class TransitionBlockbil(nn.Module):
|
||||
def __init__(self, in_planes, out_planes, dropRate=0.0):
|
||||
super(TransitionBlockbil, self).__init__()
|
||||
self.bn1 = nn.InstanceNorm2d(in_planes)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1,
|
||||
padding=0, bias=False)
|
||||
self.bn2 = nn.InstanceNorm2d(out_planes)
|
||||
self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1,
|
||||
padding=1, bias=False)
|
||||
self.droprate = dropRate
|
||||
def forward(self, x):
|
||||
out = self.conv1(self.relu(self.bn1(x)))
|
||||
if self.droprate > 0:
|
||||
out = F.dropout(out, p=self.droprate, inplace=False, training=self.training)
|
||||
out = F.upsample_bilinear(out, scale_factor=2)
|
||||
return self.conv2(out)
|
||||
|
||||
|
||||
class Deblur_first(nn.Module):
|
||||
def __init__(self,in_channels):
|
||||
super(Deblur_first, self).__init__()
|
||||
|
||||
self.dense_block1=BottleneckBlockrs(in_channels,32-in_channels)
|
||||
self.trans_block1=TransitionBlock1(32,16)
|
||||
|
||||
############# Block2-down 32-32 ##############
|
||||
self.dense_block2=BottleneckBlockdl(16,16)
|
||||
self.trans_block2=TransitionBlock3(32,16)
|
||||
|
||||
############# Block3-down 16-16 ##############
|
||||
self.dense_block3=BottleneckBlockdl(16,16)
|
||||
self.trans_block3=TransitionBlock3(32,16)
|
||||
|
||||
|
||||
############# Block5-up 16-16 ##############
|
||||
self.dense_block5=BottleneckBlockdl(32,16)
|
||||
self.trans_block5=TransitionBlock3(48,16)
|
||||
|
||||
self.dense_block6=BottleneckBlockrs(16,16)
|
||||
self.trans_block6=TransitionBlockbil(32,16)
|
||||
|
||||
|
||||
self.conv_refin=nn.Conv2d(16,16,3,1,1)
|
||||
self.conv_refin_in=nn.Conv2d(in_channels,16,3,1,1)
|
||||
self.tanh=nn.Tanh()
|
||||
|
||||
|
||||
self.conv1010 = nn.Conv2d(20, 1, kernel_size=1,stride=1,padding=0) # 1mm
|
||||
self.conv1020 = nn.Conv2d(20, 1, kernel_size=1,stride=1,padding=0) # 1mm
|
||||
self.conv1030 = nn.Conv2d(20, 1, kernel_size=1,stride=1,padding=0) # 1mm
|
||||
self.conv1040 = nn.Conv2d(20, 1, kernel_size=1,stride=1,padding=0) # 1mm
|
||||
|
||||
self.refine3= nn.Conv2d(16, 3, kernel_size=3,stride=1,padding=1)
|
||||
# self.refine3= nn.Conv2d(20+4, 3, kernel_size=7,stride=1,padding=3)
|
||||
|
||||
self.upsample = F.upsample_nearest
|
||||
|
||||
self.relu=nn.LeakyReLU(0.2, inplace=True)
|
||||
self.refineclean1= nn.Conv2d(3, 8, kernel_size=7,stride=1,padding=3)
|
||||
self.refineclean2= nn.Conv2d(8, 3, kernel_size=3,stride=1,padding=1)
|
||||
|
||||
|
||||
def forward(self, x,smaps):
|
||||
## 256x256
|
||||
x1=self.dense_block1(torch.cat([x,smaps],1))
|
||||
x1=self.trans_block1(x1)
|
||||
|
||||
### 32x32
|
||||
x2=(self.dense_block2(x1))
|
||||
x2=self.trans_block2(x2)
|
||||
|
||||
#print x2.size()
|
||||
### 16 X 16
|
||||
x3=(self.dense_block3(x2))
|
||||
x3=self.trans_block3(x3)
|
||||
|
||||
x5_in=torch.cat([x3, x1], 1)
|
||||
x5_i=(self.dense_block5(x5_in))
|
||||
|
||||
x5=self.trans_block5(x5_i)
|
||||
x6=(self.dense_block6(x5))
|
||||
x6=(self.trans_block6(x6))
|
||||
|
||||
x7=self.relu(self.conv_refin_in(torch.cat([x,smaps],1))) - self.relu(self.conv_refin(x6))
|
||||
residual=self.tanh(self.refine3(x7))
|
||||
clean = x - residual
|
||||
clean = self.relu(self.refineclean1(clean))
|
||||
clean = self.tanh(self.refineclean2(clean))
|
||||
|
||||
|
||||
return clean,x5
|
||||
|
||||
class Deblur_class(nn.Module):
|
||||
def __init__(self):
|
||||
super(Deblur_class, self).__init__()
|
||||
|
||||
##### stage class networks ###########
|
||||
self.deblur_class1 = Deblur_first(4)
|
||||
self.deblur_class2 = Deblur_first(4)
|
||||
self.deblur_class3 = Deblur_first(11)
|
||||
self.deblur_class4 = Deblur_first(4)
|
||||
######################################
|
||||
|
||||
def forward(self, x_input1,x_input2,x_input3,x_input4,class1,class2,class3,class4):
|
||||
xh_class1,x_lst1 = self.deblur_class1(x_input1,class1)
|
||||
xh_class2,x_lst2 = self.deblur_class2(x_input2,class2)
|
||||
xh_class3,x_lst3 = self.deblur_class3(x_input3,class3)
|
||||
xh_class4,x_lst4 = self.deblur_class4(x_input4,class4)
|
||||
|
||||
return xh_class1,xh_class2,xh_class3,xh_class4,x_lst1,x_lst2,x_lst3,x_lst4
|
||||
|
||||
class BottleneckBlockcf(nn.Module):
|
||||
def __init__(self, in_planes, out_planes, dropRate=0.0):
|
||||
super(BottleneckBlockcf, self).__init__()
|
||||
inter_planes = out_planes * 3
|
||||
self.bn1 = nn.InstanceNorm2d(in_planes)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.conv_o = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1,
|
||||
padding=0, bias=False)
|
||||
self.conv1 = nn.Conv2d(in_planes, inter_planes, kernel_size=1, stride=1,
|
||||
padding=0, bias=False)
|
||||
self.bn2 = nn.InstanceNorm2d(inter_planes)
|
||||
self.conv2 = nn.Conv2d(inter_planes, inter_planes, kernel_size=3, stride=1,
|
||||
padding=1, bias=False)
|
||||
|
||||
self.bn4 = nn.InstanceNorm2d(inter_planes)
|
||||
self.conv4 = nn.Conv2d(inter_planes, out_planes, kernel_size=3, stride=1,
|
||||
padding=1, bias=False)
|
||||
self.droprate = dropRate
|
||||
def forward(self, x):
|
||||
out = self.conv1(self.relu(self.bn1(x)))
|
||||
if self.droprate > 0:
|
||||
out = F.dropout(out, p=self.droprate, inplace=False, training=self.training)
|
||||
out = self.conv2(self.relu(self.bn2(out)))
|
||||
out = self.conv4(self.relu(self.bn4(out)))
|
||||
if self.droprate > 0:
|
||||
out = F.dropout(out, p=self.droprate, inplace=False, training=self.training)
|
||||
return out
|
||||
|
||||
class scale_kernel_conf(nn.Module):
|
||||
def __init__(self):
|
||||
super(scale_kernel_conf, self).__init__()
|
||||
|
||||
self.conv1 = nn.Conv2d(6,16,3,1,1)#BottleneckBlock(35, 16)
|
||||
self.trans_block1 = TransitionBlock1(16, 16)
|
||||
self.conv2 = BottleneckBlockcf(16, 32)
|
||||
self.trans_block2 = TransitionBlock1(32, 16)
|
||||
self.conv3 = BottleneckBlockcf(16, 32)
|
||||
self.trans_block3 = TransitionBlock1(32, 16)
|
||||
self.conv4 = BottleneckBlockcf(16, 32)
|
||||
self.trans_block4 = TransitionBlock3(32, 16)
|
||||
self.conv_refin = nn.Conv2d(16, 3, 1, 1, 0)
|
||||
self.sig = torch.nn.Sigmoid()
|
||||
|
||||
self.relu = nn.LeakyReLU(0.2, inplace=True)
|
||||
|
||||
def forward(self, x,target):
|
||||
x1=self.conv1(torch.cat([x,target],1))
|
||||
x1 = self.trans_block1(x1)
|
||||
x2=self.conv2(x1)
|
||||
x2 = self.trans_block2(x2)
|
||||
x3=self.conv3(x2)
|
||||
x3 = self.trans_block3(x3)
|
||||
x4=self.conv3(x3)
|
||||
x4 = self.trans_block4(x4)
|
||||
#print(x4.size())
|
||||
residual = self.sig(self.conv_refin(self.sig(F.avg_pool2d(x4,16))))
|
||||
#print(residual)
|
||||
residual = F.upsample_nearest(residual, scale_factor=128)
|
||||
#print(residual.size())
|
||||
return residual
|
||||
|
||||
|
||||
|
||||
class Deblur_segdl(nn.Module):
|
||||
def __init__(self):
|
||||
super(Deblur_segdl, self).__init__()
|
||||
self.deblur_class1 = Deblur_first(4)
|
||||
self.deblur_class2 = Deblur_first(4)
|
||||
self.deblur_class3 = Deblur_first(4)
|
||||
self.deblur_class4 = Deblur_first(4)
|
||||
self.dense_block1=BottleneckBlockrs(7,57)
|
||||
self.dense_block_cl=BottleneckBlock(64,32)
|
||||
#self.trans_block_cl=TransitionBlock1(64,32)
|
||||
self.trans_block1=TransitionBlock1(64,32)
|
||||
|
||||
############# Block2-down 32-32 ##############
|
||||
self.dense_block2=BottleneckBlockrs1(67,64)
|
||||
self.trans_block2=TransitionBlock3(131,64)
|
||||
|
||||
############# Block3-down 16-16 ##############
|
||||
self.dense_block3=BottleneckBlockdl(64,64)
|
||||
self.trans_block3=TransitionBlock3(128,64)
|
||||
|
||||
self.dense_block3_1=BottleneckBlockdl(64,64)
|
||||
self.trans_block3_1=TransitionBlock3(128,64)
|
||||
|
||||
self.dense_block3_2=BottleneckBlockdl(64,64)
|
||||
self.trans_block3_2=TransitionBlock3(128,64)
|
||||
|
||||
############# Block4-up 8-8 ##############
|
||||
self.dense_block4=BottleneckBlockdl(64,64)
|
||||
self.trans_block4=TransitionBlock3(128,64)
|
||||
|
||||
############# Block5-up 16-16 ##############
|
||||
self.dense_block5=BottleneckBlockrs1(128,64)
|
||||
self.trans_block5=TransitionBlockbil(195,64)
|
||||
|
||||
self.dense_block6=BottleneckBlockrs(71,64)
|
||||
self.trans_block6=TransitionBlock3(135,16)
|
||||
|
||||
|
||||
self.conv_refin=nn.Conv2d(23,16,3,1,1)
|
||||
self.conv_refin_in=nn.Conv2d(7,16,3,1,1)
|
||||
self.conv_refin_in64=nn.Conv2d(3,16,3,1,1)
|
||||
self.conv_refin64=nn.Conv2d(192,16,3,1,1)
|
||||
self.tanh=nn.Tanh()
|
||||
|
||||
|
||||
self.conv1010 = nn.Conv2d(20, 1, kernel_size=1,stride=1,padding=0) # 1mm
|
||||
self.conv1020 = nn.Conv2d(20, 1, kernel_size=1,stride=1,padding=0) # 1mm
|
||||
self.conv1030 = nn.Conv2d(20, 1, kernel_size=1,stride=1,padding=0) # 1mm
|
||||
self.conv1040 = nn.Conv2d(20, 1, kernel_size=1,stride=1,padding=0) # 1mm
|
||||
|
||||
self.refine3= nn.Conv2d(16, 3, kernel_size=3,stride=1,padding=1)
|
||||
# self.refine3= nn.Conv2d(20+4, 3, kernel_size=7,stride=1,padding=3)
|
||||
|
||||
self.upsample = F.upsample_nearest
|
||||
|
||||
self.relu=nn.LeakyReLU(0.2, inplace=True)
|
||||
self.refineclean1= nn.Conv2d(3, 8, kernel_size=7,stride=1,padding=3)
|
||||
self.refineclean2= nn.Conv2d(8, 3, kernel_size=3,stride=1,padding=1)
|
||||
|
||||
self.conv11 = nn.Conv2d(64, 1, kernel_size=3,stride=1,padding=1) # 1mm
|
||||
self.conv21 = nn.Conv2d(64, 1, kernel_size=3,stride=1,padding=1) # 1mm
|
||||
self.conv31 = nn.Conv2d(64, 1, kernel_size=3,stride=1,padding=1) # 1mm
|
||||
self.conv3_11 = nn.Conv2d(64, 1, kernel_size=3,stride=1,padding=1) # 1mm
|
||||
self.conv3_21 = nn.Conv2d(64, 1, kernel_size=3,stride=1,padding=1) # 1mm
|
||||
self.conv41 = nn.Conv2d(64, 1, kernel_size=3,stride=1,padding=1) # 1mm
|
||||
self.conv51 = nn.Conv2d(64, 1, kernel_size=3,stride=1,padding=1) # 1mm
|
||||
#self.conv61 = nn.Conv2d(8, 1, kernel_size=3,stride=1,padding=1) # 1mm
|
||||
|
||||
|
||||
|
||||
|
||||
self.batchnorm20=nn.InstanceNorm2d(20)
|
||||
self.batchnorm1=nn.InstanceNorm2d(1)
|
||||
self.conf_ker = scale_kernel_conf()
|
||||
|
||||
|
||||
|
||||
def forward(self, x,x_64,smaps,class1,class2,class3,class4,target,class_msk1,class_msk2,class_msk3,class_msk4):
|
||||
## 256x256
|
||||
xcl_class1,xh_class1 = self.deblur_class1(x,class1)
|
||||
xcl_class2,xh_class2 = self.deblur_class2(x,class2)
|
||||
xcl_class3,xh_class3 = self.deblur_class3(x,class3)
|
||||
xcl_class4,xh_class4 = self.deblur_class4(x,class4)
|
||||
x_cl = self.dense_block_cl(torch.cat([xh_class1,xh_class2,xh_class3,xh_class4],1))
|
||||
x1=self.dense_block1(torch.cat([x,smaps],1))
|
||||
x1=self.trans_block1(x1)
|
||||
|
||||
### 32x32
|
||||
x2=(self.dense_block2(torch.cat([x1,x_64,x_cl],1)))
|
||||
x2=self.trans_block2(x2)
|
||||
|
||||
#print x2.size()
|
||||
### 16 X 16
|
||||
x3=(self.dense_block3(x2))
|
||||
x3=self.trans_block3(x3)
|
||||
|
||||
x3_1 = (self.dense_block3_1(x3))
|
||||
x3_1 = self.trans_block3_1(x3_1)
|
||||
#print x3_1.size()
|
||||
x3_2 = (self.dense_block3_2(x3_1))
|
||||
x3_2 = self.trans_block3_2(x3_2)
|
||||
|
||||
## Classifier ##
|
||||
#x4_in = torch.cat([x3_2, x2], 1)
|
||||
x4=(self.dense_block4(x3_2))
|
||||
x4=self.trans_block4(x4)
|
||||
x5_in=torch.cat([x4, x1,x_cl], 1)
|
||||
x5_i=(self.dense_block5(x5_in))
|
||||
|
||||
xhat64 = self.relu(self.conv_refin_in64(x_64)) - self.relu(self.conv_refin64(x5_i))
|
||||
xhat64 = self.tanh(self.refine3(xhat64))
|
||||
x5=self.trans_block5(torch.cat([x5_i,xhat64],1))
|
||||
x6=(self.dense_block6(torch.cat([x5,x,smaps],1)))
|
||||
x6=(self.trans_block6(x6))
|
||||
shape_out = x6.data.size()
|
||||
# print(shape_out)
|
||||
shape_out = shape_out[2:4]
|
||||
x11 = self.upsample(self.relu((self.conv11(torch.cat([x1,x_cl], 1)))), size=shape_out)
|
||||
x21 = self.upsample(self.relu((self.conv21(x2))), size=shape_out)
|
||||
x31 = self.upsample(self.relu((self.conv31(x3))), size=shape_out)
|
||||
x3_11 = self.upsample(self.relu((self.conv3_11(x3_1))), size=shape_out)
|
||||
x3_21 = self.upsample(self.relu((self.conv3_21(x3_2))), size=shape_out)
|
||||
x41 = self.upsample(self.relu((self.conv41(x4))), size=shape_out)
|
||||
x51 = self.upsample(self.relu((self.conv51(x5))), size=shape_out)
|
||||
x6=torch.cat([x6,x51,x41,x3_21,x3_11,x31,x21,x11],1)
|
||||
x7=self.relu(self.conv_refin_in(torch.cat([x,smaps],1))) - self.relu(self.conv_refin(x6))
|
||||
residual=self.tanh(self.refine3(x7))
|
||||
clean = x - residual
|
||||
clean = self.relu(self.refineclean1(clean))
|
||||
clean = self.tanh(self.refineclean2(clean))
|
||||
|
||||
clean64 = x_64 - xhat64
|
||||
clean64 = self.relu(self.refineclean1(clean64))
|
||||
clean64 = self.tanh(self.refineclean2(clean64))
|
||||
|
||||
xmask1 = self.conf_ker(clean*class_msk1,target*class_msk1)
|
||||
xmask2 = self.conf_ker(clean*class_msk2,target*class_msk2)
|
||||
xmask3 = self.conf_ker(clean*class_msk3,target*class_msk3)
|
||||
xmask4 = self.conf_ker(clean*class_msk4,target*class_msk4)
|
||||
|
||||
return clean,clean64,xmask1,xmask2,xmask3,xmask4,xcl_class1,xcl_class2,xcl_class3,xcl_class4
|
||||
|
||||
class Segmentation(nn.Module):
|
||||
def __init__(self):
|
||||
super(Segmentation, self).__init__()
|
||||
|
||||
self.dense_block1=BottleneckBlockrs1(3,61)
|
||||
self.trans_block1=TransitionBlock1(64,64)
|
||||
|
||||
############# Block2-down 32-32 ##############
|
||||
self.dense_block2=BottleneckBlockdls(67,64)
|
||||
self.trans_block2=TransitionBlock1(131,64)
|
||||
|
||||
############# Block3-down 16-16 ##############
|
||||
self.dense_block3=BottleneckBlockdls(64,64)
|
||||
self.trans_block3=TransitionBlock3(128,64)
|
||||
|
||||
self.dense_block3_1=BottleneckBlockdls(64,64)
|
||||
self.trans_block3_1=TransitionBlock3(128,64)
|
||||
|
||||
self.dense_block3_2=BottleneckBlockdls(64,64)
|
||||
self.trans_block3_2=TransitionBlock3(128,64)
|
||||
|
||||
############# Block4-up 8-8 ##############
|
||||
self.dense_block4=BottleneckBlockdls(128,64)
|
||||
self.trans_block4=TransitionBlock(192,64)
|
||||
|
||||
############# Block5-up 16-16 ##############
|
||||
self.dense_block5=BottleneckBlockdls(128,64)
|
||||
self.trans_block5=TransitionBlock(196,64)
|
||||
|
||||
self.dense_block6=BottleneckBlockrs1(64,64)
|
||||
self.trans_block6=TransitionBlock3(128,16)
|
||||
|
||||
|
||||
self.conv_refin=nn.Conv2d(23,16,3,1,1)
|
||||
self.conv_refin64=nn.Conv2d(192,16,3,1,1)
|
||||
self.tanh=nn.Sigmoid()
|
||||
|
||||
self.refine3= nn.Conv2d(16, 4, kernel_size=3,stride=1,padding=1)
|
||||
self.refine3_i= nn.Conv2d(16, 4, kernel_size=3,stride=1,padding=1)
|
||||
# self.refine3= nn.Conv2d(20+4, 3, kernel_size=7,stride=1,padding=3)
|
||||
|
||||
self.upsample = F.upsample_nearest
|
||||
|
||||
self.relu=nn.LeakyReLU(0.2, inplace=True)
|
||||
self.refineclean1= nn.Conv2d(4, 8, kernel_size=7,stride=1,padding=3)
|
||||
self.refineclean2= nn.Conv2d(8, 4, kernel_size=3,stride=1,padding=1)
|
||||
|
||||
self.conv11 = nn.Conv2d(64, 1, kernel_size=3,stride=1,padding=1) # 1mm
|
||||
self.conv21 = nn.Conv2d(64, 1, kernel_size=3,stride=1,padding=1) # 1mm
|
||||
self.conv31 = nn.Conv2d(64, 1, kernel_size=3,stride=1,padding=1) # 1mm
|
||||
self.conv3_11 = nn.Conv2d(64, 1, kernel_size=3,stride=1,padding=1) # 1mm
|
||||
self.conv3_21 = nn.Conv2d(64, 1, kernel_size=3,stride=1,padding=1) # 1mm
|
||||
self.conv41 = nn.Conv2d(64, 1, kernel_size=3,stride=1,padding=1) # 1mm
|
||||
self.conv51 = nn.Conv2d(64, 1, kernel_size=3,stride=1,padding=1) # 1mm
|
||||
#self.conv61 = nn.Conv2d(8, 1, kernel_size=3,stride=1,padding=1) # 1mm
|
||||
|
||||
|
||||
|
||||
|
||||
self.batchnorm20=nn.BatchNorm2d(20)
|
||||
self.batchnorm1=nn.BatchNorm2d(1)
|
||||
|
||||
|
||||
|
||||
def forward(self, x,x_64):
|
||||
## 256x256
|
||||
x1=self.dense_block1(x)
|
||||
x1=self.trans_block1(x1)
|
||||
|
||||
### 32x32
|
||||
x2=(self.dense_block2(torch.cat([x1,x_64],1)))
|
||||
x2=self.trans_block2(x2)
|
||||
|
||||
#print x2.size()
|
||||
### 16 X 16
|
||||
x3=(self.dense_block3(x2))
|
||||
x3=self.trans_block3(x3)
|
||||
|
||||
x3_1 = (self.dense_block3_1(x3))
|
||||
x3_1 = self.trans_block3_1(x3_1)
|
||||
#print x3_1.size()
|
||||
x3_2 = (self.dense_block3_2(x3_1))
|
||||
x3_2 = self.trans_block3_2(x3_2)
|
||||
|
||||
## Classifier ##
|
||||
x4_in = torch.cat([x3_2, x2], 1)
|
||||
x4=(self.dense_block4(x4_in))
|
||||
x4=self.trans_block4(x4)
|
||||
x5_in=torch.cat([x4, x1], 1)
|
||||
x5_i=(self.dense_block5(x5_in))
|
||||
xhat64 = self.relu(self.conv_refin64(x5_i))
|
||||
xhat64 = self.tanh(self.refine3_i(xhat64))
|
||||
x5=self.trans_block5(torch.cat([x5_i, xhat64], 1))
|
||||
|
||||
x6=(self.dense_block6(x5))
|
||||
x6=(self.trans_block6(x6))
|
||||
shape_out = x6.data.size()
|
||||
# print(shape_out)
|
||||
shape_out = shape_out[2:4]
|
||||
x11 = self.upsample(self.relu((self.conv11(x1))), size=shape_out)
|
||||
x21 = self.upsample(self.relu((self.conv21(x2))), size=shape_out)
|
||||
x31 = self.upsample(self.relu((self.conv31(x3))), size=shape_out)
|
||||
x3_11 = self.upsample(self.relu((self.conv3_11(x3_1))), size=shape_out)
|
||||
x3_21 = self.upsample(self.relu((self.conv3_21(x3_2))), size=shape_out)
|
||||
x41 = self.upsample(self.relu((self.conv41(x4))), size=shape_out)
|
||||
x51 = self.upsample(self.relu((self.conv51(x5))), size=shape_out)
|
||||
x6 = torch.cat([x6,x51,x41,x3_21,x3_11,x31,x21,x11],1)
|
||||
x7 = self.relu(self.conv_refin(x6))
|
||||
residual = self.tanh(self.refine3(x7))
|
||||
|
||||
return residual,xhat64
|
||||
@ -0,0 +1,37 @@
|
||||
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
## Created by: Hang Zhang
|
||||
## ECE Department, Rutgers University
|
||||
## Email: zhang.hang@rutgers.edu
|
||||
## Copyright (c) 2017
|
||||
##
|
||||
## This source code is licensed under the MIT-style license found in the
|
||||
## LICENSE file in the root directory of this source tree
|
||||
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import os
|
||||
from torch.autograd import Variable
|
||||
|
||||
from myutils import utils
|
||||
|
||||
class StyleLoader():
|
||||
def __init__(self, style_folder, style_size, cuda=True):
|
||||
self.folder = style_folder
|
||||
self.style_size = style_size
|
||||
self.files = os.listdir(style_folder)
|
||||
self.cuda = cuda
|
||||
|
||||
def get(self, i):
|
||||
idx = i%len(self.files)
|
||||
filepath = os.path.join(self.folder, self.files[idx])
|
||||
style = utils.tensor_load_rgbimage(filepath, self.style_size)
|
||||
style = style.unsqueeze(0)
|
||||
style = utils.preprocess_batch(style)
|
||||
if self.cuda:
|
||||
style = style.cuda()
|
||||
style_v = Variable(style, requires_grad=False)
|
||||
return style_v
|
||||
|
||||
def size(self):
|
||||
return len(self.files)
|
||||
|
||||
|
||||
@ -0,0 +1,310 @@
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import math
|
||||
import torch
|
||||
import argparse
|
||||
import numpy as np
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import torch.legacy.nn as lnn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from functools import reduce
|
||||
from torch.autograd import Variable
|
||||
from torch.utils.serialization import load_lua
|
||||
|
||||
class LambdaBase(nn.Sequential):
|
||||
def __init__(self, fn, *args):
|
||||
super(LambdaBase, self).__init__(*args)
|
||||
self.lambda_func = fn
|
||||
|
||||
def forward_prepare(self, input):
|
||||
output = []
|
||||
for module in self._modules.values():
|
||||
output.append(module(input))
|
||||
return output if output else input
|
||||
|
||||
class Lambda(LambdaBase):
|
||||
def forward(self, input):
|
||||
return self.lambda_func(self.forward_prepare(input))
|
||||
|
||||
class LambdaMap(LambdaBase):
|
||||
def forward(self, input):
|
||||
# result is Variables list [Variable1, Variable2, ...]
|
||||
return list(map(self.lambda_func,self.forward_prepare(input)))
|
||||
|
||||
class LambdaReduce(LambdaBase):
|
||||
def forward(self, input):
|
||||
# result is a Variable
|
||||
return reduce(self.lambda_func,self.forward_prepare(input))
|
||||
|
||||
|
||||
def copy_param(m,n):
|
||||
if m.weight is not None: n.weight.data.copy_(m.weight)
|
||||
if m.bias is not None: n.bias.data.copy_(m.bias)
|
||||
if hasattr(n,'running_mean'): n.running_mean.copy_(m.running_mean)
|
||||
if hasattr(n,'running_var'): n.running_var.copy_(m.running_var)
|
||||
|
||||
def add_submodule(seq, *args):
|
||||
for n in args:
|
||||
seq.add_module(str(len(seq._modules)),n)
|
||||
|
||||
def lua_recursive_model(module,seq):
|
||||
for m in module.modules:
|
||||
name = type(m).__name__
|
||||
real = m
|
||||
if name == 'TorchObject':
|
||||
name = m._typename.replace('cudnn.','')
|
||||
m = m._obj
|
||||
|
||||
if name == 'SpatialConvolution' or name == 'nn.SpatialConvolutionMM':
|
||||
if not hasattr(m,'groups') or m.groups is None: m.groups=1
|
||||
n = nn.Conv2d(m.nInputPlane,m.nOutputPlane,(m.kW,m.kH),(m.dW,m.dH),(m.padW,m.padH),1,m.groups,bias=(m.bias is not None))
|
||||
print(size(m.weight))
|
||||
print(n)
|
||||
copy_param(m,n)
|
||||
add_submodule(seq,n)
|
||||
elif name == 'SpatialBatchNormalization':
|
||||
n = nn.BatchNorm2d(m.running_mean.size(0), m.eps, m.momentum, m.affine)
|
||||
copy_param(m,n)
|
||||
add_submodule(seq,n)
|
||||
elif name == 'VolumetricBatchNormalization':
|
||||
n = nn.BatchNorm3d(m.running_mean.size(0), m.eps, m.momentum, m.affine)
|
||||
copy_param(m, n)
|
||||
add_submodule(seq, n)
|
||||
elif name == 'ReLU':
|
||||
n = nn.ReLU()
|
||||
add_submodule(seq,n)
|
||||
elif name == 'Sigmoid':
|
||||
n = nn.Sigmoid()
|
||||
add_submodule(seq,n)
|
||||
elif name == 'SpatialMaxPooling':
|
||||
n = nn.MaxPool2d((m.kW,m.kH),(m.dW,m.dH),(m.padW,m.padH),ceil_mode=m.ceil_mode)
|
||||
add_submodule(seq,n)
|
||||
elif name == 'SpatialAveragePooling':
|
||||
n = nn.AvgPool2d((m.kW,m.kH),(m.dW,m.dH),(m.padW,m.padH),ceil_mode=m.ceil_mode)
|
||||
add_submodule(seq,n)
|
||||
elif name == 'SpatialUpSamplingNearest':
|
||||
n = nn.UpsamplingNearest2d(scale_factor=m.scale_factor)
|
||||
add_submodule(seq,n)
|
||||
elif name == 'View':
|
||||
n = Lambda(lambda x: x.view(x.size(0),-1))
|
||||
add_submodule(seq,n)
|
||||
elif name == 'Reshape':
|
||||
n = Lambda(lambda x: x.view(x.size(0),-1))
|
||||
add_submodule(seq,n)
|
||||
elif name == 'Linear':
|
||||
# Linear in pytorch only accept 2D input
|
||||
n1 = Lambda(lambda x: x.view(1,-1) if 1==len(x.size()) else x )
|
||||
n2 = nn.Linear(m.weight.size(1),m.weight.size(0),bias=(m.bias is not None))
|
||||
copy_param(m,n2)
|
||||
n = nn.Sequential(n1,n2)
|
||||
add_submodule(seq,n)
|
||||
elif name == 'Dropout':
|
||||
m.inplace = False
|
||||
n = nn.Dropout(m.p)
|
||||
add_submodule(seq,n)
|
||||
elif name == 'SoftMax':
|
||||
n = nn.Softmax()
|
||||
add_submodule(seq,n)
|
||||
elif name == 'Identity':
|
||||
n = Lambda(lambda x: x) # do nothing
|
||||
add_submodule(seq,n)
|
||||
elif name == 'SpatialFullConvolution':
|
||||
n = nn.ConvTranspose2d(m.nInputPlane,m.nOutputPlane,(m.kW,m.kH),(m.dW,m.dH),(m.padW,m.padH),(m.adjW,m.adjH))
|
||||
copy_param(m,n)
|
||||
add_submodule(seq,n)
|
||||
elif name == 'VolumetricFullConvolution':
|
||||
n = nn.ConvTranspose3d(m.nInputPlane,m.nOutputPlane,(m.kT,m.kW,m.kH),(m.dT,m.dW,m.dH),(m.padT,m.padW,m.padH),(m.adjT,m.adjW,m.adjH),m.groups)
|
||||
copy_param(m,n)
|
||||
add_submodule(seq, n)
|
||||
elif name == 'SpatialReplicationPadding':
|
||||
n = nn.ReplicationPad2d((m.pad_l,m.pad_r,m.pad_t,m.pad_b))
|
||||
add_submodule(seq,n)
|
||||
elif name == 'SpatialReflectionPadding':
|
||||
n = nn.ReflectionPad2d((m.pad_l,m.pad_r,m.pad_t,m.pad_b))
|
||||
add_submodule(seq,n)
|
||||
elif name == 'Copy':
|
||||
n = Lambda(lambda x: x) # do nothing
|
||||
add_submodule(seq,n)
|
||||
elif name == 'Narrow':
|
||||
n = Lambda(lambda x,a=(m.dimension,m.index,m.length): x.narrow(*a))
|
||||
add_submodule(seq,n)
|
||||
elif name == 'SpatialCrossMapLRN':
|
||||
lrn = lnn.SpatialCrossMapLRN(m.size,m.alpha,m.beta,m.k)
|
||||
n = Lambda(lambda x,lrn=lrn: Variable(lrn.forward(x.data)))
|
||||
add_submodule(seq,n)
|
||||
elif name == 'Sequential':
|
||||
n = nn.Sequential()
|
||||
lua_recursive_model(m,n)
|
||||
add_submodule(seq,n)
|
||||
elif name == 'ConcatTable': # output is list
|
||||
n = LambdaMap(lambda x: x)
|
||||
lua_recursive_model(m,n)
|
||||
add_submodule(seq,n)
|
||||
elif name == 'CAddTable': # input is list
|
||||
n = LambdaReduce(lambda x,y: x+y)
|
||||
add_submodule(seq,n)
|
||||
elif name == 'Concat':
|
||||
dim = m.dimension
|
||||
n = LambdaReduce(lambda x,y,dim=dim: torch.cat((x,y),dim))
|
||||
lua_recursive_model(m,n)
|
||||
add_submodule(seq,n)
|
||||
elif name == 'TorchObject':
|
||||
print('Not Implement',name,real._typename)
|
||||
else:
|
||||
print('Not Implement',name)
|
||||
|
||||
|
||||
def lua_recursive_source(module):
|
||||
s = []
|
||||
for m in module.modules:
|
||||
name = type(m).__name__
|
||||
real = m
|
||||
if name == 'TorchObject':
|
||||
name = m._typename.replace('cudnn.','')
|
||||
m = m._obj
|
||||
|
||||
if name == 'SpatialConvolution' or name == 'nn.SpatialConvolutionMM':
|
||||
if not hasattr(m,'groups') or m.groups is None: m.groups=1
|
||||
s += ['nn.Conv2d({},{},{},{},{},{},{},bias={}),#Conv2d'.format(m.nInputPlane,
|
||||
m.nOutputPlane,(m.kW,m.kH),(m.dW,m.dH),(m.padW,m.padH),1,m.groups,m.bias is not None)]
|
||||
elif name == 'SpatialBatchNormalization':
|
||||
s += ['nn.BatchNorm2d({},{},{},{}),#BatchNorm2d'.format(m.running_mean.size(0), m.eps, m.momentum, m.affine)]
|
||||
elif name == 'VolumetricBatchNormalization':
|
||||
s += ['nn.BatchNorm3d({},{},{},{}),#BatchNorm3d'.format(m.running_mean.size(0), m.eps, m.momentum, m.affine)]
|
||||
elif name == 'ReLU':
|
||||
s += ['nn.ReLU()']
|
||||
elif name == 'Sigmoid':
|
||||
s += ['nn.Sigmoid()']
|
||||
elif name == 'SpatialMaxPooling':
|
||||
s += ['nn.MaxPool2d({},{},{},ceil_mode={}),#MaxPool2d'.format((m.kW,m.kH),(m.dW,m.dH),(m.padW,m.padH),m.ceil_mode)]
|
||||
elif name == 'SpatialAveragePooling':
|
||||
s += ['nn.AvgPool2d({},{},{},ceil_mode={}),#AvgPool2d'.format((m.kW,m.kH),(m.dW,m.dH),(m.padW,m.padH),m.ceil_mode)]
|
||||
elif name == 'SpatialUpSamplingNearest':
|
||||
s += ['nn.UpsamplingNearest2d(scale_factor={})'.format(m.scale_factor)]
|
||||
elif name == 'View':
|
||||
s += ['Lambda(lambda x: x.view(x.size(0),-1)), # View']
|
||||
elif name == 'Reshape':
|
||||
s += ['Lambda(lambda x: x.view(x.size(0),-1)), # Reshape']
|
||||
elif name == 'Linear':
|
||||
s1 = 'Lambda(lambda x: x.view(1,-1) if 1==len(x.size()) else x )'
|
||||
s2 = 'nn.Linear({},{},bias={})'.format(m.weight.size(1),m.weight.size(0),(m.bias is not None))
|
||||
s += ['nn.Sequential({},{}),#Linear'.format(s1,s2)]
|
||||
elif name == 'Dropout':
|
||||
s += ['nn.Dropout({})'.format(m.p)]
|
||||
elif name == 'SoftMax':
|
||||
s += ['nn.Softmax()']
|
||||
elif name == 'Identity':
|
||||
s += ['Lambda(lambda x: x), # Identity']
|
||||
elif name == 'SpatialFullConvolution':
|
||||
s += ['nn.ConvTranspose2d({},{},{},{},{},{})'.format(m.nInputPlane,
|
||||
m.nOutputPlane,(m.kW,m.kH),(m.dW,m.dH),(m.padW,m.padH),(m.adjW,m.adjH))]
|
||||
elif name == 'VolumetricFullConvolution':
|
||||
s += ['nn.ConvTranspose3d({},{},{},{},{},{},{})'.format(m.nInputPlane,
|
||||
m.nOutputPlane,(m.kT,m.kW,m.kH),(m.dT,m.dW,m.dH),(m.padT,m.padW,m.padH),(m.adjT,m.adjW,m.adjH),m.groups)]
|
||||
elif name == 'SpatialReplicationPadding':
|
||||
s += ['nn.ReplicationPad2d({})'.format((m.pad_l,m.pad_r,m.pad_t,m.pad_b))]
|
||||
elif name == 'SpatialReflectionPadding':
|
||||
s += ['nn.ReflectionPad2d({})'.format((m.pad_l,m.pad_r,m.pad_t,m.pad_b))]
|
||||
elif name == 'Copy':
|
||||
s += ['Lambda(lambda x: x), # Copy']
|
||||
elif name == 'Narrow':
|
||||
s += ['Lambda(lambda x,a={}: x.narrow(*a))'.format((m.dimension,m.index,m.length))]
|
||||
elif name == 'SpatialCrossMapLRN':
|
||||
lrn = 'lnn.SpatialCrossMapLRN(*{})'.format((m.size,m.alpha,m.beta,m.k))
|
||||
s += ['Lambda(lambda x,lrn={}: Variable(lrn.forward(x.data)))'.format(lrn)]
|
||||
|
||||
elif name == 'Sequential':
|
||||
s += ['nn.Sequential( # Sequential']
|
||||
s += lua_recursive_source(m)
|
||||
s += [')']
|
||||
elif name == 'ConcatTable':
|
||||
s += ['LambdaMap(lambda x: x, # ConcatTable']
|
||||
s += lua_recursive_source(m)
|
||||
s += [')']
|
||||
elif name == 'CAddTable':
|
||||
s += ['LambdaReduce(lambda x,y: x+y), # CAddTable']
|
||||
elif name == 'Concat':
|
||||
dim = m.dimension
|
||||
s += ['LambdaReduce(lambda x,y,dim={}: torch.cat((x,y),dim), # Concat'.format(m.dimension)]
|
||||
s += lua_recursive_source(m)
|
||||
s += [')']
|
||||
else:
|
||||
s += '# ' + name + ' Not Implement,\n'
|
||||
s = map(lambda x: '\t{}'.format(x),s)
|
||||
return s
|
||||
|
||||
def simplify_source(s):
|
||||
s = map(lambda x: x.replace(',(1, 1),(0, 0),1,1,bias=True),#Conv2d',')'),s)
|
||||
s = map(lambda x: x.replace(',(0, 0),1,1,bias=True),#Conv2d',')'),s)
|
||||
s = map(lambda x: x.replace(',1,1,bias=True),#Conv2d',')'),s)
|
||||
s = map(lambda x: x.replace(',bias=True),#Conv2d',')'),s)
|
||||
s = map(lambda x: x.replace('),#Conv2d',')'),s)
|
||||
s = map(lambda x: x.replace(',1e-05,0.1,True),#BatchNorm2d',')'),s)
|
||||
s = map(lambda x: x.replace('),#BatchNorm2d',')'),s)
|
||||
s = map(lambda x: x.replace(',(0, 0),ceil_mode=False),#MaxPool2d',')'),s)
|
||||
s = map(lambda x: x.replace(',ceil_mode=False),#MaxPool2d',')'),s)
|
||||
s = map(lambda x: x.replace('),#MaxPool2d',')'),s)
|
||||
s = map(lambda x: x.replace(',(0, 0),ceil_mode=False),#AvgPool2d',')'),s)
|
||||
s = map(lambda x: x.replace(',ceil_mode=False),#AvgPool2d',')'),s)
|
||||
s = map(lambda x: x.replace(',bias=True)),#Linear',')), # Linear'),s)
|
||||
s = map(lambda x: x.replace(')),#Linear',')), # Linear'),s)
|
||||
|
||||
s = map(lambda x: '{},\n'.format(x),s)
|
||||
s = map(lambda x: x[1:],s)
|
||||
s = reduce(lambda x,y: x+y, s)
|
||||
return s
|
||||
|
||||
def torch_to_pytorch(t7_filename,outputname=None):
|
||||
model = load_lua(t7_filename,unknown_classes=True)
|
||||
if type(model).__name__=='hashable_uniq_dict': model=model.model
|
||||
model.gradInput = None
|
||||
slist = lua_recursive_source(lnn.Sequential().add(model))
|
||||
s = simplify_source(slist)
|
||||
header = '''
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.legacy.nn as lnn
|
||||
from functools import reduce
|
||||
from torch.autograd import Variable
|
||||
class LambdaBase(nn.Sequential):
|
||||
def __init__(self, fn, *args):
|
||||
super(LambdaBase, self).__init__(*args)
|
||||
self.lambda_func = fn
|
||||
def forward_prepare(self, input):
|
||||
output = []
|
||||
for module in self._modules.values():
|
||||
output.append(module(input))
|
||||
return output if output else input
|
||||
class Lambda(LambdaBase):
|
||||
def forward(self, input):
|
||||
return self.lambda_func(self.forward_prepare(input))
|
||||
class LambdaMap(LambdaBase):
|
||||
def forward(self, input):
|
||||
return list(map(self.lambda_func,self.forward_prepare(input)))
|
||||
class LambdaReduce(LambdaBase):
|
||||
def forward(self, input):
|
||||
return reduce(self.lambda_func,self.forward_prepare(input))
|
||||
'''
|
||||
varname = t7_filename.replace('.t7','').replace('.','_').replace('-','_')
|
||||
s = '{}\n\n{} = {}'.format(header,varname,s[:-2])
|
||||
|
||||
if outputname is None: outputname=varname
|
||||
with open(outputname+'.py', "w") as pyfile:
|
||||
pyfile.write(s)
|
||||
|
||||
n = nn.Sequential()
|
||||
lua_recursive_model(model,n)
|
||||
torch.save(n.state_dict(),outputname+'.pth')
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(description='Convert torch t7 model to pytorch')
|
||||
parser.add_argument('--model','-m', type=str, required=True,
|
||||
help='torch model file in t7 format')
|
||||
parser.add_argument('--output', '-o', type=str, default=None,
|
||||
help='output file name prefix, xxx.py xxx.pth')
|
||||
args = parser.parse_args()
|
||||
|
||||
torch_to_pytorch(args.model,args.output)
|
||||
@ -0,0 +1,94 @@
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from torch.autograd import Variable
|
||||
from torch.utils.serialization import load_lua
|
||||
|
||||
from myutils.vgg16 import Vgg16
|
||||
|
||||
def tensor_load_rgbimage(filename, size=None, scale=None, keep_asp=False):
|
||||
img = Image.open(filename).convert('RGB')
|
||||
if size is not None:
|
||||
if keep_asp:
|
||||
size2 = int(size * 1.0 / img.size[0] * img.size[1])
|
||||
img = img.resize((size, size2), Image.ANTIALIAS)
|
||||
else:
|
||||
img = img.resize((size, size), Image.ANTIALIAS)
|
||||
|
||||
elif scale is not None:
|
||||
img = img.resize((int(img.size[0] / scale), int(img.size[1] / scale)), Image.ANTIALIAS)
|
||||
img = np.array(img).transpose(2, 0, 1)
|
||||
img = torch.from_numpy(img).float()
|
||||
return img
|
||||
|
||||
|
||||
def tensor_save_rgbimage(tensor, filename, cuda=False):
|
||||
if cuda:
|
||||
img = tensor.clone().cpu().clamp(0, 255).numpy()
|
||||
else:
|
||||
img = tensor.clone().clamp(0, 255).numpy()
|
||||
img = img.transpose(1, 2, 0).astype('uint8')
|
||||
img = Image.fromarray(img)
|
||||
img.save(filename)
|
||||
|
||||
|
||||
def tensor_save_bgrimage(tensor, filename, cuda=False):
|
||||
(b, g, r) = torch.chunk(tensor, 3)
|
||||
tensor = torch.cat((r, g, b))
|
||||
tensor_save_rgbimage(tensor, filename, cuda)
|
||||
|
||||
|
||||
def gram_matrix(y):
|
||||
(b, ch, h, w) = y.size()
|
||||
features = y.view(b, ch, w * h)
|
||||
features_t = features.transpose(1, 2)
|
||||
gram = features.bmm(features_t) / (ch * h * w)
|
||||
return gram
|
||||
|
||||
|
||||
def subtract_imagenet_mean_batch(batch):
|
||||
"""Subtract ImageNet mean pixel-wise from a BGR image."""
|
||||
tensortype = type(batch.data)
|
||||
mean = tensortype(batch.data.size())
|
||||
mean[:, 0, :, :] = 103.939
|
||||
mean[:, 1, :, :] = 116.779
|
||||
mean[:, 2, :, :] = 123.680
|
||||
return batch - Variable(mean)
|
||||
|
||||
|
||||
def add_imagenet_mean_batch(batch):
|
||||
"""Add ImageNet mean pixel-wise from a BGR image."""
|
||||
tensortype = type(batch.data)
|
||||
mean = tensortype(batch.data.size())
|
||||
mean[:, 0, :, :] = 103.939
|
||||
mean[:, 1, :, :] = 116.779
|
||||
mean[:, 2, :, :] = 123.680
|
||||
return batch + Variable(mean)
|
||||
|
||||
def imagenet_clamp_batch(batch, low, high):
|
||||
batch[:,0,:,:].data.clamp_(low-103.939, high-103.939)
|
||||
batch[:,1,:,:].data.clamp_(low-116.779, high-116.779)
|
||||
batch[:,2,:,:].data.clamp_(low-123.680, high-123.680)
|
||||
|
||||
|
||||
def preprocess_batch(batch):
|
||||
batch = batch.transpose(0, 1)
|
||||
(r, g, b) = torch.chunk(batch, 3)
|
||||
batch = torch.cat((b, g, r))
|
||||
batch = batch.transpose(0, 1)
|
||||
return batch
|
||||
|
||||
|
||||
def init_vgg16(model_folder):
|
||||
"""load the vgg16 model feature"""
|
||||
if not os.path.exists(os.path.join(model_folder, 'vgg16.weight')):
|
||||
if not os.path.exists(os.path.join(model_folder, 'vgg16.t7')):
|
||||
os.system(
|
||||
'wget http://cs.stanford.edu/people/jcjohns/fast-neural-style/models/vgg16.t7 -O ' + os.path.join(model_folder, 'vgg16.t7'))
|
||||
vgglua = load_lua(os.path.join(model_folder, 'vgg16.t7'))
|
||||
vgg = Vgg16()
|
||||
for (src, dst) in zip(vgglua.parameters()[0], vgg.parameters()):
|
||||
dst.data[:] = src
|
||||
torch.save(vgg.state_dict(), os.path.join(model_folder, 'vgg16.weight'))
|
||||
@ -0,0 +1,49 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class Vgg16(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(Vgg16, self).__init__()
|
||||
self.conv1_1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
|
||||
self.conv1_2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
self.conv2_1 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
|
||||
self.conv2_2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
self.conv3_1 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
|
||||
self.conv3_2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
|
||||
self.conv3_3 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
self.conv4_1 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1)
|
||||
self.conv4_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
|
||||
self.conv4_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
self.conv5_1 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
|
||||
self.conv5_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
|
||||
self.conv5_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
def forward(self, X):
|
||||
h = F.relu(self.conv1_1(X))
|
||||
h = F.relu(self.conv1_2(h))
|
||||
relu1_2 = h
|
||||
h = F.max_pool2d(h, kernel_size=2, stride=2)
|
||||
|
||||
h = F.relu(self.conv2_1(h))
|
||||
h = F.relu(self.conv2_2(h))
|
||||
relu2_2 = h
|
||||
h = F.max_pool2d(h, kernel_size=2, stride=2)
|
||||
|
||||
h = F.relu(self.conv3_1(h))
|
||||
h = F.relu(self.conv3_2(h))
|
||||
h = F.relu(self.conv3_3(h))
|
||||
relu3_3 = h
|
||||
h = F.max_pool2d(h, kernel_size=2, stride=2)
|
||||
|
||||
h = F.relu(self.conv4_1(h))
|
||||
h = F.relu(self.conv4_2(h))
|
||||
h = F.relu(self.conv4_3(h))
|
||||
relu4_3 = h
|
||||
|
||||
return [relu1_2, relu2_2, relu3_3, relu4_3]
|
||||
@ -0,0 +1,314 @@
|
||||
from __future__ import print_function
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
import random
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.parallel
|
||||
import torch.backends.cudnn as cudnn
|
||||
cudnn.benchmark = True
|
||||
cudnn.fastest = True
|
||||
import torch.optim as optim
|
||||
import torchvision.utils as vutils
|
||||
from torch.autograd import Variable
|
||||
|
||||
from misc import *
|
||||
import models.face_fed as net
|
||||
|
||||
from myutils.vgg16 import Vgg16
|
||||
from myutils import utils
|
||||
import pdb
|
||||
|
||||
# Pre-defined Parameters
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--dataset', required=False,
|
||||
default='pix2pix', help='')
|
||||
parser.add_argument('--dataroot', required=False,
|
||||
default='', help='path to trn dataset')
|
||||
parser.add_argument('--valDataroot', required=False,
|
||||
default='', help='path to val dataset')
|
||||
parser.add_argument('--mode', type=str, default='B2A', help='B2A: facade, A2B: edges2shoes')
|
||||
parser.add_argument('--batchSize', type=int, default=1, help='input batch size')
|
||||
parser.add_argument('--valBatchSize', type=int, default=1, help='input batch size')
|
||||
parser.add_argument('--originalSize', type=int,
|
||||
default=128, help='the height / width of the original input image')
|
||||
parser.add_argument('--imageSize', type=int,
|
||||
default=128, help='the height / width of the cropped input image to network')
|
||||
parser.add_argument('--inputChannelSize', type=int,
|
||||
default=3, help='size of the input channels')
|
||||
parser.add_argument('--outputChannelSize', type=int,
|
||||
default=3, help='size of the output channels')
|
||||
parser.add_argument('--ngf', type=int, default=64)
|
||||
parser.add_argument('--ndf', type=int, default=64)
|
||||
parser.add_argument('--niter', type=int, default=400, help='number of epochs to train for')
|
||||
parser.add_argument('--lrD', type=float, default=0.0002, help='learning rate, default=0.0002')
|
||||
parser.add_argument('--lrG', type=float, default=0.0002, help='learning rate, default=0.0002')
|
||||
parser.add_argument('--annealStart', type=int, default=0, help='annealing learning rate start to')
|
||||
parser.add_argument('--annealEvery', type=int, default=400, help='epoch to reaching at learning rate of 0')
|
||||
parser.add_argument('--lambdaGAN', type=float, default=0.01, help='lambdaGAN')
|
||||
parser.add_argument('--lambdaIMG', type=float, default=1, help='lambdaIMG')
|
||||
parser.add_argument('--poolSize', type=int, default=50, help='Buffer size for storing previously generated samples from G')
|
||||
parser.add_argument('--wd', type=float, default=0.0000, help='weight decay in D')
|
||||
parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam')
|
||||
parser.add_argument('--netG', default='', help="path to netG (to continue training)")
|
||||
parser.add_argument('--netD', default='', help="path to netD (to continue training)")
|
||||
parser.add_argument('--workers', type=int, help='number of data loading workers', default=1)
|
||||
parser.add_argument('--exp', default='sample', help='folder to output images and model checkpoints')
|
||||
parser.add_argument('--display', type=int, default=5, help='interval for displaying train-logs')
|
||||
parser.add_argument('--evalIter', type=int, default=500, help='interval for evauating(generating) images from valDataroot')
|
||||
opt = parser.parse_args()
|
||||
print(opt)
|
||||
|
||||
|
||||
|
||||
create_exp_dir(opt.exp)
|
||||
opt.manualSeed = random.randint(1, 10000)
|
||||
random.seed(opt.manualSeed)
|
||||
torch.manual_seed(opt.manualSeed)
|
||||
torch.cuda.manual_seed_all(opt.manualSeed)
|
||||
print("Random Seed: ", opt.manualSeed)
|
||||
|
||||
# Initialize dataloader
|
||||
dataloader = getLoader(opt.dataset,
|
||||
opt.dataroot,
|
||||
opt.originalSize,
|
||||
opt.imageSize,
|
||||
opt.batchSize,
|
||||
opt.workers,
|
||||
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
|
||||
split='val',
|
||||
shuffle=True,
|
||||
seed=opt.manualSeed)
|
||||
opt.dataset='pix2pix_val'
|
||||
|
||||
valDataloader = getLoader(opt.dataset,
|
||||
opt.valDataroot,
|
||||
opt.originalSize, #opt.originalSize,
|
||||
opt.imageSize,
|
||||
opt.valBatchSize,
|
||||
opt.workers,
|
||||
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
|
||||
split='val',
|
||||
shuffle=False,
|
||||
seed=opt.manualSeed)
|
||||
|
||||
# get logger
|
||||
trainLogger = open('%s/train.log' % opt.exp, 'w')
|
||||
|
||||
|
||||
|
||||
ngf = opt.ngf
|
||||
ndf = opt.ndf
|
||||
inputChannelSize = opt.inputChannelSize
|
||||
outputChannelSize= opt.outputChannelSize
|
||||
|
||||
|
||||
# Load Pre-trained derain model
|
||||
netS=net.Segmentation()
|
||||
netG=net.Deblur_segdl()
|
||||
|
||||
#netC.apply(weights_init)
|
||||
|
||||
|
||||
netG.apply(weights_init)
|
||||
if opt.netG != '':
|
||||
state_dict_g = torch.load(opt.netG)
|
||||
new_state_dict_g = {}
|
||||
for k, v in state_dict_g.items():
|
||||
name = k[7:] # remove `module.`
|
||||
#print(k)
|
||||
new_state_dict_g[name] = v
|
||||
# load params
|
||||
netG.load_state_dict(new_state_dict_g)
|
||||
#netG.load_state_dict(torch.load(opt.netG))
|
||||
print(netG)
|
||||
netG.eval()
|
||||
#netS.apply(weights_init)
|
||||
netS.load_state_dict(torch.load('./pretrained_models/SMaps_Best.pth'))
|
||||
#netS.eval()
|
||||
netS.cuda()
|
||||
netG.cuda()
|
||||
|
||||
# Initialize testing data
|
||||
target= torch.FloatTensor(opt.batchSize, outputChannelSize, opt.imageSize, opt.imageSize)
|
||||
input = torch.FloatTensor(opt.batchSize, inputChannelSize, opt.imageSize, opt.imageSize)
|
||||
|
||||
val_target= torch.FloatTensor(opt.valBatchSize, outputChannelSize, opt.imageSize, opt.imageSize)
|
||||
val_input = torch.FloatTensor(opt.valBatchSize, inputChannelSize, opt.imageSize, opt.imageSize)
|
||||
label_d = torch.FloatTensor(opt.batchSize)
|
||||
|
||||
|
||||
target = torch.FloatTensor(opt.batchSize, outputChannelSize, opt.imageSize, opt.imageSize)
|
||||
input = torch.FloatTensor(opt.batchSize, inputChannelSize, opt.imageSize, opt.imageSize)
|
||||
depth = torch.FloatTensor(opt.batchSize, inputChannelSize, opt.imageSize, opt.imageSize)
|
||||
ato = torch.FloatTensor(opt.batchSize, inputChannelSize, opt.imageSize, opt.imageSize)
|
||||
|
||||
|
||||
val_target = torch.FloatTensor(opt.valBatchSize, outputChannelSize, opt.imageSize, opt.imageSize)
|
||||
val_input = torch.FloatTensor(opt.valBatchSize, inputChannelSize, opt.imageSize, opt.imageSize)
|
||||
val_depth = torch.FloatTensor(opt.valBatchSize, inputChannelSize, opt.imageSize, opt.imageSize)
|
||||
val_ato = torch.FloatTensor(opt.valBatchSize, inputChannelSize, opt.imageSize, opt.imageSize)
|
||||
|
||||
|
||||
target_128= torch.FloatTensor(opt.batchSize, outputChannelSize, (opt.imageSize/4), (opt.imageSize/4))
|
||||
input_128 = torch.FloatTensor(opt.batchSize, inputChannelSize, (opt.imageSize/4), (opt.imageSize/4))
|
||||
target_256= torch.FloatTensor(opt.batchSize, outputChannelSize, (opt.imageSize/2), (opt.imageSize/2))
|
||||
input_256 = torch.FloatTensor(opt.batchSize, inputChannelSize, (opt.imageSize/2), (opt.imageSize/2))
|
||||
|
||||
val_target_128= torch.FloatTensor(opt.batchSize, outputChannelSize, (opt.imageSize/4), (opt.imageSize/4))
|
||||
val_input_128 = torch.FloatTensor(opt.batchSize, inputChannelSize, (opt.imageSize/4), (opt.imageSize/4))
|
||||
val_target_256= torch.FloatTensor(opt.batchSize, outputChannelSize, (opt.imageSize/2), (opt.imageSize/2))
|
||||
val_input_256 = torch.FloatTensor(opt.batchSize, inputChannelSize, (opt.imageSize/2), (opt.imageSize/2))
|
||||
|
||||
target, input, depth, ato = target.cuda(), input.cuda(), depth.cuda(), ato.cuda()
|
||||
val_target, val_input, val_depth, val_ato = val_target.cuda(), val_input.cuda(), val_depth.cuda(), val_ato.cuda()
|
||||
|
||||
target = Variable(target, volatile=True)
|
||||
input = Variable(input,volatile=True)
|
||||
depth = Variable(depth,volatile=True)
|
||||
ato = Variable(ato,volatile=True)
|
||||
|
||||
target_128, input_128 = target_128.cuda(), input_128.cuda()
|
||||
val_target_128, val_input_128 = val_target_128.cuda(), val_input_128.cuda()
|
||||
target_256, input_256 = target_256.cuda(), input_256.cuda()
|
||||
val_target_256, val_input_256 = val_target_256.cuda(), val_input_256.cuda()
|
||||
|
||||
target_128 = Variable(target_128)
|
||||
input_128 = Variable(input_128)
|
||||
target_256 = Variable(target_256)
|
||||
input_256 = Variable(input_256)
|
||||
|
||||
label_d = Variable(label_d.cuda())
|
||||
|
||||
|
||||
|
||||
|
||||
def norm_ip(img, min, max):
|
||||
img.clamp_(min=min, max=max)
|
||||
img.add_(-min).div_(max - min)
|
||||
|
||||
return img
|
||||
|
||||
|
||||
def norm_range(t, range):
|
||||
if range is not None:
|
||||
norm_ip(t, range[0], range[1])
|
||||
else:
|
||||
norm_ip(t, -1, +1)
|
||||
|
||||
return t#norm_ip(t, t.min(), t.max())
|
||||
|
||||
# get optimizer
|
||||
optimizerG = optim.Adam(netG.parameters(), lr = opt.lrG, betas = (opt.beta1, 0.999), weight_decay=0.00005)
|
||||
|
||||
|
||||
# Begin Testing
|
||||
for epoch in range(1):
|
||||
heavy, medium, light=200, 200, 200
|
||||
for i, data in enumerate(valDataloader, 0):
|
||||
if 1:
|
||||
print('Image:'+str(i))
|
||||
import time
|
||||
data_val = data
|
||||
|
||||
t0 = time.time()
|
||||
|
||||
val_input_cpu, val_target_cpu = data_val
|
||||
|
||||
val_target_cpu, val_input_cpu = val_target_cpu.float().cuda(), val_input_cpu.float().cuda()
|
||||
val_batch_output = torch.FloatTensor(val_input.size()).fill_(0)
|
||||
|
||||
val_input.resize_as_(val_input_cpu).copy_(val_input_cpu)
|
||||
val_target=Variable(val_target_cpu, volatile=True)
|
||||
|
||||
|
||||
z=0
|
||||
|
||||
|
||||
|
||||
with torch.no_grad():
|
||||
for idx in range(val_input.size(0)):
|
||||
single_img = val_input[idx,:,:,:].unsqueeze(0)
|
||||
val_inputv = Variable(single_img, volatile=True)
|
||||
print (val_inputv.size())
|
||||
# val_inputv = val_inputv.float().cuda()
|
||||
val_inputv_256 = torch.nn.functional.interpolate(val_inputv,scale_factor=0.5)
|
||||
val_inputv_128 = torch.nn.functional.interpolate(val_inputv,scale_factor=0.25)
|
||||
|
||||
## Get de-rained results ##
|
||||
#residual_val, x_hat_val, x_hatlv128, x_hatvl256 = netG(val_inputv, val_inputv_256, val_inputv_128)
|
||||
|
||||
t1 = time.time()
|
||||
print('running time:'+str(t1 - t0))
|
||||
from PIL import Image
|
||||
|
||||
#x_hat_val = netG(val_inputv)
|
||||
#smaps_vl = netS(val_inputv)
|
||||
#S_valinput = torch.cat([smaps_vl,val_inputv],1)
|
||||
"""smaps,smaps64 = netS(val_inputv,val_inputv_256)
|
||||
S_input = torch.cat([smaps,val_inputv],1)
|
||||
x_hat_val, x_hat_val64 = netG(val_inputv,val_inputv_256,smaps,smaps64)"""
|
||||
|
||||
|
||||
#x_hatcls1,x_hatcls2,x_hatcls3,x_hatcls4,x_lst1,x_lst2,x_lst3,x_lst4 = netG(val_inputv,val_inputv_256,smaps_i,smaps_i64,class1,class2,class3,class4)
|
||||
smaps,smaps64 = netS(val_inputv,val_inputv_256)
|
||||
class1 = torch.zeros([1,1,128,128], dtype=torch.float32)
|
||||
class1[:,0,:,:] = smaps[:,0,:,:]
|
||||
class2 = torch.zeros([1,1,128,128], dtype=torch.float32)
|
||||
class2[:,0,:,:] = smaps[:,1,:,:]
|
||||
class3 = torch.zeros([1,1,128,128], dtype=torch.float32)
|
||||
class3[:,0,:,:] = smaps[:,2,:,:]
|
||||
class4 = torch.zeros([1,1,128,128], dtype=torch.float32)
|
||||
class4[:,0,:,:] = smaps[:,3,:,:]
|
||||
class_msk1 = torch.zeros([1,3,128,128], dtype=torch.float32)
|
||||
class_msk1[:,0,:,:] = smaps[:,0,:,:]
|
||||
class_msk1[:,1,:,:] = smaps[:,0,:,:]
|
||||
class_msk1[:,2,:,:] = smaps[:,0,:,:]
|
||||
class_msk2 = torch.zeros([1,3,128,128], dtype=torch.float32)
|
||||
class_msk2[:,0,:,:] = smaps[:,1,:,:]
|
||||
class_msk2[:,1,:,:] = smaps[:,1,:,:]
|
||||
class_msk2[:,2,:,:] = smaps[:,1,:,:]
|
||||
class_msk3 = torch.zeros([1,3,128,128], dtype=torch.float32)
|
||||
class_msk3[:,0,:,:] = smaps[:,2,:,:]
|
||||
class_msk3[:,1,:,:] = smaps[:,2,:,:]
|
||||
class_msk3[:,2,:,:] = smaps[:,2,:,:]
|
||||
class_msk4 = torch.zeros([1,3,128,128], dtype=torch.float32)
|
||||
class_msk4[:,0,:,:] = smaps[:,3,:,:]
|
||||
class_msk4[:,1,:,:] = smaps[:,3,:,:]
|
||||
class_msk4[:,2,:,:] = smaps[:,3,:,:]
|
||||
class1 = class1.float().cuda()
|
||||
class2 = class2.float().cuda()
|
||||
class3 = class3.float().cuda()
|
||||
class4 = class4.float().cuda()
|
||||
class_msk4 = class_msk4.float().cuda()
|
||||
class_msk3 = class_msk3.float().cuda()
|
||||
class_msk2 = class_msk2.float().cuda()
|
||||
class_msk1 = class_msk1.float().cuda()
|
||||
x_hat_val, x_hat_val64,xmask1,xmask2,xmask3,xmask4,xcl_class1,xcl_class2,xcl_class3,xcl_class4 = netG(val_inputv,val_inputv_256,smaps,class1,class2,class3,class4,val_inputv,class_msk1,class_msk2,class_msk3,class_msk4)
|
||||
# x_hat1,x_hat64,xmask1,xmask2,xmask3,xmask4,xcl_class1,xcl_class2,xcl_class3,xcl_class4 = netG(input,input_256,smaps_i,class1,class2,class3,class4,target,class_msk1,class_msk2,class_msk3,class_msk4)
|
||||
#x_hat_val.data
|
||||
#val_batch_output[idx,:,:,:].copy_(x_hat_val.data[0,1,:,:])
|
||||
# print(torch.mean(xmask1),torch.mean(xmask2),torch.mean(xmask3),torch.mean(xmask4))
|
||||
print (smaps.size())
|
||||
tensor = x_hat_val.data.cpu()
|
||||
|
||||
|
||||
### Save the de-rained results #####
|
||||
from PIL import Image
|
||||
directory = './result_all/deblurh/'#'./result_all/new_model_data/DID-MDN/'
|
||||
if not os.path.exists(directory):
|
||||
os.makedirs(directory)
|
||||
|
||||
tensor = torch.squeeze(tensor)
|
||||
tensor=norm_range(tensor, None)
|
||||
print(tensor.min(),tensor.max())
|
||||
|
||||
filename='./result_all/deblurh/'+str(i+1)+'.png'
|
||||
ndarr = tensor.mul(255).clamp(0, 255).byte().permute(1, 2, 0).numpy()
|
||||
im = Image.fromarray(ndarr)
|
||||
|
||||
im.save(filename)
|
||||
|
||||
|
||||
|
||||
@ -0,0 +1,514 @@
|
||||
from __future__ import print_function
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
import random
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.parallel
|
||||
import torch.backends.cudnn as cudnn
|
||||
cudnn.benchmark = True
|
||||
cudnn.fastest = True
|
||||
import torch.optim as optim
|
||||
import torchvision.utils as vutils
|
||||
from torch.autograd import Variable
|
||||
|
||||
from misc import *
|
||||
import models.face_fed as net
|
||||
|
||||
|
||||
from myutils.vgg16 import Vgg16
|
||||
from myutils import utils
|
||||
import pdb
|
||||
import torch.nn.functional as F
|
||||
#from PIL import Image
|
||||
from torchvision import transforms
|
||||
|
||||
import h5py
|
||||
from os import listdir
|
||||
from os.path import isfile, join
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--dataset', required=False,
|
||||
default='pix2pix_class', help='')
|
||||
parser.add_argument('--dataroot', required=False,
|
||||
default='', help='path to trn dataset')
|
||||
parser.add_argument('--valDataroot', required=False,
|
||||
default='', help='path to val dataset')
|
||||
parser.add_argument('--mode', type=str, default='B2A', help='B2A: facade, A2B: edges2shoes')
|
||||
parser.add_argument('--batchSize', type=int, default=1, help='input batch size')
|
||||
parser.add_argument('--valBatchSize', type=int, default=120, help='input batch size')
|
||||
parser.add_argument('--originalSize', type=int,
|
||||
default=175, help='the height / width of the original input image')
|
||||
parser.add_argument('--imageSize', type=int,
|
||||
default=128, help='the height / width of the cropped input image to network')
|
||||
parser.add_argument('--inputChannelSize', type=int,
|
||||
default=3, help='size of the input channels')
|
||||
parser.add_argument('--outputChannelSize', type=int,
|
||||
default=3, help='size of the output channels')
|
||||
parser.add_argument('--ngf', type=int, default=64)
|
||||
parser.add_argument('--ndf', type=int, default=64)
|
||||
parser.add_argument('--niter', type=int, default=5000, help='number of epochs to train for')
|
||||
parser.add_argument('--lrD', type=float, default=0.0002, help='learning rate, default=0.0002')
|
||||
parser.add_argument('--lrG', type=float, default=0.0002, help='learning rate, default=0.0002')
|
||||
parser.add_argument('--annealStart', type=int, default=0, help='annealing learning rate start to')
|
||||
parser.add_argument('--annealEvery', type=int, default=400, help='epoch to reaching at learning rate of 0')
|
||||
parser.add_argument('--lambdaGAN', type=float, default=0.01, help='lambdaGAN')
|
||||
parser.add_argument('--lambdaIMG', type=float, default=2.0, help='lambdaIMG')
|
||||
parser.add_argument('--poolSize', type=int, default=50, help='Buffer size for storing previously generated samples from G')
|
||||
parser.add_argument('--wd', type=float, default=0.0000, help='weight decay in D')
|
||||
parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam')
|
||||
parser.add_argument('--netG', default='', help="path to netG (to continue training)")
|
||||
parser.add_argument('--netD', default='', help="path to netD (to continue training)")
|
||||
parser.add_argument('--workers', type=int, help='number of data loading workers', default=1)
|
||||
parser.add_argument('--exp', default='sample', help='folder to output images and model checkpoints')
|
||||
parser.add_argument('--display', type=int, default=5, help='interval for displaying train-logs')
|
||||
parser.add_argument('--evalIter', type=int, default=500, help='interval for evauating(generating) images from valDataroot')
|
||||
opt = parser.parse_args()
|
||||
print(opt)
|
||||
|
||||
from scipy import signal
|
||||
import h5py
|
||||
from scipy import signal
|
||||
import random
|
||||
k_filename ='./kernel.mat'
|
||||
kfp = h5py.File(k_filename)
|
||||
kernels = np.array(kfp['kernels'])
|
||||
kernels = kernels.transpose([0,2,1])
|
||||
|
||||
create_exp_dir(opt.exp)
|
||||
opt.manualSeed = random.randint(1, 10000)
|
||||
# opt.manualSeed = 101
|
||||
random.seed(opt.manualSeed)
|
||||
torch.manual_seed(opt.manualSeed)
|
||||
torch.cuda.manual_seed_all(opt.manualSeed)
|
||||
print("Random Seed: ", opt.manualSeed)
|
||||
|
||||
# get dataloader
|
||||
opt.dataset='pix2pix_val'
|
||||
print (opt.dataroot)
|
||||
dataloader = getLoader(opt.dataset,
|
||||
opt.dataroot,
|
||||
opt.originalSize,
|
||||
opt.imageSize,
|
||||
opt.batchSize,
|
||||
opt.workers,
|
||||
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
|
||||
split='train',
|
||||
shuffle=True,
|
||||
seed=opt.manualSeed)
|
||||
|
||||
opt.dataset='pix2pix_val'
|
||||
valDataloader = getLoader(opt.dataset,
|
||||
opt.valDataroot,
|
||||
opt.originalSize,
|
||||
opt.imageSize,
|
||||
opt.valBatchSize,
|
||||
opt.workers,
|
||||
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
|
||||
split='val',
|
||||
shuffle=False,
|
||||
seed=opt.manualSeed)
|
||||
|
||||
|
||||
# get logger
|
||||
trainLogger = open('%s/train.log' % opt.exp, 'w')
|
||||
|
||||
def gradient(y):
|
||||
gradient_h=torch.abs(y[:, :, :, :-1] - y[:, :, :, 1:])
|
||||
gradient_y=torch.abs(y[:, :, :-1, :] - y[:, :, 1:, :])
|
||||
|
||||
return gradient_h, gradient_y
|
||||
|
||||
|
||||
ngf = opt.ngf
|
||||
ndf = opt.ndf
|
||||
inputChannelSize = opt.inputChannelSize
|
||||
outputChannelSize= opt.outputChannelSize
|
||||
|
||||
# get models
|
||||
netS=net.Segmentation()
|
||||
netG=net.Deblur_segdl()
|
||||
|
||||
|
||||
netS.load_state_dict(torch.load('./pretrained_models/SMaps_Best.pth'))
|
||||
|
||||
|
||||
# state_dict_g = torch.load('./face_deblur/Deblur_epoch_46.pth')
|
||||
# new_state_dict_g = {}
|
||||
# for k, v in state_dict_g.items():
|
||||
# name = k[7:] # remove `module.`
|
||||
# #print(k)
|
||||
# new_state_dict_g[name] = v
|
||||
# # load params
|
||||
# netG.load_state_dict(new_state_dict_g)
|
||||
|
||||
netG = torch.nn.DataParallel(netG)
|
||||
netS = torch.nn.DataParallel(netS)
|
||||
netG.train()
|
||||
criterionCAE = nn.L1Loss()
|
||||
criterionCAE1 = nn.SmoothL1Loss()
|
||||
|
||||
target= torch.FloatTensor(opt.batchSize, outputChannelSize, opt.imageSize, opt.imageSize)
|
||||
input = torch.FloatTensor(opt.batchSize, inputChannelSize, opt.imageSize, opt.imageSize)
|
||||
target_128= torch.FloatTensor(opt.batchSize, outputChannelSize, (opt.imageSize/4), (opt.imageSize/4))
|
||||
input_128 = torch.FloatTensor(opt.batchSize, inputChannelSize, (opt.imageSize/4), (opt.imageSize/4))
|
||||
target_256= torch.FloatTensor(opt.batchSize, outputChannelSize, (opt.imageSize/2), (opt.imageSize/2))
|
||||
input_256 = torch.FloatTensor(opt.batchSize, inputChannelSize, (opt.imageSize/2), (opt.imageSize/2))
|
||||
|
||||
|
||||
|
||||
|
||||
val_target= torch.FloatTensor(opt.valBatchSize, outputChannelSize, opt.imageSize, opt.imageSize)
|
||||
val_input = torch.FloatTensor(opt.valBatchSize, inputChannelSize, opt.imageSize, opt.imageSize)
|
||||
val_target_128= torch.FloatTensor(opt.batchSize, outputChannelSize, (opt.imageSize/4), (opt.imageSize/4))
|
||||
val_input_128 = torch.FloatTensor(opt.batchSize, inputChannelSize, (opt.imageSize/4), (opt.imageSize/4))
|
||||
val_target_256= torch.FloatTensor(opt.batchSize, outputChannelSize, (opt.imageSize/2), (opt.imageSize/2))
|
||||
val_input_256 = torch.FloatTensor(opt.batchSize, inputChannelSize, (opt.imageSize/2), (opt.imageSize/2))
|
||||
label_d = torch.FloatTensor(opt.batchSize)
|
||||
|
||||
|
||||
target = torch.FloatTensor(opt.batchSize, outputChannelSize, opt.imageSize, opt.imageSize)
|
||||
input = torch.FloatTensor(opt.batchSize, inputChannelSize, opt.imageSize, opt.imageSize)
|
||||
depth = torch.FloatTensor(opt.batchSize, inputChannelSize, opt.imageSize, opt.imageSize)
|
||||
ato = torch.FloatTensor(opt.batchSize, inputChannelSize, opt.imageSize, opt.imageSize)
|
||||
|
||||
|
||||
val_target = torch.FloatTensor(opt.valBatchSize, outputChannelSize, opt.imageSize, opt.imageSize)
|
||||
val_input = torch.FloatTensor(opt.valBatchSize, inputChannelSize, opt.imageSize, opt.imageSize)
|
||||
val_depth = torch.FloatTensor(opt.valBatchSize, inputChannelSize, opt.imageSize, opt.imageSize)
|
||||
val_ato = torch.FloatTensor(opt.valBatchSize, inputChannelSize, opt.imageSize, opt.imageSize)
|
||||
|
||||
|
||||
|
||||
|
||||
# NOTE: size of 2D output maps in the discriminator
|
||||
sizePatchGAN = 30
|
||||
real_label = 1
|
||||
fake_label = 0
|
||||
|
||||
# image pool storing previously generated samples from G
|
||||
imagePool = ImagePool(opt.poolSize)
|
||||
|
||||
# NOTE weight for L_cGAN and L_L1 (i.e. Eq.(4) in the paper)
|
||||
lambdaGAN = opt.lambdaGAN
|
||||
lambdaIMG = opt.lambdaIMG
|
||||
|
||||
netG.cuda()
|
||||
#netC.cuda()
|
||||
netS.cuda()
|
||||
criterionCAE.cuda()
|
||||
criterionCAE1.cuda()
|
||||
|
||||
|
||||
|
||||
|
||||
target, input, depth, ato = target.cuda(), input.cuda(), depth.cuda(), ato.cuda()
|
||||
val_target, val_input, val_depth, val_ato = val_target.cuda(), val_input.cuda(), val_depth.cuda(), val_ato.cuda()
|
||||
|
||||
target = Variable(target)
|
||||
input = Variable(input)
|
||||
|
||||
target_128, input_128 = target_128.cuda(), input_128.cuda()
|
||||
val_target_128, val_input_128 = val_target_128.cuda(), val_input_128.cuda()
|
||||
target_256, input_256 = target_256.cuda(), input_256.cuda()
|
||||
val_target_256, val_input_256 = val_target_256.cuda(), val_input_256.cuda()
|
||||
|
||||
target_128 = Variable(target_128)
|
||||
input_128 = Variable(input_128)
|
||||
target_256 = Variable(target_256)
|
||||
input_256 = Variable(input_256)
|
||||
# input = Variable(input,requires_grad=False)
|
||||
# depth = Variable(depth)
|
||||
ato = Variable(ato)
|
||||
|
||||
# Initialize VGG-16
|
||||
vgg = Vgg16()
|
||||
utils.init_vgg16('./models/')
|
||||
vgg.load_state_dict(torch.load(os.path.join('./models/', "vgg16.weight")))
|
||||
vgg.cuda()
|
||||
|
||||
|
||||
label_d = Variable(label_d.cuda())
|
||||
|
||||
# get randomly sampled validation images and save it
|
||||
print(len(dataloader))
|
||||
val_iter = iter(valDataloader)
|
||||
data_val = val_iter.next()
|
||||
|
||||
|
||||
val_input_cpu, val_target_cpu = data_val
|
||||
|
||||
val_target_cpu, val_input_cpu = val_target_cpu.float().cuda(), val_input_cpu.float().cuda()
|
||||
|
||||
|
||||
|
||||
val_target.resize_as_(val_target_cpu).copy_(val_target_cpu)
|
||||
val_input.resize_as_(val_input_cpu).copy_(val_input_cpu)
|
||||
|
||||
vutils.save_image(val_target, '%s/real_target.png' % opt.exp, normalize=True)
|
||||
vutils.save_image(val_input, '%s/real_input.png' % opt.exp, normalize=True)
|
||||
|
||||
|
||||
|
||||
|
||||
optimizerG = optim.Adam(netG.parameters(), lr = opt.lrG, betas = (opt.beta1, 0.999), weight_decay=0.00005)
|
||||
# NOTE training loop
|
||||
ganIterations = 0
|
||||
count = 48
|
||||
|
||||
|
||||
for epoch in range(opt.niter):
|
||||
if epoch % 19 == 0 and epoch>0:
|
||||
opt.lrG = opt.lrG/2.0
|
||||
for param_group in optimizerG.param_groups:
|
||||
param_group['lr'] = opt.lrG
|
||||
if epoch >= opt.annealStart:
|
||||
adjust_learning_rate(optimizerG, opt.lrG, epoch, None, opt.annealEvery)
|
||||
|
||||
|
||||
for i, data in enumerate(dataloader, 0):
|
||||
|
||||
input_cpu, target_cpu = data
|
||||
batch_size = target_cpu.size(0)
|
||||
b,ch,x,y = target_cpu.size()
|
||||
x1 = int((x-opt.imageSize)/2)
|
||||
y1 = int((y-opt.imageSize)/2)
|
||||
input_cpu = input_cpu.numpy()
|
||||
target_cpu = target_cpu.numpy()
|
||||
for j in range(batch_size):
|
||||
index = random.randint(0,24500)
|
||||
input_cpu[j,0,:,:]= signal.convolve(input_cpu[j,0,:,:],kernels[index,:,:],mode='same')
|
||||
input_cpu[j,1,:,:]= signal.convolve(input_cpu[j,1,:,:],kernels[index,:,:],mode='same')
|
||||
input_cpu[j,2,:,:]= signal.convolve(input_cpu[j,2,:,:],kernels[index,:,:],mode='same')
|
||||
input_cpu = input_cpu + (1.0/255.0)* np.random.normal(0,4,input_cpu.shape)
|
||||
input_cpu = input_cpu[:,:,x1:x1+opt.imageSize,y1:y1+opt.imageSize]
|
||||
target_cpu = target_cpu[:,:,x1:x1+opt.imageSize,y1:y1+opt.imageSize]
|
||||
input_cpu = torch.from_numpy(input_cpu)
|
||||
target_cpu = torch.from_numpy(target_cpu)
|
||||
|
||||
|
||||
|
||||
|
||||
target_cpu, input_cpu = target_cpu.float().cuda(), input_cpu.float().cuda()
|
||||
|
||||
|
||||
# get paired data
|
||||
target.data.resize_as_(target_cpu).copy_(target_cpu)
|
||||
input.data.resize_as_(input_cpu).copy_(input_cpu)
|
||||
input_256 = torch.nn.functional.interpolate(input,scale_factor=0.5)
|
||||
target_256 = torch.nn.functional.interpolate(target,scale_factor=0.5)
|
||||
|
||||
|
||||
|
||||
with torch.no_grad():
|
||||
smaps_i,smaps_i64 = netS(input,input_256)
|
||||
smaps,smaps64 = netS(target,target_256)
|
||||
class1 = torch.zeros([batch_size,1,128,128], dtype=torch.float32)
|
||||
class1[:,0,:,:] = smaps_i[:,0,:,:]
|
||||
class2 = torch.zeros([batch_size,1,128,128], dtype=torch.float32)
|
||||
class2[:,0,:,:] = smaps_i[:,1,:,:]
|
||||
class3 = torch.zeros([batch_size,1,128,128], dtype=torch.float32)
|
||||
class3[:,0,:,:] = smaps_i[:,2,:,:]
|
||||
class4 = torch.zeros([batch_size,1,128,128], dtype=torch.float32)
|
||||
class4[:,0,:,:] = smaps_i[:,3,:,:]
|
||||
class_msk1 = torch.zeros([batch_size,3,128,128], dtype=torch.float32)
|
||||
class_msk1[:,0,:,:] = smaps[:,0,:,:]
|
||||
class_msk1[:,1,:,:] = smaps[:,0,:,:]
|
||||
class_msk1[:,2,:,:] = smaps[:,0,:,:]
|
||||
class_msk2 = torch.zeros([batch_size,3,128,128], dtype=torch.float32)
|
||||
class_msk2[:,0,:,:] = smaps[:,1,:,:]
|
||||
class_msk2[:,1,:,:] = smaps[:,1,:,:]
|
||||
class_msk2[:,2,:,:] = smaps[:,1,:,:]
|
||||
class_msk3 = torch.zeros([batch_size,3,128,128], dtype=torch.float32)
|
||||
class_msk3[:,0,:,:] = smaps[:,2,:,:]
|
||||
class_msk3[:,1,:,:] = smaps[:,2,:,:]
|
||||
class_msk3[:,2,:,:] = smaps[:,2,:,:]
|
||||
class_msk4 = torch.zeros([batch_size,3,128,128], dtype=torch.float32)
|
||||
class_msk4[:,0,:,:] = smaps[:,3,:,:]
|
||||
class_msk4[:,1,:,:] = smaps[:,3,:,:]
|
||||
class_msk4[:,2,:,:] = smaps[:,3,:,:]
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
class1 = class1.float().cuda()
|
||||
class2 = class2.float().cuda()
|
||||
class3 = class3.float().cuda()
|
||||
class4 = class4.float().cuda()
|
||||
class_msk4 = class_msk4.float().cuda()
|
||||
class_msk3 = class_msk3.float().cuda()
|
||||
class_msk2 = class_msk2.float().cuda()
|
||||
class_msk1 = class_msk1.float().cuda()
|
||||
x_hat1,x_hat64,xmask1,xmask2,xmask3,xmask4,xcl_class1,xcl_class2,xcl_class3,xcl_class4 = netG(input,input_256,smaps_i,class1,class2,class3,class4,target,class_msk1,class_msk2,class_msk3,class_msk4)
|
||||
|
||||
x_hat = x_hat1
|
||||
|
||||
|
||||
#xeff = conf*x_hat+(1-conf)*target
|
||||
#xeff_64 = conf_64*x_hat64+(1-conf_64)*target_256
|
||||
|
||||
|
||||
#print(x_hat.size())
|
||||
if ganIterations % 2 == 0:
|
||||
netG.zero_grad() # start to update G
|
||||
|
||||
#x1 = xmask1*class_msk1*x_hat+(1-xmask1)*class_msk1*target
|
||||
#smaps_hat,smaps64_hat = netS(x_hat1,x_hat64)
|
||||
if epoch>4 or (epoch<4 and epoch%2 == 0):
|
||||
with torch.no_grad():
|
||||
smaps,smaps64 = netS(target,target_256)
|
||||
L_img_ = 0.33*criterionCAE(x_hat64, target_256) #+ 0.5*criterionCAE(smaps_hat, smaps)
|
||||
L_img_ = L_img_ + 1.2 *criterionCAE(xmask1*class_msk1*x_hat+(1-xmask1)*class_msk1*target, class_msk1*target)
|
||||
L_img_ = L_img_ + 1.2 *criterionCAE(xmask2*class_msk2*x_hat+(1-xmask2)*class_msk2*target, class_msk2*target)
|
||||
L_img_ = L_img_ + 3.6 *criterionCAE(xmask3*class_msk3*x_hat+(1-xmask3)*class_msk3*target, class_msk3*target)
|
||||
L_img_ = L_img_ + 1.2 *criterionCAE(xmask4*class_msk4*x_hat+(1-xmask4)*class_msk4*target, class_msk4*target)
|
||||
if ganIterations % (25*opt.display) == 0:
|
||||
print(L_img_.data[0])
|
||||
sys.stdout.flush()
|
||||
if ganIterations< 1000:
|
||||
lam_cmp = 1.0
|
||||
else :
|
||||
lam_cmp = 0.09
|
||||
sng = 0.00000001
|
||||
L_img_ = L_img_ - (lam_cmp/(4.0))*torch.mean(torch.log(xmask1+sng))
|
||||
L_img_ = L_img_ - (lam_cmp/(4.0))*torch.mean(torch.log(xmask2+sng))
|
||||
L_img_ = L_img_ - (lam_cmp/(4.0))*torch.mean(torch.log(xmask3+sng))
|
||||
L_img_ = L_img_ - (lam_cmp/(4.0))*torch.mean(torch.log(xmask4+sng))
|
||||
if ganIterations % (50*opt.display) == 0:
|
||||
print(L_img_.data[0])
|
||||
sys.stdout.flush()
|
||||
|
||||
#L_img_ = L_img_ + 2*criterionCAE(class_msk3*x_hat,class_msk3*target)
|
||||
# L_res = lambdaIMG * L_res_
|
||||
gradh_xhat,gradv_xhat=gradient(x_hat)
|
||||
gradh_tar,gradv_tar=gradient(target)
|
||||
gradh_xhat64,gradv_xhat64=gradient(x_hat64)
|
||||
gradh_tar64,gradv_tar64=gradient(target_256)
|
||||
L_img_ = L_img_ + 0.15*criterionCAE(gradh_xhat,gradh_tar)+ 0.15*criterionCAE(gradv_xhat,gradv_tar)+ 0.08*criterionCAE(gradh_xhat64,gradh_tar64)+0.08*criterionCAE(gradv_xhat64,gradv_tar64)
|
||||
if ganIterations % (25*opt.display) == 0:
|
||||
print(L_img_.data[0])
|
||||
print((torch.mean(torch.log(xmask1)).data),(torch.mean(torch.log(xmask2)).data),(torch.mean(xmask3).data),(torch.mean(xmask4).data))
|
||||
sys.stdout.flush()
|
||||
# L_res = lambdaIMG * L_res_
|
||||
L_img = lambdaIMG * L_img_
|
||||
|
||||
if lambdaIMG <> 0:
|
||||
#L_img.backward(retain_graph=True) # in case of current version of pytorch
|
||||
L_img.backward(retain_graph=True)
|
||||
# L_res.backward(retain_variables=True)
|
||||
|
||||
# Perceptual Loss 1
|
||||
features_content = vgg(target)
|
||||
f_xc_c = Variable(features_content[1].data, requires_grad=False)
|
||||
features_y = vgg(x_hat)
|
||||
|
||||
features_content = vgg(target_256)
|
||||
f_xc_c64 = Variable(features_content[1].data, requires_grad=False)
|
||||
features_y64 = vgg(x_hat64)
|
||||
|
||||
content_loss = 1.8*lambdaIMG* criterionCAE(features_y[1], f_xc_c) + 1.8*0.33*lambdaIMG* criterionCAE(features_y64[1], f_xc_c64)
|
||||
content_loss.backward(retain_graph=True)
|
||||
|
||||
# Perceptual Loss 2
|
||||
features_content = vgg(target)
|
||||
f_xc_c = Variable(features_content[0].data, requires_grad=False)
|
||||
features_y = vgg(x_hat)
|
||||
|
||||
features_content = vgg(target_256)
|
||||
f_xc_c64 = Variable(features_content[0].data, requires_grad=False)
|
||||
features_y64 = vgg(x_hat64)
|
||||
|
||||
content_loss1 = 1.8*lambdaIMG* criterionCAE(features_y[0], f_xc_c) + 1.8*0.33*lambdaIMG* criterionCAE(features_y64[0], f_xc_c64)
|
||||
content_loss1.backward(retain_graph=True)
|
||||
|
||||
|
||||
else:
|
||||
L_img_ = 1.2 *criterionCAE(xcl_class1, target)
|
||||
L_img_ = L_img_ + 1.2 *criterionCAE(xcl_class2, target)
|
||||
L_img_ = L_img_ + 3.6 *criterionCAE(xcl_class3, target)
|
||||
L_img_ = L_img_ + 1.2 *criterionCAE(xcl_class4, target)
|
||||
L_img = lambdaIMG * L_img_
|
||||
if lambdaIMG <> 0:
|
||||
L_img.backward(retain_graph=True)
|
||||
if ganIterations % (25*opt.display) == 0:
|
||||
print(L_img_.data[0])
|
||||
print("updating fisrt stage parameters")
|
||||
sys.stdout.flush()
|
||||
|
||||
|
||||
|
||||
if ganIterations % 2 == 0:
|
||||
optimizerG.step()
|
||||
ganIterations += 1
|
||||
|
||||
if ganIterations % opt.display == 0:
|
||||
print('[%d/%d][%d/%d] Loss: %f '
|
||||
% (epoch, opt.niter, i, len(dataloader),
|
||||
L_img.data[0]))
|
||||
sys.stdout.flush()
|
||||
trainLogger.write('%d\t%f\n' % \
|
||||
(i, L_img.data[0]))
|
||||
trainLogger.flush()
|
||||
if ganIterations % (int(len(dataloader)/2)) == 0:
|
||||
val_batch_output = torch.zeros([16,3,128,128], dtype=torch.float32)#torch.FloatTensor([10,3,128,128]).fill_(0)
|
||||
for idx in range(val_input.size(0)):
|
||||
single_img = val_input[idx,:,:,:].unsqueeze(0)
|
||||
val_inputv = Variable(single_img, volatile=True)
|
||||
with torch.no_grad():
|
||||
#smaps_vl = netS(val_inputv)
|
||||
#S_valinput = torch.cat([smaps_vl,val_inputv],1)
|
||||
index = idx+24500
|
||||
#rint(val_inputv.size())
|
||||
val_inputv = val_inputv.cpu().numpy()
|
||||
val_inputv[0,0,:,:]= signal.convolve(val_inputv[0,0,:,:],kernels[index,:,:],mode='same')
|
||||
val_inputv[0,1,:,:]= signal.convolve(val_inputv[0,1,:,:],kernels[index,:,:],mode='same')
|
||||
val_inputv[0,2,:,:]= signal.convolve(val_inputv[0,2,:,:],kernels[index,:,:],mode='same')
|
||||
val_inputv = val_inputv[:,:,x1:x1+opt.imageSize,y1:y1+opt.imageSize]
|
||||
val_inputv = val_inputv + (1.0/255.0)* np.random.normal(0,4,val_inputv.shape)
|
||||
val_inputv = torch.from_numpy(val_inputv)
|
||||
val_inputv = val_inputv.float().cuda()
|
||||
val_inputv_256 = torch.nn.functional.interpolate(val_inputv,scale_factor=0.5)
|
||||
#rint(val_inputv.size())
|
||||
smaps,smaps64 = netS(val_inputv,val_inputv_256)
|
||||
class1 = torch.zeros([1,1,128,128], dtype=torch.float32)
|
||||
class1[:,0,:,:] = smaps[:,0,:,:]
|
||||
class2 = torch.zeros([1,1,128,128], dtype=torch.float32)
|
||||
class2[:,0,:,:] = smaps[:,1,:,:]
|
||||
class3 = torch.zeros([1,1,128,128], dtype=torch.float32)
|
||||
class3[:,0,:,:] = smaps[:,2,:,:]
|
||||
class4 = torch.zeros([1,1,128,128], dtype=torch.float32)
|
||||
class4[:,0,:,:] = smaps[:,3,:,:]
|
||||
class_msk1 = torch.zeros([1,3,128,128], dtype=torch.float32)
|
||||
class_msk1[:,0,:,:] = smaps[:,0,:,:]
|
||||
class_msk1[:,1,:,:] = smaps[:,0,:,:]
|
||||
class_msk1[:,2,:,:] = smaps[:,0,:,:]
|
||||
class_msk2 = torch.zeros([1,3,128,128], dtype=torch.float32)
|
||||
class_msk2[:,0,:,:] = smaps[:,1,:,:]
|
||||
class_msk2[:,1,:,:] = smaps[:,1,:,:]
|
||||
class_msk2[:,2,:,:] = smaps[:,1,:,:]
|
||||
class_msk3 = torch.zeros([1,3,128,128], dtype=torch.float32)
|
||||
class_msk3[:,0,:,:] = smaps[:,2,:,:]
|
||||
class_msk3[:,1,:,:] = smaps[:,2,:,:]
|
||||
class_msk3[:,2,:,:] = smaps[:,2,:,:]
|
||||
class_msk4 = torch.zeros([1,3,128,128], dtype=torch.float32)
|
||||
class_msk4[:,0,:,:] = smaps[:,3,:,:]
|
||||
class_msk4[:,1,:,:] = smaps[:,3,:,:]
|
||||
class_msk4[:,2,:,:] = smaps[:,3,:,:]
|
||||
x_hat_val, x_hat_val64,xmask1,xmask2,xmask3,xmask4,xcl_class1,xcl_class2,xcl_class3,xcl_class4 = netG(val_inputv,val_inputv_256,smaps,class1,class2,class3,class4,val_inputv,class_msk1,class_msk2,class_msk3,class_msk4)
|
||||
#x_hat_val.data[0,:,:,:] = masks*x_hat_val.data[0,:,:,:]
|
||||
val_batch_output[idx,:,:,:].copy_(x_hat_val.data[0,:,:,:])
|
||||
### We use a random label here just for intermediate result visuliztion (No need to worry about the label here) ##
|
||||
|
||||
|
||||
if ganIterations % (int(len(dataloader)/2)) == 0:
|
||||
vutils.save_image(val_batch_output, '%s/generated_epoch_iter%08d.png' % \
|
||||
(opt.exp, ganIterations), normalize=True, scale_each=False)
|
||||
del val_batch_output
|
||||
if ganIterations % (int(len(dataloader)/2)) == 0:
|
||||
torch.save(netG.state_dict(), '%s/Deblur_epoch_%d.pth' % (opt.exp, count))
|
||||
#torch.save(netC.state_dict(), '%s/Deblur_first_epoch_%d.pth' % (opt.exp, count))
|
||||
count = count +1
|
||||
trainLogger.close()
|
||||
|
||||
@ -0,0 +1 @@
|
||||
|
||||
@ -0,0 +1,227 @@
|
||||
from __future__ import division
|
||||
import torch
|
||||
import math
|
||||
import random
|
||||
from PIL import Image, ImageOps
|
||||
import numpy as np
|
||||
import numbers
|
||||
import types
|
||||
|
||||
class Compose(object):
|
||||
"""Composes several transforms together.
|
||||
Args:
|
||||
transforms (List[Transform]): list of transforms to compose.
|
||||
Example:
|
||||
>>> transforms.Compose([
|
||||
>>> transforms.CenterCrop(10),
|
||||
>>> transforms.ToTensor(),
|
||||
>>> ])
|
||||
"""
|
||||
def __init__(self, transforms):
|
||||
self.transforms = transforms
|
||||
|
||||
def __call__(self, imgA, imgB):
|
||||
for t in self.transforms:
|
||||
imgA, imgB = t(imgA, imgB)
|
||||
return imgA, imgB
|
||||
|
||||
class ToTensor(object):
|
||||
"""Converts a PIL.Image or numpy.ndarray (H x W x C) in the range
|
||||
[0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0].
|
||||
"""
|
||||
def __call__(self, picA, picB):
|
||||
pics = [picA, picB]
|
||||
output = []
|
||||
for pic in pics:
|
||||
if isinstance(pic, np.ndarray):
|
||||
# handle numpy array
|
||||
img = torch.from_numpy(pic.transpose((2, 0, 1)))
|
||||
else:
|
||||
# handle PIL Image
|
||||
img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
|
||||
# PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK
|
||||
if pic.mode == 'YCbCr':
|
||||
nchannel = 3
|
||||
else:
|
||||
nchannel = len(pic.mode)
|
||||
img = img.view(pic.size[1], pic.size[0], nchannel)
|
||||
# put it from HWC to CHW format
|
||||
# yikes, this transpose takes 80% of the loading time/CPU
|
||||
img = img.transpose(0, 1).transpose(0, 2).contiguous()
|
||||
img = img.float().div(255.)
|
||||
output.append(img)
|
||||
return output[0], output[1]
|
||||
|
||||
class ToPILImage(object):
|
||||
"""Converts a torch.*Tensor of range [0, 1] and shape C x H x W
|
||||
or numpy ndarray of dtype=uint8, range[0, 255] and shape H x W x C
|
||||
to a PIL.Image of range [0, 255]
|
||||
"""
|
||||
def __call__(self, picA, picB):
|
||||
pics = [picA, picB]
|
||||
output = []
|
||||
for pic in pics:
|
||||
npimg = pic
|
||||
mode = None
|
||||
if not isinstance(npimg, np.ndarray):
|
||||
npimg = pic.mul(255).byte().numpy()
|
||||
npimg = np.transpose(npimg, (1, 2, 0))
|
||||
|
||||
if npimg.shape[2] == 1:
|
||||
npimg = npimg[:, :, 0]
|
||||
mode = "L"
|
||||
output.append(Image.fromarray(npimg, mode=mode))
|
||||
|
||||
return output[0], output[1]
|
||||
|
||||
class Normalize(object):
|
||||
"""Given mean: (R, G, B) and std: (R, G, B),
|
||||
will normalize each channel of the torch.*Tensor, i.e.
|
||||
channel = (channel - mean) / std
|
||||
"""
|
||||
def __init__(self, mean, std):
|
||||
self.mean = mean
|
||||
self.std = std
|
||||
|
||||
def __call__(self, tensorA, tensorB):
|
||||
tensors = [tensorA, tensorB]
|
||||
output = []
|
||||
for tensor in tensors:
|
||||
# TODO: make efficient
|
||||
for t, m, s in zip(tensor, self.mean, self.std):
|
||||
t.sub_(m).div_(s)
|
||||
output.append(tensor)
|
||||
return output[0], output[1]
|
||||
|
||||
class Scale(object):
|
||||
"""Rescales the input PIL.Image to the given 'size'.
|
||||
'size' will be the size of the smaller edge.
|
||||
For example, if height > width, then image will be
|
||||
rescaled to (size * height / width, size)
|
||||
size: size of the smaller edge
|
||||
interpolation: Default: PIL.Image.BILINEAR
|
||||
"""
|
||||
def __init__(self, size, interpolation=Image.BILINEAR):
|
||||
self.size = size
|
||||
self.interpolation = interpolation
|
||||
|
||||
def __call__(self, imgA, imgB):
|
||||
imgs = [imgA, imgB]
|
||||
output = []
|
||||
for img in imgs:
|
||||
# w, h = img.size
|
||||
# if (w <= h and w == self.size) or (h <= w and h == self.size):
|
||||
# output.append(img)
|
||||
# continue
|
||||
# if w < h:
|
||||
# ow = self.size
|
||||
# oh = int(self.size * h / w)
|
||||
# output.append(img.resize((ow, oh), self.interpolation))
|
||||
# continue
|
||||
# else:
|
||||
# oh = self.size
|
||||
# ow = int(self.size * w / h)
|
||||
oh = self.size
|
||||
ow = self.size
|
||||
output.append(img.resize((ow, oh), self.interpolation))
|
||||
# print output[0].size
|
||||
return output[0], output[1]
|
||||
|
||||
class CenterCrop(object):
|
||||
"""Crops the given PIL.Image at the center to have a region of
|
||||
the given size. size can be a tuple (target_height, target_width)
|
||||
or an integer, in which case the target will be of a square shape (size, size)
|
||||
"""
|
||||
def __init__(self, size):
|
||||
if isinstance(size, numbers.Number):
|
||||
self.size = (int(size), int(size))
|
||||
else:
|
||||
self.size = size
|
||||
|
||||
def __call__(self, imgA, imgB):
|
||||
imgs = [imgA, imgB]
|
||||
output = []
|
||||
for img in imgs:
|
||||
w, h = img.size
|
||||
th, tw = self.size
|
||||
x1 = int(round((w - tw) / 2.))
|
||||
y1 = int(round((h - th) / 2.))
|
||||
output.append(img.crop((x1, y1, x1 + tw, y1 + th)))
|
||||
return output[0], output[1]
|
||||
|
||||
class Pad(object):
|
||||
"""Pads the given PIL.Image on all sides with the given "pad" value"""
|
||||
def __init__(self, padding, fill=0):
|
||||
assert isinstance(padding, numbers.Number)
|
||||
assert isinstance(fill, numbers.Number) or isinstance(fill, str) or isinstance(fill, tuple)
|
||||
self.padding = padding
|
||||
self.fill = fill
|
||||
|
||||
def __call__(self, imgA, imgB):
|
||||
imgs = [imgA, imgB]
|
||||
output = []
|
||||
for img in imgs:
|
||||
output.append(ImageOps.expand(img, border=self.padding, fill=self.fill))
|
||||
return output[0], output[1]
|
||||
|
||||
class Lambda(object):
|
||||
"""Applies a lambda as a transform."""
|
||||
def __init__(self, lambd):
|
||||
assert isinstance(lambd, types.LambdaType)
|
||||
self.lambd = lambd
|
||||
|
||||
def __call__(self, imgA, imgB):
|
||||
imgs = [imgA, imgB]
|
||||
output = []
|
||||
for img in imgs:
|
||||
output.append(self.lambd(img))
|
||||
return output[0], output[1]
|
||||
|
||||
class RandomCrop(object):
|
||||
"""Crops the given PIL.Image at a random location to have a region of
|
||||
the given size. size can be a tuple (target_height, target_width)
|
||||
or an integer, in which case the target will be of a square shape (size, size)
|
||||
"""
|
||||
def __init__(self, size, padding=0):
|
||||
if isinstance(size, numbers.Number):
|
||||
self.size = (int(size), int(size))
|
||||
else:
|
||||
self.size = size
|
||||
self.padding = padding
|
||||
|
||||
def __call__(self, imgA, imgB):
|
||||
imgs = [imgA, imgB]
|
||||
output = []
|
||||
x1 = -1
|
||||
y1 = -1
|
||||
for img in imgs:
|
||||
if self.padding > 0:
|
||||
img = ImageOps.expand(img, border=self.padding, fill=0)
|
||||
|
||||
w, h = img.size
|
||||
th, tw = self.size
|
||||
if w == tw and h == th:
|
||||
output.append(img)
|
||||
continue
|
||||
|
||||
if x1 == -1 and y1 == -1:
|
||||
x1 = random.randint(0, w - tw)
|
||||
y1 = random.randint(0, h - th)
|
||||
output.append(img.crop((x1, y1, x1 + tw, y1 + th)))
|
||||
return output[0], output[1]
|
||||
|
||||
class RandomHorizontalFlip(object):
|
||||
"""Randomly horizontally flips the given PIL.Image with a probability of 0.5
|
||||
"""
|
||||
def __call__(self, imgA, imgB):
|
||||
imgs = [imgA, imgB]
|
||||
output = []
|
||||
# flag = random.random() < 0.5
|
||||
flag = random.random() < -1
|
||||
|
||||
for img in imgs:
|
||||
if flag:
|
||||
output.append(img.transpose(Image.FLIP_LEFT_RIGHT))
|
||||
else:
|
||||
output.append(img)
|
||||
return output[0], output[1]
|
||||
@ -0,0 +1,224 @@
|
||||
from __future__ import division
|
||||
import torch
|
||||
import math
|
||||
import random
|
||||
from PIL import Image, ImageOps
|
||||
import numpy as np
|
||||
import numbers
|
||||
import types
|
||||
|
||||
class Compose(object):
|
||||
"""Composes several transforms together.
|
||||
Args:
|
||||
transforms (List[Transform]): list of transforms to compose.
|
||||
Example:
|
||||
>>> transforms.Compose([
|
||||
>>> transforms.CenterCrop(10),
|
||||
>>> transforms.ToTensor(),
|
||||
>>> ])
|
||||
"""
|
||||
def __init__(self, transforms):
|
||||
self.transforms = transforms
|
||||
|
||||
def __call__(self, imgA, imgB, imgC):
|
||||
for t in self.transforms:
|
||||
imgA, imgB, imgC = t(imgA, imgB, imgC)
|
||||
return imgA, imgB, imgC
|
||||
|
||||
class ToTensor(object):
|
||||
"""Converts a PIL.Image or numpy.ndarray (H x W x C) in the range
|
||||
[0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0].
|
||||
"""
|
||||
def __call__(self, picA, picB, picC):
|
||||
pics = [picA, picB, picC]
|
||||
output = []
|
||||
for pic in pics:
|
||||
if isinstance(pic, np.ndarray):
|
||||
# handle numpy array
|
||||
img = torch.from_numpy(pic.transpose((2, 0, 1)))
|
||||
else:
|
||||
# handle PIL Image
|
||||
img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
|
||||
# PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK
|
||||
if pic.mode == 'YCbCr':
|
||||
nchannel = 3
|
||||
else:
|
||||
nchannel = len(pic.mode)
|
||||
img = img.view(pic.size[1], pic.size[0], nchannel)
|
||||
# put it from HWC to CHW format
|
||||
# yikes, this transpose takes 80% of the loading time/CPU
|
||||
img = img.transpose(0, 1).transpose(0, 2).contiguous()
|
||||
img = img.float().div(255.)
|
||||
output.append(img)
|
||||
return output[0], output[1], output[2]
|
||||
|
||||
class ToPILImage(object):
|
||||
"""Converts a torch.*Tensor of range [0, 1] and shape C x H x W
|
||||
or numpy ndarray of dtype=uint8, range[0, 255] and shape H x W x C
|
||||
to a PIL.Image of range [0, 255]
|
||||
"""
|
||||
def __call__(self, picA, picB, picC):
|
||||
pics = [picA, picB, picC]
|
||||
output = []
|
||||
for pic in pics:
|
||||
npimg = pic
|
||||
mode = None
|
||||
if not isinstance(npimg, np.ndarray):
|
||||
npimg = pic.mul(255).byte().numpy()
|
||||
npimg = np.transpose(npimg, (1, 2, 0))
|
||||
|
||||
if npimg.shape[2] == 1:
|
||||
npimg = npimg[:, :, 0]
|
||||
mode = "L"
|
||||
output.append(Image.fromarray(npimg, mode=mode))
|
||||
|
||||
return output[0], output[1], output[2]
|
||||
|
||||
class Normalize(object):
|
||||
"""Given mean: (R, G, B) and std: (R, G, B),
|
||||
will normalize each channel of the torch.*Tensor, i.e.
|
||||
channel = (channel - mean) / std
|
||||
"""
|
||||
def __init__(self, mean, std):
|
||||
self.mean = mean
|
||||
self.std = std
|
||||
|
||||
def __call__(self, tensorA, tensorB, tensorC):
|
||||
tensors = [tensorA, tensorB, tensorC]
|
||||
output = []
|
||||
for tensor in tensors:
|
||||
# TODO: make efficient
|
||||
for t, m, s in zip(tensor, self.mean, self.std):
|
||||
t.sub_(m).div_(s)
|
||||
output.append(tensor)
|
||||
return output[0], output[1], output[2]
|
||||
|
||||
class Scale(object):
|
||||
"""Rescales the input PIL.Image to the given 'size'.
|
||||
'size' will be the size of the smaller edge.
|
||||
For example, if height > width, then image will be
|
||||
rescaled to (size * height / width, size)
|
||||
size: size of the smaller edge
|
||||
interpolation: Default: PIL.Image.BILINEAR
|
||||
"""
|
||||
def __init__(self, size, interpolation=Image.BILINEAR):
|
||||
self.size = size
|
||||
self.interpolation = interpolation
|
||||
|
||||
def __call__(self, imgA, imgB, imgC):
|
||||
imgs = [imgA, imgB, imgC]
|
||||
output = []
|
||||
for img in imgs:
|
||||
w, h = img.size
|
||||
if (w <= h and w == self.size) or (h <= w and h == self.size):
|
||||
output.append(img)
|
||||
continue
|
||||
if w < h:
|
||||
ow = self.size
|
||||
oh = int(self.size * h / w)
|
||||
output.append(img.resize((ow, oh), self.interpolation))
|
||||
continue
|
||||
else:
|
||||
oh = self.size
|
||||
ow = int(self.size * w / h)
|
||||
output.append(img.resize((ow, oh), self.interpolation))
|
||||
return output[0], output[1], output[2]
|
||||
|
||||
class CenterCrop(object):
|
||||
"""Crops the given PIL.Image at the center to have a region of
|
||||
the given size. size can be a tuple (target_height, target_width)
|
||||
or an integer, in which case the target will be of a square shape (size, size)
|
||||
"""
|
||||
def __init__(self, size):
|
||||
if isinstance(size, numbers.Number):
|
||||
self.size = (int(size), int(size))
|
||||
else:
|
||||
self.size = size
|
||||
|
||||
def __call__(self, imgA, imgB, imgC):
|
||||
imgs = [imgA, imgB, imgC]
|
||||
output = []
|
||||
for img in imgs:
|
||||
w, h = img.size
|
||||
th, tw = self.size
|
||||
x1 = int(round((w - tw) / 2.))
|
||||
y1 = int(round((h - th) / 2.))
|
||||
output.append(img.crop((x1, y1, x1 + tw, y1 + th)))
|
||||
return output[0], output[1], output[2]
|
||||
|
||||
class Pad(object):
|
||||
"""Pads the given PIL.Image on all sides with the given "pad" value"""
|
||||
def __init__(self, padding, fill=0):
|
||||
assert isinstance(padding, numbers.Number)
|
||||
assert isinstance(fill, numbers.Number) or isinstance(fill, str) or isinstance(fill, tuple)
|
||||
self.padding = padding
|
||||
self.fill = fill
|
||||
|
||||
def __call__(self, imgA, imgB, imgC):
|
||||
imgs = [imgA, imgB, imgC]
|
||||
output = []
|
||||
for img in imgs:
|
||||
output.append(ImageOps.expand(img, border=self.padding, fill=self.fill))
|
||||
return output[0], output[1], output[2]
|
||||
|
||||
class Lambda(object):
|
||||
"""Applies a lambda as a transform."""
|
||||
def __init__(self, lambd):
|
||||
assert isinstance(lambd, types.LambdaType)
|
||||
self.lambd = lambd
|
||||
|
||||
def __call__(self, imgA, imgB, imgC):
|
||||
imgs = [imgA, imgB, imgC]
|
||||
output = []
|
||||
for img in imgs:
|
||||
output.append(self.lambd(img))
|
||||
return output[0], output[1], output[2]
|
||||
|
||||
class RandomCrop(object):
|
||||
"""Crops the given PIL.Image at a random location to have a region of
|
||||
the given size. size can be a tuple (target_height, target_width)
|
||||
or an integer, in which case the target will be of a square shape (size, size)
|
||||
"""
|
||||
def __init__(self, size, padding=0):
|
||||
if isinstance(size, numbers.Number):
|
||||
self.size = (int(size), int(size))
|
||||
else:
|
||||
self.size = size
|
||||
self.padding = padding
|
||||
|
||||
def __call__(self, imgA, imgB, imgC):
|
||||
imgs = [imgA, imgB, imgC]
|
||||
output = []
|
||||
x1 = -1
|
||||
y1 = -1
|
||||
for img in imgs:
|
||||
if self.padding > 0:
|
||||
img = ImageOps.expand(img, border=self.padding, fill=0)
|
||||
|
||||
w, h = img.size
|
||||
th, tw = self.size
|
||||
if w == tw and h == th:
|
||||
output.append(img)
|
||||
continue
|
||||
|
||||
if x1 == -1 and y1 == -1:
|
||||
x1 = random.randint(0, w - tw)
|
||||
y1 = random.randint(0, h - th)
|
||||
output.append(img.crop((x1, y1, x1 + tw, y1 + th)))
|
||||
return output[0], output[1], output[2]
|
||||
|
||||
class RandomHorizontalFlip(object):
|
||||
"""Randomly horizontally flips the given PIL.Image with a probability of 0.5
|
||||
"""
|
||||
def __call__(self, imgA, imgB, imgC):
|
||||
imgs = [imgA, imgB, imgC]
|
||||
output = []
|
||||
# flag = random.random() < 0.5
|
||||
flag = random.random() < -1
|
||||
|
||||
for img in imgs:
|
||||
if flag:
|
||||
output.append(img.transpose(Image.FLIP_LEFT_RIGHT))
|
||||
else:
|
||||
output.append(img)
|
||||
return output[0], output[1], output[2]
|
||||
@ -0,0 +1,225 @@
|
||||
from __future__ import division
|
||||
import torch
|
||||
import math
|
||||
import random
|
||||
from PIL import Image, ImageOps
|
||||
import numpy as np
|
||||
import numbers
|
||||
import types
|
||||
|
||||
class Compose(object):
|
||||
"""Composes several transforms together.
|
||||
Args:
|
||||
transforms (List[Transform]): list of transforms to compose.
|
||||
Example:
|
||||
>>> transforms.Compose([
|
||||
>>> transforms.CenterCrop(10),
|
||||
>>> transforms.ToTensor(),
|
||||
>>> ])
|
||||
"""
|
||||
def __init__(self, transforms):
|
||||
self.transforms = transforms
|
||||
|
||||
def __call__(self, imgA, imgB):
|
||||
for t in self.transforms:
|
||||
imgA, imgB = t(imgA, imgB)
|
||||
return imgA, imgB
|
||||
|
||||
class ToTensor(object):
|
||||
"""Converts a PIL.Image or numpy.ndarray (H x W x C) in the range
|
||||
[0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0].
|
||||
"""
|
||||
def __call__(self, picA, picB):
|
||||
pics = [picA, picB]
|
||||
output = []
|
||||
for pic in pics:
|
||||
if isinstance(pic, np.ndarray):
|
||||
# handle numpy array
|
||||
img = torch.from_numpy(pic.transpose((2, 0, 1)))
|
||||
else:
|
||||
# handle PIL Image
|
||||
img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
|
||||
# PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK
|
||||
if pic.mode == 'YCbCr':
|
||||
nchannel = 3
|
||||
else:
|
||||
nchannel = len(pic.mode)
|
||||
img = img.view(pic.size[1], pic.size[0], nchannel)
|
||||
# put it from HWC to CHW format
|
||||
# yikes, this transpose takes 80% of the loading time/CPU
|
||||
img = img.transpose(0, 1).transpose(0, 2).contiguous()
|
||||
img = img.float().div(255.)
|
||||
output.append(img)
|
||||
return output[0], output[1]
|
||||
|
||||
class ToPILImage(object):
|
||||
"""Converts a torch.*Tensor of range [0, 1] and shape C x H x W
|
||||
or numpy ndarray of dtype=uint8, range[0, 255] and shape H x W x C
|
||||
to a PIL.Image of range [0, 255]
|
||||
"""
|
||||
def __call__(self, picA, picB):
|
||||
pics = [picA, picB]
|
||||
output = []
|
||||
for pic in pics:
|
||||
npimg = pic
|
||||
mode = None
|
||||
if not isinstance(npimg, np.ndarray):
|
||||
npimg = pic.mul(255).byte().numpy()
|
||||
npimg = np.transpose(npimg, (1, 2, 0))
|
||||
|
||||
if npimg.shape[2] == 1:
|
||||
npimg = npimg[:, :, 0]
|
||||
mode = "L"
|
||||
output.append(Image.fromarray(npimg, mode=mode))
|
||||
|
||||
return output[0], output[1]
|
||||
|
||||
class Normalize(object):
|
||||
"""Given mean: (R, G, B) and std: (R, G, B),
|
||||
will normalize each channel of the torch.*Tensor, i.e.
|
||||
channel = (channel - mean) / std
|
||||
"""
|
||||
def __init__(self, mean, std):
|
||||
self.mean = mean
|
||||
self.std = std
|
||||
|
||||
def __call__(self, tensorA, tensorB):
|
||||
tensors = [tensorA, tensorB]
|
||||
output = []
|
||||
for tensor in tensors:
|
||||
# TODO: make efficient
|
||||
for t, m, s in zip(tensor, self.mean, self.std):
|
||||
t.sub_(m).div_(s)
|
||||
output.append(tensor)
|
||||
return output[0], output[1]
|
||||
|
||||
class Scale(object):
|
||||
"""Rescales the input PIL.Image to the given 'size'.
|
||||
'size' will be the size of the smaller edge.
|
||||
For example, if height > width, then image will be
|
||||
rescaled to (size * height / width, size)
|
||||
size: size of the smaller edge
|
||||
interpolation: Default: PIL.Image.BILINEAR
|
||||
"""
|
||||
def __init__(self, size, interpolation=Image.BILINEAR):
|
||||
self.size = size
|
||||
self.interpolation = interpolation
|
||||
|
||||
def __call__(self, imgA, imgB):
|
||||
imgs = [imgA, imgB]
|
||||
output = []
|
||||
for img in imgs:
|
||||
w, h = img.size
|
||||
if (w <= h and w == self.size) or (h <= w and h == self.size):
|
||||
output.append(img)
|
||||
continue
|
||||
if w < h:
|
||||
ow = self.size
|
||||
oh = int(self.size * h / w)
|
||||
output.append(img.resize((ow, oh), self.interpolation))
|
||||
continue
|
||||
else:
|
||||
oh = self.size
|
||||
ow = int(self.size * w / h)
|
||||
output.append(img.resize((ow, oh), self.interpolation))
|
||||
return output[0], output[1]
|
||||
|
||||
class CenterCrop(object):
|
||||
"""Crops the given PIL.Image at the center to have a region of
|
||||
the given size. size can be a tuple (target_height, target_width)
|
||||
or an integer, in which case the target will be of a square shape (size, size)
|
||||
"""
|
||||
def __init__(self, size):
|
||||
if isinstance(size, numbers.Number):
|
||||
self.size = (int(size), int(size))
|
||||
else:
|
||||
self.size = size
|
||||
|
||||
def __call__(self, imgA, imgB):
|
||||
imgs = [imgA, imgB]
|
||||
output = []
|
||||
for img in imgs:
|
||||
w, h = img.size
|
||||
th, tw = self.size
|
||||
x1 = int(round((w - tw) / 2.))
|
||||
y1 = int(round((h - th) / 2.))
|
||||
# output.append(img.crop((x1, y1, x1 + tw, y1 + th)))
|
||||
|
||||
output.append(img)
|
||||
|
||||
return output[0], output[1]
|
||||
|
||||
class Pad(object):
|
||||
"""Pads the given PIL.Image on all sides with the given "pad" value"""
|
||||
def __init__(self, padding, fill=0):
|
||||
assert isinstance(padding, numbers.Number)
|
||||
assert isinstance(fill, numbers.Number) or isinstance(fill, str) or isinstance(fill, tuple)
|
||||
self.padding = padding
|
||||
self.fill = fill
|
||||
|
||||
def __call__(self, imgA, imgB):
|
||||
imgs = [imgA, imgB]
|
||||
output = []
|
||||
for img in imgs:
|
||||
output.append(ImageOps.expand(img, border=self.padding, fill=self.fill))
|
||||
return output[0], output[1]
|
||||
|
||||
class Lambda(object):
|
||||
"""Applies a lambda as a transform."""
|
||||
def __init__(self, lambd):
|
||||
assert isinstance(lambd, types.LambdaType)
|
||||
self.lambd = lambd
|
||||
|
||||
def __call__(self, imgA, imgB):
|
||||
imgs = [imgA, imgB]
|
||||
output = []
|
||||
for img in imgs:
|
||||
output.append(self.lambd(img))
|
||||
return output[0], output[1]
|
||||
|
||||
class RandomCrop(object):
|
||||
"""Crops the given PIL.Image at a random location to have a region of
|
||||
the given size. size can be a tuple (target_height, target_width)
|
||||
or an integer, in which case the target will be of a square shape (size, size)
|
||||
"""
|
||||
def __init__(self, size, padding=0):
|
||||
if isinstance(size, numbers.Number):
|
||||
self.size = (int(size), int(size))
|
||||
else:
|
||||
self.size = size
|
||||
self.padding = padding
|
||||
|
||||
def __call__(self, imgA, imgB):
|
||||
imgs = [imgA, imgB]
|
||||
output = []
|
||||
x1 = -1
|
||||
y1 = -1
|
||||
for img in imgs:
|
||||
if self.padding > 0:
|
||||
img = ImageOps.expand(img, border=self.padding, fill=0)
|
||||
|
||||
w, h = img.size
|
||||
th, tw = self.size
|
||||
if w == tw and h == th:
|
||||
output.append(img)
|
||||
continue
|
||||
|
||||
if x1 == -1 and y1 == -1:
|
||||
x1 = random.randint(0, w - tw)
|
||||
y1 = random.randint(0, h - th)
|
||||
output.append(img.crop((x1, y1, x1 + tw, y1 + th)))
|
||||
return output[0], output[1]
|
||||
|
||||
class RandomHorizontalFlip(object):
|
||||
"""Randomly horizontally flips the given PIL.Image with a probability of 0.5
|
||||
"""
|
||||
def __call__(self, imgA, imgB):
|
||||
imgs = [imgA, imgB]
|
||||
output = []
|
||||
flag = random.random() < 0.5
|
||||
for img in imgs:
|
||||
if flag:
|
||||
output.append(img.transpose(Image.FLIP_LEFT_RIGHT))
|
||||
else:
|
||||
output.append(img)
|
||||
return output[0], output[1]
|
||||
@ -0,0 +1,224 @@
|
||||
from __future__ import division
|
||||
import torch
|
||||
import math
|
||||
import random
|
||||
from PIL import Image, ImageOps
|
||||
import numpy as np
|
||||
import numbers
|
||||
import types
|
||||
|
||||
class Compose(object):
|
||||
"""Composes several transforms together.
|
||||
Args:
|
||||
transforms (List[Transform]): list of transforms to compose.
|
||||
Example:
|
||||
>>> transforms.Compose([
|
||||
>>> transforms.CenterCrop(10),
|
||||
>>> transforms.ToTensor(),
|
||||
>>> ])
|
||||
"""
|
||||
def __init__(self, transforms):
|
||||
self.transforms = transforms
|
||||
|
||||
def __call__(self, imgA, imgB, imgC):
|
||||
for t in self.transforms:
|
||||
imgA, imgB, imgC = t(imgA, imgB, imgC)
|
||||
return imgA, imgB, imgC
|
||||
|
||||
class ToTensor(object):
|
||||
"""Converts a PIL.Image or numpy.ndarray (H x W x C) in the range
|
||||
[0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0].
|
||||
"""
|
||||
def __call__(self, picA, picB, picC):
|
||||
pics = [picA, picB, picC]
|
||||
output = []
|
||||
for pic in pics:
|
||||
if isinstance(pic, np.ndarray):
|
||||
# handle numpy array
|
||||
img = torch.from_numpy(pic.transpose((2, 0, 1)))
|
||||
else:
|
||||
# handle PIL Image
|
||||
img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
|
||||
# PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK
|
||||
if pic.mode == 'YCbCr':
|
||||
nchannel = 3
|
||||
else:
|
||||
nchannel = len(pic.mode)
|
||||
img = img.view(pic.size[1], pic.size[0], nchannel)
|
||||
# put it from HWC to CHW format
|
||||
# yikes, this transpose takes 80% of the loading time/CPU
|
||||
img = img.transpose(0, 1).transpose(0, 2).contiguous()
|
||||
img = img.float().div(255.)
|
||||
output.append(img)
|
||||
return output[0], output[1], output[2]
|
||||
|
||||
class ToPILImage(object):
|
||||
"""Converts a torch.*Tensor of range [0, 1] and shape C x H x W
|
||||
or numpy ndarray of dtype=uint8, range[0, 255] and shape H x W x C
|
||||
to a PIL.Image of range [0, 255]
|
||||
"""
|
||||
def __call__(self, picA, picB, picC):
|
||||
pics = [picA, picB, picC]
|
||||
output = []
|
||||
for pic in pics:
|
||||
npimg = pic
|
||||
mode = None
|
||||
if not isinstance(npimg, np.ndarray):
|
||||
npimg = pic.mul(255).byte().numpy()
|
||||
npimg = np.transpose(npimg, (1, 2, 0))
|
||||
|
||||
if npimg.shape[2] == 1:
|
||||
npimg = npimg[:, :, 0]
|
||||
mode = "L"
|
||||
output.append(Image.fromarray(npimg, mode=mode))
|
||||
|
||||
return output[0], output[1], output[2]
|
||||
|
||||
class Normalize(object):
|
||||
"""Given mean: (R, G, B) and std: (R, G, B),
|
||||
will normalize each channel of the torch.*Tensor, i.e.
|
||||
channel = (channel - mean) / std
|
||||
"""
|
||||
def __init__(self, mean, std):
|
||||
self.mean = mean
|
||||
self.std = std
|
||||
|
||||
def __call__(self, tensorA, tensorB, tensorC):
|
||||
tensors = [tensorA, tensorB, tensorC]
|
||||
output = []
|
||||
for tensor in tensors:
|
||||
# TODO: make efficient
|
||||
for t, m, s in zip(tensor, self.mean, self.std):
|
||||
t.sub_(m).div_(s)
|
||||
output.append(tensor)
|
||||
return output[0], output[1], output[2]
|
||||
|
||||
class Scale(object):
|
||||
"""Rescales the input PIL.Image to the given 'size'.
|
||||
'size' will be the size of the smaller edge.
|
||||
For example, if height > width, then image will be
|
||||
rescaled to (size * height / width, size)
|
||||
size: size of the smaller edge
|
||||
interpolation: Default: PIL.Image.BILINEAR
|
||||
"""
|
||||
def __init__(self, size, interpolation=Image.BILINEAR):
|
||||
self.size = size
|
||||
self.interpolation = interpolation
|
||||
|
||||
def __call__(self, imgA, imgB, imgC):
|
||||
imgs = [imgA, imgB, imgC]
|
||||
output = []
|
||||
for img in imgs:
|
||||
w, h = img.size
|
||||
if (w <= h and w == self.size) or (h <= w and h == self.size):
|
||||
output.append(img)
|
||||
continue
|
||||
if w < h:
|
||||
ow = self.size
|
||||
oh = int(self.size * h / w)
|
||||
output.append(img.resize((ow, oh), self.interpolation))
|
||||
continue
|
||||
else:
|
||||
oh = self.size
|
||||
ow = int(self.size * w / h)
|
||||
output.append(img.resize((ow, oh), self.interpolation))
|
||||
return output[0], output[1], output[2]
|
||||
|
||||
class CenterCrop(object):
|
||||
"""Crops the given PIL.Image at the center to have a region of
|
||||
the given size. size can be a tuple (target_height, target_width)
|
||||
or an integer, in which case the target will be of a square shape (size, size)
|
||||
"""
|
||||
def __init__(self, size):
|
||||
if isinstance(size, numbers.Number):
|
||||
self.size = (int(size), int(size))
|
||||
else:
|
||||
self.size = size
|
||||
|
||||
def __call__(self, imgA, imgB, imgC):
|
||||
imgs = [imgA, imgB, imgC]
|
||||
output = []
|
||||
for img in imgs:
|
||||
w, h = img.size
|
||||
th, tw = self.size
|
||||
x1 = int(round((w - tw) / 2.))
|
||||
y1 = int(round((h - th) / 2.))
|
||||
output.append(img.crop((x1, y1, x1 + tw, y1 + th)))
|
||||
return output[0], output[1], output[2]
|
||||
|
||||
class Pad(object):
|
||||
"""Pads the given PIL.Image on all sides with the given "pad" value"""
|
||||
def __init__(self, padding, fill=0):
|
||||
assert isinstance(padding, numbers.Number)
|
||||
assert isinstance(fill, numbers.Number) or isinstance(fill, str) or isinstance(fill, tuple)
|
||||
self.padding = padding
|
||||
self.fill = fill
|
||||
|
||||
def __call__(self, imgA, imgB, imgC):
|
||||
imgs = [imgA, imgB, imgC]
|
||||
output = []
|
||||
for img in imgs:
|
||||
output.append(ImageOps.expand(img, border=self.padding, fill=self.fill))
|
||||
return output[0], output[1], output[2]
|
||||
|
||||
class Lambda(object):
|
||||
"""Applies a lambda as a transform."""
|
||||
def __init__(self, lambd):
|
||||
assert isinstance(lambd, types.LambdaType)
|
||||
self.lambd = lambd
|
||||
|
||||
def __call__(self, imgA, imgB, imgC):
|
||||
imgs = [imgA, imgB, imgC]
|
||||
output = []
|
||||
for img in imgs:
|
||||
output.append(self.lambd(img))
|
||||
return output[0], output[1], output[2]
|
||||
|
||||
class RandomCrop(object):
|
||||
"""Crops the given PIL.Image at a random location to have a region of
|
||||
the given size. size can be a tuple (target_height, target_width)
|
||||
or an integer, in which case the target will be of a square shape (size, size)
|
||||
"""
|
||||
def __init__(self, size, padding=0):
|
||||
if isinstance(size, numbers.Number):
|
||||
self.size = (int(size), int(size))
|
||||
else:
|
||||
self.size = size
|
||||
self.padding = padding
|
||||
|
||||
def __call__(self, imgA, imgB, imgC):
|
||||
imgs = [imgA, imgB, imgC]
|
||||
output = []
|
||||
x1 = -1
|
||||
y1 = -1
|
||||
for img in imgs:
|
||||
if self.padding > 0:
|
||||
img = ImageOps.expand(img, border=self.padding, fill=0)
|
||||
|
||||
w, h = img.size
|
||||
th, tw = self.size
|
||||
if w == tw and h == th:
|
||||
output.append(img)
|
||||
continue
|
||||
|
||||
if x1 == -1 and y1 == -1:
|
||||
x1 = random.randint(0, w - tw)
|
||||
y1 = random.randint(0, h - th)
|
||||
output.append(img.crop((x1, y1, x1 + tw, y1 + th)))
|
||||
return output[0], output[1], output[2]
|
||||
|
||||
class RandomHorizontalFlip(object):
|
||||
"""Randomly horizontally flips the given PIL.Image with a probability of 0.5
|
||||
"""
|
||||
def __call__(self, imgA, imgB, imgC):
|
||||
imgs = [imgA, imgB, imgC]
|
||||
output = []
|
||||
# flag = random.random() < 0.5
|
||||
flag = random.random() < -1
|
||||
|
||||
for img in imgs:
|
||||
if flag:
|
||||
output.append(img.transpose(Image.FLIP_LEFT_RIGHT))
|
||||
else:
|
||||
output.append(img)
|
||||
return output[0], output[1], output[2]
|
||||