diff --git a/datasets/__init__.py b/datasets/__init__.py new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/datasets/__init__.py @@ -0,0 +1 @@ + diff --git a/datasets/__init__.pyc b/datasets/__init__.pyc new file mode 100644 index 0000000..13923cb Binary files /dev/null and b/datasets/__init__.pyc differ diff --git a/datasets/classification.py b/datasets/classification.py new file mode 100644 index 0000000..b2d4534 --- /dev/null +++ b/datasets/classification.py @@ -0,0 +1,75 @@ +import torch.utils.data as data +from PIL import Image +import os +import os.path +import numpy as np +import h5py +import glob + +IMG_EXTENSIONS = [ + '.jpg', '.JPG', '.jpeg', '.JPEG', + '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', +] + +def is_image_file(filename): + return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) + +def make_dataset(dir): + images = [] + if not os.path.isdir(dir): + raise Exception('Check dataroot') + for root, _, fnames in sorted(os.walk(dir)): + for fname in fnames: + if is_image_file(fname): + path = os.path.join(dir, fname) + item = path + images.append(item) + return images + +def default_loader(path): + return Image.open(path).convert('RGB') + +class classification(data.Dataset): + def __init__(self, root, transform=None, loader=default_loader, seed=None): + # imgs = make_dataset(root) + # if len(imgs) == 0: + # raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n" + # "Supported image extensions are: " + ",".join(IMG_EXTENSIONS))) + self.root = root + # self.imgs = imgs + self.transform = transform + self.loader = loader + + if seed is not None: + np.random.seed(seed) + + def __getitem__(self, _): + index = np.random.randint(1,self.__len__()) + # path = self.imgs[index] + # img = self.loader(path) + #img = img.resize((w, h), Image.BILINEAR) + + + + file_name=self.root+'/'+str(index)+'.h5' + f=h5py.File(file_name,'r') + + haze_image=f['haze'][:] + label=f['label'][:] + label=label.mean()-1 + + haze_image=np.swapaxes(haze_image,0,2) + haze_image=np.swapaxes(haze_image,1,2) + + + # if self.transform is not None: + # # NOTE preprocessing for each pair of images + # imgA, imgB = self.transform(imgA, imgB) + return haze_image, label + + def __len__(self): + train_list=glob.glob(self.root+'/*h5') + # print len(train_list) + return len(train_list) + + # return len(self.imgs) diff --git a/datasets/pix2pix.py b/datasets/pix2pix.py new file mode 100644 index 0000000..ac7e664 --- /dev/null +++ b/datasets/pix2pix.py @@ -0,0 +1,133 @@ +import torch.utils.data as data + +from PIL import Image +import os +import os.path +import numpy as np + + + +IMG_EXTENSIONS = [ + '.jpg', '.JPG', '.jpeg', '.JPEG', + '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '', +] + +def is_image_file(filename): + return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) + +def make_dataset(dir): + images = [] + if not os.path.isdir(dir): + raise Exception('Check dataroot') + for root, _, fnames in sorted(os.walk(dir)): + for fname in fnames: + if is_image_file(fname): + path = os.path.join(dir, fname) + item = path + images.append(item) + return images + +def default_loader(path): + return Image.open(path).convert('RGB') + +class pix2pix(data.Dataset): + def __init__(self, root, transform=None, loader=default_loader, seed=None): + imgs = make_dataset(root) + if len(imgs) == 0: + raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n" + "Supported image extensions are: " + ",".join(IMG_EXTENSIONS))) + self.root = root + self.imgs = imgs + self.transform = transform + self.loader = loader + + if seed is not None: + np.random.seed(seed) + + def __getitem__(self, index): + # index = np.random.randint(self.__len__(), size=1)[0] + # index = np.random.randint(self.__len__(), size=1)[0]+1 + # index = np.random.randint(self.__len__(), size=1)[0] + + # index_folder = np.random.randint(1,4) + index_folder = np.random.randint(0,1) + + index_sub = np.random.randint(2, 5) + + label=index_folder + + + if index_folder==0: + path='/home/openset/Desktop/derain2018/facades/training2'+'/'+str(index)+'.jpg' + + + + if index_folder==1: + if index_sub<4: + path='/home/openset/Desktop/derain2018/facades/DB_Rain_new/Rain_Heavy/train2018new'+'/'+str(index)+'.jpg' + if index_sub==4: + index = np.random.randint(0,400) + path='/home/openset/Desktop/derain2018/facades/DB_Rain/Rain_Heavy/trainnew'+'/'+str(index)+'.jpg' + + if index_folder==2: + if index_sub<4: + path='/home/openset/Desktop/derain2018/facades/DB_Rain_new/Rain_Medium/train2018new'+'/'+str(index)+'.jpg' + if index_sub==4: + index = np.random.randint(0,400) + path='/home/openset/Desktop/derain2018/facades/DB_Rain/Rain_Medium/trainnew'+'/'+str(index)+'.jpg' + + if index_folder==3: + if index_sub<4: + path='/home/openset/Desktop/derain2018/facades/DB_Rain_new/Rain_Light/train2018new'+'/'+str(index)+'.jpg' + if index_sub==4: + index = np.random.randint(0,400) + path='/home/openset/Desktop/derain2018/facades/DB_Rain/Rain_Light/trainnew'+'/'+str(index)+'.jpg' + + + + # img = self.loader(path) + + img = self.loader(path) + + # NOTE: img -> PIL Image + # w, h = img.size + # w, h = 1024, 512 + # img = img.resize((w, h), Image.BILINEAR) + # pix = np.array(I) + # + # r = 16 + # eps = 1 + # + # I = img.crop((0, 0, w/2, h)) + # pix = np.array(I) + # base=guidedfilter(pix, pix, r, eps) + # base = PIL.Image.fromarray(numpy.uint8(base)) + # + # + # + # imgA=base + # imgB=I-base + # imgC = img.crop((w/2, 0, w, h)) + + w, h = img.size + # img = img.resize((w, h), Image.BILINEAR) + + + # NOTE: split a sample into imgA and imgB + imgA = img.crop((0, 0, w/2, h)) + # imgC = img.crop((2*w/3, 0, w, h)) + + imgB = img.crop((w/2, 0, w, h)) + + + if self.transform is not None: + # NOTE preprocessing for each pair of images + # imgA, imgB, imgC = self.transform(imgA, imgB, imgC) + imgA, imgB = self.transform(imgA, imgB) + + return imgA, imgB, label + + def __len__(self): + # return 679 + print(len(self.imgs)) + return len(self.imgs) diff --git a/datasets/pix2pix.pyc b/datasets/pix2pix.pyc new file mode 100644 index 0000000..b5afb46 Binary files /dev/null and b/datasets/pix2pix.pyc differ diff --git a/datasets/pix2pix2.py b/datasets/pix2pix2.py new file mode 100644 index 0000000..943aec4 --- /dev/null +++ b/datasets/pix2pix2.py @@ -0,0 +1,109 @@ +import torch.utils.data as data +from PIL import Image +import os +import os.path +import numpy as np +import h5py +import glob +import scipy.ndimage +IMG_EXTENSIONS = [ + '.jpg', '.JPG', '.jpeg', '.JPEG', + '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', +] + +def is_image_file(filename): + return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) + +def make_dataset(dir): + images = [] + if not os.path.isdir(dir): + raise Exception('Check dataroot') + for root, _, fnames in sorted(os.walk(dir)): + for fname in fnames: + if is_image_file(fname): + path = os.path.join(dir, fname) + item = path + images.append(item) + return images + +def default_loader(path): + return Image.open(path).convert('RGB') + +class pix2pix(data.Dataset): + def __init__(self, root, transform=None, loader=default_loader, seed=None): + # imgs = make_dataset(root) + # if len(imgs) == 0: + # raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n" + # "Supported image extensions are: " + ",".join(IMG_EXTENSIONS))) + self.root = root + # self.imgs = imgs + self.transform = transform + self.loader = loader + + if seed is not None: + np.random.seed(seed) + + def __getitem__(self, _): + index = np.random.randint(1,self.__len__()) + # index = np.random.randint(self.__len__(), size=1)[0] + + # path = self.imgs[index] + # img = self.loader(path) + #img = img.resize((w, h), Image.BILINEAR) + + + + file_name=self.root+'/'+str(index)+'.h5' + f=h5py.File(file_name,'r') + + haze_image=f['haze'][:] + trans_map=f['trans'][:] + ato_map=f['ato'][:] + GT=f['gt'][:] + + + + haze_image=np.swapaxes(haze_image,0,2) + trans_map=np.swapaxes(trans_map,0,2) + ato_map=np.swapaxes(ato_map,0,2) + GT=np.swapaxes(GT,0,2) + + + + haze_image=np.swapaxes(haze_image,1,2) + trans_map=np.swapaxes(trans_map,1,2) + ato_map=np.swapaxes(ato_map,1,2) + GT=np.swapaxes(GT,1,2) + + # if np.random.uniform()>0.5: + # haze_image=np.flip(haze_image,2).copy() + # GT = np.flip(GT, 2).copy() + # trans_map=np.flip(trans_map, 2).copy() + # if np.random.uniform()>0.5: + # angle = np.random.uniform(-10, 10) + # haze_image=scipy.ndimage.interpolation.rotate(haze_image, angle) + # GT = scipy.ndimage.interpolation.rotate(GT, angle) + + # if np.random.uniform()>0.5: + # angle = np.random.uniform(-10, 10) + # haze_image=scipy.ndimage.interpolation.rotate(haze_image, angle) + # GT = scipy.ndimage.interpolation.rotate(GT, angle) + + # if np.random.uniform()>0.5: + # std = np.random.uniform(0.2, 1.2) + # haze_image = scipy.ndimage.filters.gaussian_filter(haze_image, std,mode='constant') + + # haze_image=np.random.uniform(-10/5000,10/5000,size=haze_image.shape) + # haze_image = np.maximum(0, haze_image) + + # if self.transform is not None: + # # NOTE preprocessing for each pair of images + # imgA, imgB = self.transform(imgA, imgB) + return haze_image, GT, trans_map, ato_map + + def __len__(self): + train_list=glob.glob(self.root+'/*h5') + # print len(train_list) + return len(train_list) + + # return len(self.imgs) diff --git a/datasets/pix2pix_class.py b/datasets/pix2pix_class.py new file mode 100644 index 0000000..23ef4b6 --- /dev/null +++ b/datasets/pix2pix_class.py @@ -0,0 +1,88 @@ +import torch.utils.data as data + +from PIL import Image +import os +import os.path +import numpy as np + + +IMG_EXTENSIONS = [ + '.jpg', '.JPG', '.jpeg', '.JPEG', + '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '', +] + +def is_image_file(filename): + return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) + +def make_dataset(dir): + images = [] + if not os.path.isdir(dir): + raise Exception('Check dataroot') + for root, _, fnames in sorted(os.walk(dir)): + for fname in fnames: + if is_image_file(fname): + path = os.path.join(dir, fname) + item = path + images.append(item) + return images + +def default_loader(path): + return Image.open(path).convert('RGB') + +class pix2pix(data.Dataset): + def __init__(self, root, transform=None, loader=default_loader, seed=None): + imgs = make_dataset(root) + if len(imgs) == 0: + raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n" + "Supported image extensions are: " + ",".join(IMG_EXTENSIONS))) + self.root = root + self.imgs = imgs + self.transform = transform + self.loader = loader + + if seed is not None: + np.random.seed(seed) + + def __getitem__(self, index): + + index_sub = np.random.randint(0, 3) + label=index_sub + + + + if index_sub==0: + index = np.random.randint(0,4000) + path='/media/openset/Z/derain2018/facades/DID-MDN-training/Rain_Heavy/train2018new'+'/'+str(index)+'.jpg' + + + if index_sub==1: + index = np.random.randint(0,4000) + path='/media/openset/Z/derain2018/facades/DID-MDN-training/Rain_Medium/train2018new'+'/'+str(index)+'.jpg' + + if index_sub==2: + index = np.random.randint(0,4000) + path='/media/openset/Z/derain2018/facades/DID-MDN-training/Rain_Light/train2018new'+'/'+str(index)+'.jpg' + + + + img = self.loader(path) + + + w, h = img.size + + + # NOTE: split a sample into imgA and imgB + imgA = img.crop((0, 0, w/2, h)) + imgB = img.crop((w/2, 0, w, h)) + + + if self.transform is not None: + # NOTE preprocessing for each pair of images + imgA, imgB = self.transform(imgA, imgB) + + return imgA, imgB, label + + def __len__(self): + # return 679 + # print(len(self.imgs)) + return len(self.imgs) diff --git a/datasets/pix2pix_class.pyc b/datasets/pix2pix_class.pyc new file mode 100644 index 0000000..347b305 Binary files /dev/null and b/datasets/pix2pix_class.pyc differ diff --git a/datasets/pix2pix_val.py b/datasets/pix2pix_val.py new file mode 100644 index 0000000..6fe0070 --- /dev/null +++ b/datasets/pix2pix_val.py @@ -0,0 +1,139 @@ +import torch.utils.data as data + +from PIL import Image +import os +import os.path +import numpy as np +import cv2 + +# from guidedfilter import guidedfilter +# import guidedfilter.guidedfilter as guidedfilter + + + + +IMG_EXTENSIONS = [ + '.jpg', '.JPG', '.jpeg', '.JPEG', + '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '', +] + +def is_image_file(filename): + return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) + +def make_dataset(dir): + images = [] + if not os.path.isdir(dir): + raise Exception('Check dataroot') + for root, _, fnames in sorted(os.walk(dir)): + for fname in fnames: + if is_image_file(fname): + path = os.path.join(dir, fname) + item = path + images.append(item) + return images + +def default_loader(path): + return Image.open(path).convert('RGB') + +class pix2pix_val(data.Dataset): + def __init__(self, root, transform=None, loader=default_loader, seed=None): + imgs = make_dataset(root) + if len(imgs) == 0: + raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n" + "Supported image extensions are: " + ",".join(IMG_EXTENSIONS))) + self.root = root + self.imgs = imgs + self.transform = transform + self.loader = loader + + if seed is not None: + np.random.seed(seed) + + def __getitem__(self, index): + # index = np.random.randint(self.__len__(), size=1)[0] + # index = np.random.randint(self.__len__(), size=1)[0] + + path = self.imgs[index] + + if index>75000: + index = index%75000+10 + path=self.root+'/'+'%06d'%(index+1)+'.png' + + index_folder = np.random.randint(0,4) + label=index_folder + + # path='/home/openset/Desktop/derain2018/facades/DB_Rain_test/Rain_Heavy/test2018'+'/'+str(index)+'.jpg' + img = self.loader(path) + + # # NOTE: img -> PIL Image + # w, h = img.size + # w, h = 1024, 512 + # img = img.resize((w, h), Image.BILINEAR) + # # NOTE: split a sample into imgA and imgB + # imgA = img.crop((0, 0, w/2, h)) + # imgB = img.crop((w/2, 0, w, h)) + # if self.transform is not None: + # # NOTE preprocessing for each pair of images + # imgA, imgB = self.transform(imgA, imgB) + # return imgA, imgB + + + # w, h = 1536, 512 + # img = img.resize((w, h), Image.BILINEAR) + # + # + # # NOTE: split a sample into imgA and imgB + # imgA = img.crop((0, 0, w/3, h)) + # imgC = img.crop((2*w/3, 0, w, h)) + # + # imgB = img.crop((w/3, 0, 2*w/3, h)) + + # w, h = 1024, 512 + # img = img.resize((w, h), Image.BILINEAR) + # + # r = 16 + # eps = 1 + # + # # I = img.crop((0, 0, w/2, h)) + # # pix = np.array(I) + # # print + # # base[idx,:,:,:]=guidedfilter(pix[], pix[], r, eps) + # # base[]=guidedfilter(pix[], pix[], r, eps) + # # base[]=guidedfilter(pix[], pix[], r, eps) + # + # + # # base = PIL.Image.fromarray(numpy.uint8(base)) + # + # # NOTE: split a sample into imgA and imgB + # imgA = img.crop((0, 0, w/3, h)) + # imgC = img.crop((2*w/3, 0, w, h)) + # + # imgB = img.crop((w/3, 0, 2*w/3, h)) + # imgA=base + # imgB=I-base + # imgC = img.crop((w/2, 0, w, h)) + w, h = img.size + #print(w,h) + # w, h = 586*2, 586 + + # img = img.resize((w, h), Image.BILINEAR) + + + # NOTE: split a sample into imgA and imgB + imgA = img.crop((0, 0, w/2, h)) + # imgC = img.crop((2*w/3, 0, w, h)) + + imgB = img.crop((w/2, 0, w, h)) + + + + if self.transform is not None: + # NOTE preprocessing for each pair of images + # imgA, imgB, imgC = self.transform(imgA, imgB, imgC) + imgA, imgB = self.transform(img, img) + + + return imgA, imgB + + def __len__(self): + return len(self.imgs) diff --git a/datasets/pix2pix_val.pyc b/datasets/pix2pix_val.pyc new file mode 100644 index 0000000..6e869d0 Binary files /dev/null and b/datasets/pix2pix_val.pyc differ diff --git a/datasets/util.py b/datasets/util.py new file mode 100644 index 0000000..3394ac3 --- /dev/null +++ b/datasets/util.py @@ -0,0 +1,1103 @@ +# -*- coding: utf-8 -*- +# Copyright (C) 2015-2017 by Brendt Wohlberg +# All rights reserved. BSD 3-clause License. +# This file is part of the SPORCO package. Details of the copyright +# and user license can be found in the 'LICENSE.txt' file distributed +# with the package. + +"""Utility functions""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from future.utils import PY2 +from builtins import range +from builtins import object + +from timeit import default_timer as timer +import os +import imghdr +import io +import platform +import multiprocessing as mp +import itertools +import collections +import socket +if PY2: + import urllib2 as urlrequest + import urllib2 as urlerror +else: + import urllib.request as urlrequest + import urllib.error as urlerror +import numpy as np +from scipy import misc +import scipy.ndimage.interpolation as sni + +import sporco.linalg as sla +import sporco.plot as spl + +__author__ = """Brendt Wohlberg """ + + +import warnings + +def plot(*args, **kwargs): + warnings.warn("sporco.util.plot is deprecated: use sporco.plot.plot", + PendingDeprecationWarning) + return spl.plot(*args, **kwargs) + +def surf(*args, **kwargs): + warnings.warn("sporco.util.surf is deprecated: use sporco.plot.surf", + PendingDeprecationWarning) + return spl.surf(*args, **kwargs) + +def imview(*args, **kwargs): + warnings.warn("sporco.util.imview is deprecated: use sporco.plot.imview", + PendingDeprecationWarning) + return spl.imview(*args, **kwargs) + + + +# Python 2/3 unicode literal compatibility +if PY2: + import codecs + def u(x): + """Python 2/3 compatible definition of utf8 literals""" + return x.decode('utf8') +else: + def u(x): + """Python 2/3 compatible definition of utf8 literals""" + return x + + + + +def ntpl2array(ntpl): + """ + Convert a :func:`collections.namedtuple` object to a :class:`numpy.ndarray` + object that can be saved using :func:`numpy.savez`. + + Parameters + ---------- + ntpl : collections.namedtuple object + Named tuple object to be converted to ndarray + + Returns + ------- + arr : ndarray + Array representation of input named tuple + """ + + return np.asarray((np.vstack([col for col in ntpl]), ntpl._fields, + ntpl.__class__.__name__)) + + + +def array2ntpl(arr): + """ + Convert a :class:`numpy.ndarray` object constructed by :func:`ntpl2array` + back to the original :func:`collections.namedtuple` representation. + + Parameters + ---------- + arr : ndarray + Array representation of named tuple constructed by :func:`ntpl2array` + + Returns + ------- + ntpl : collections.namedtuple object + Named tuple object with the same name and fields as the original named + typle object provided to :func:`ntpl2array` + """ + + cls = collections.namedtuple(arr[2], arr[1]) + return cls(*tuple(arr[0])) + + + +def transpose_ntpl_list(lst): + """Transpose a list of named tuple objects (of the same type) into a + named tuple of lists. + + Parameters + ---------- + lst : list of collections.namedtuple object + List of named tuple objects of the same type + + Returns + ------- + ntpl : collections.namedtuple object + Named tuple object with each entry consisting of a list of the + corresponding fields of the named tuple objects in list ``lst`` + """ + + cls = collections.namedtuple(lst[0].__class__.__name__, lst[0]._fields) + if len(lst) == 0: + return None + else: + return cls(*[[lst[k][l] for k in range(len(lst))] + for l in range(len(lst[0]))]) + + + +def solve_status_str(hdrtxt, fwiter=4, fpothr=2): + """Construct header and format details for status display of an + iterative solver. + + Parameters + ---------- + hdrtxt : tuple of strings + Tuple of field header strings + fwiter : int, optional (default 4) + Number of characters in iteration count integer field + fpothr : int, optional (default 2) + Precision of other float field + + Returns + ------- + hdrstr : string + Complete header string + fmtstr : string + Complete print formatting string for numeric values + nsep : integer + Number of characters in separator string + """ + + # Field width for all fields other than first depends on precision + fwothr = fpothr + 6 + # Construct header string from hdrtxt list of column headers + hdrstr = ("%-*s" % (fwiter+2, hdrtxt[0])) + \ + ((("%%-%ds " % (fwothr+1)) * (len(hdrtxt)-1)) % \ + tuple(hdrtxt[1:])) + # Construct iteration status format string + fmtstr = ("%%%dd" % (fwiter)) + (((" %%%d.%de" % (fwothr, fpothr)) * \ + (len(hdrtxt)-1))) + # Compute length of separator string + nsep = fwiter + (fwothr + 2)*(len(hdrtxt)-1) + + return hdrstr, fmtstr, nsep + + + +def tiledict(D, sz=None): + """Construct an image allowing visualization of dictionary content. + + Parameters + ---------- + D : array_like + Dictionary matrix/array. + sz : tuple + Size of each block in dictionary. + + Returns + ------- + im : ndarray + Image tiled with dictionary entries. + """ + + # Handle standard 2D (non-convolutional) dictionary + if D.ndim == 2: + D = D.reshape((sz + (D.shape[1],))) + sz = None + dsz = D.shape + + if D.ndim == 4: + axisM = 3 + szni = 3 + else: + axisM = 2 + szni = 2 + + # Construct dictionary atom size vector if not provided + if sz is None: + sz = np.tile(np.array(dsz[0:2]).reshape([2, 1]), (1, D.shape[axisM])) + else: + sz = np.array(sum(tuple((x[0:2],) * x[szni] for x in sz), ())).T + + # Compute the maximum atom dimensions + mxsz = np.amax(sz, 1) + + # Shift and scale values to [0, 1] + D = D - D.min() + D = D / D.max() + + # Construct tiled image + N = dsz[axisM] + Vr = int(np.floor(np.sqrt(N))) + Vc = int(np.ceil(N/float(Vr))) + if D.ndim == 4: + im = np.ones((Vr*mxsz[0] + Vr-1, Vc*mxsz[1] + Vc-1, dsz[2])) + else: + im = np.ones((Vr*mxsz[0] + Vr-1, Vc*mxsz[1] + Vc-1)) + k = 0 + for l in range(0, Vr): + for m in range(0, Vc): + r = mxsz[0]*l + l + c = mxsz[1]*m + m + if D.ndim == 4: + im[r:(r+sz[0, k]), c:(c+sz[1, k]), :] = D[0:sz[0, k], + 0:sz[1, k], :, k] + else: + im[r:(r+sz[0, k]), c:(c+sz[1, k])] = D[0:sz[0, k], + 0:sz[1, k], k] + k = k + 1 + if k >= N: + break + if k >= N: + break + + return im + + + +def imageblocks(imgs, blksz): + """Extract all blocks of specified size from an image or list of images. + + Parameters + ---------- + imgs: array_like or tuple of array_like + Single image or tuple of images from which to extract blocks + blksz : tuple of two ints + Size of the blocks + + Returns + ------- + blks : ndarray + Array of extracted blocks + """ + + # See http://stackoverflow.com/questions/16774148 and + # sklearn.feature_extraction.image.extract_patches_2d + if not isinstance(imgs, tuple): + imgs = (imgs,) + + blks = np.array([]).reshape(blksz + (0,)) + for im in imgs: + Nr, Nc = im.shape + nr, nc = blksz + shape = (Nr-nr+1, Nc-nc+1, nr, nc) + strides = im.itemsize*np.array([Nc, 1, Nc, 1]) + sb = np.lib.stride_tricks.as_strided(np.ascontiguousarray(im), + shape=shape, strides=strides) + sb = np.ascontiguousarray(sb) + sb.shape = (-1, nr, nc) + sb = np.rollaxis(sb, 0, 3) + blks = np.dstack((blks, sb)) + + return blks + + + +def rgb2gray(rgb): + """Convert an RGB image (or images) to grayscale. + + Parameters + ---------- + rgb : ndarray + RGB image as Nr x Nc x 3 or Nr x Nc x 3 x K array + + Returns + ------- + gry : ndarray + Grayscale image as Nr x Nc or Nr x Nc x K array + """ + + w = sla.atleast_nd(rgb.ndim, np.array([0.299, 0.587, 0.144], + dtype=rgb.dtype, ndmin=3)) + return np.sum(w * rgb, axis=2) + + + +def complex_randn(*args): + """Return a complex array of samples drawn from a standard normal + distribution. + + Parameters + ---------- + d0, d1, ..., dn: int + Dimensions of the random array + + Returns + ------- + a : ndarray + Random array of shape (d0, d1, ..., dn) + """ + + return np.random.randn(*args) + 1j*np.random.randn(*args) + + + +def spnoise(s, frc, smn=0.0, smx=1.0): + """Return image with salt & pepper noise imposed on it. + + Parameters + ---------- + s : ndarray + Input image + frc : float + Desired fraction of pixels corrupted by noise + smn : float, optional (default 0.0) + Lower value for noise (pepper) + smx : float, optional (default 1.0) + Upper value for noise (salt) + + Returns + ------- + sn : ndarray + Noisy image + """ + + sn = s.copy() + spm = np.random.uniform(-1.0, 1.0, s.shape) + sn[spm < frc - 1.0] = smn + sn[spm > 1.0 - frc] = smx + return sn + + + +def tikhonov_filter(s, lmbda, npd=16): + r"""Lowpass filter based on Tikhonov regularization. + + Lowpass filter image(s) and return low and high frequency + components, consisting of the lowpass filtered image and its + difference with the input image. The lowpass filter is equivalent to + Tikhonov regularization with `lmbda` as the regularization parameter + and a discrete gradient as the operator in the regularization term, + i.e. the lowpass component is the solution to + + .. math:: + \mathrm{argmin}_\mathbf{x} \; (1/2) \left\|\mathbf{x} - \mathbf{s} + \right\|_2^2 + (\lambda / 2) \sum_i \| G_i \mathbf{x} \|_2^2 \;\;, + + where :math:`\mathbf{s}` is the input image, :math:`\lambda` is the + regularization parameter, and :math:`G_i` is an operator that + computes the discrete gradient along image axis :math:`i`. Once the + lowpass component :math:`\mathbf{x}` has been computed, the highpass + component is just :math:`\mathbf{s} - \mathbf{x}`. + + + Parameters + ---------- + s : array_like + Input image or array of images. + lmbda : float + Regularization parameter controlling lowpass filtering. + npd : int, optional (default=16) + Number of samples to pad at image boundaries. + + Returns + ------- + sl : array_like + Lowpass image or array of images. + sh : array_like + Highpass image or array of images. + """ + + grv = np.array([-1.0, 1.0]).reshape([2, 1]) + gcv = np.array([-1.0, 1.0]).reshape([1, 2]) + Gr = sla.fftn(grv, (s.shape[0]+2*npd, s.shape[1]+2*npd), (0, 1)) + Gc = sla.fftn(gcv, (s.shape[0]+2*npd, s.shape[1]+2*npd), (0, 1)) + A = 1.0 + lmbda*np.conj(Gr)*Gr + lmbda*np.conj(Gc)*Gc + if s.ndim > 2: + A = A[(slice(None),)*2 + (np.newaxis,)*(s.ndim-2)] + sp = np.pad(s, ((npd, npd),)*2 + ((0, 0),)*(s.ndim-2), 'symmetric') + slp = np.real(sla.ifftn(sla.fftn(sp, axes=(0, 1)) / A, axes=(0, 1))) + sl = slp[npd:(slp.shape[0]-npd), npd:(slp.shape[1]-npd)] + sh = s - sl + return sl.astype(s.dtype), sh.astype(s.dtype) + + + +def idle_cpu_count(mincpu=1): + """Estimate number of idle CPUs, for use by multiprocessing code + needing to determine how many processes can be run without excessive + load. This function uses :func:`os.getloadavg` which is only available + under a Unix OS. + + Parameters + ---------- + mincpu : int + Minimum number of CPUs to report, independent of actual estimate + + Returns + ------- + idle : int + Estimate of number of idle CPUs + """ + + if PY2: + ncpu = mp.cpu_count() + else: + ncpu = os.cpu_count() + idle = int(ncpu - np.floor(os.getloadavg()[0])) + return max(mincpu, idle) + + + +def grid_search(fn, grd, fmin=True, nproc=None): + """Perform a grid search for optimal parameters of a specified + function. In the simplest case the function returns a float value, + and a single optimum value and corresponding parameter values are + identified. If the function returns a tuple of values, each of + these is taken to define a separate function on the search grid, + with optimum function values and corresponding parameter values + being identified for each of them. On all platforms except Windows + (where ``mp.Pool`` usage has some limitations), the computation + of the function at the grid points is computed in parallel. + + **Warning:** This function will hang if `fn` makes use of :mod:`pyfftw` + with multi-threading enabled (the + `bug `_ has been reported). + When using the FFT functions in :mod:`sporco.linalg`, multi-threading + can be disabled by including the following code:: + + import sporco.linalg + sporco.linalg.pyfftw_threads = 1 + + + Parameters + ---------- + fn : function + Function to be evaluated. It should take a tuple of parameter values as + an argument, and return a float value or a tuple of float values. + grd : tuple of array_like + A tuple providing an array of sample points for each axis of the grid + on which the search is to be performed. + fmin : bool, optional (default True) + Determine whether optimal function values are selected as minima or + maxima. If `fmin` is True then minima are selected. + nproc : int or None, optional (default None) + Number of processes to run in parallel. If None, the number of + CPUs of the system is used. + + Returns + ------- + sprm : ndarray + Optimal parameter values on each axis. If `fn` is multi-valued, + `sprm` is a matrix with rows corresponding to parameter values + and columns corresponding to function values. + sfvl : float or ndarray + Optimum function value or values + fvmx : ndarray + Function value(s) on search grid + sidx : tuple of int or tuple of ndarray + Indices of optimal values on parameter grid + """ + + if fmin: + slct = np.argmin + else: + slct = np.argmax + fprm = itertools.product(*grd) + if platform.system() == 'Windows': + fval = list(map(fn, fprm)) + else: + if nproc is None: + nproc = mp.cpu_count() + pool = mp.Pool(processes=nproc) + fval = pool.map(fn, fprm) + pool.close() + pool.join() + if isinstance(fval[0], (tuple, list, np.ndarray)): + nfnv = len(fval[0]) + fvmx = np.reshape(fval, [a.size for a in grd] + [nfnv,]) + sidx = np.unravel_index(slct(fvmx.reshape((-1, nfnv)), axis=0), + fvmx.shape[0:-1]) + (np.array((range(nfnv))),) + sprm = np.array([grd[k][sidx[k]] for k in range(len(grd))]) + sfvl = tuple(fvmx[sidx]) + else: + fvmx = np.reshape(fval, [a.size for a in grd]) + sidx = np.unravel_index(slct(fvmx), fvmx.shape) + sprm = np.array([grd[k][sidx[k]] for k in range(len(grd))]) + sfvl = fvmx[sidx] + + return sprm, sfvl, fvmx, sidx + + + +def convdicts(): + """Access a set of example learned convolutional dictionaries. + + Returns + ------- + cdd : dict + A dict associating description strings with dictionaries represented + as ndarrays + + Examples + -------- + Print the dict keys to obtain the identifiers of the available + dictionaries + + >>> from sporco import util + >>> cd = util.convdicts() + >>> print(cd.keys()) + ['G:12x12x72', 'G:8x8x16,12x12x32,16x16x48', ...] + + Select a specific example dictionary using the corresponding identifier + + >>> D = cd['G:8x8x96'] + """ + + pth = os.path.join(os.path.dirname(__file__), 'data', 'convdict.npz') + npz = np.load(pth) + cdd = {} + for k in list(npz.keys()): + cdd[k] = npz[k] + return cdd + + + +def netgetdata(url, maxtry=3, timeout=10): + """ + Get content of a file via a URL. + + Parameters + ---------- + url : string + URL of the file to be downloaded + maxtry : int, optional (default 3) + Maximum number of download retries + timeout : int, optional (default 10) + Timeout in seconds for blocking operations + + Returns + ------- + str : io.BytesIO + Buffered I/O stream + + Raises + ------ + urlerror.URLError (urllib2.URLError in Python 2, + urllib.error.URLError in Python 3) + If the file cannot be downloaded + """ + + err = ValueError('maxtry parameter should be greater than zero') + for ntry in range(maxtry): + try: + rspns = urlrequest.urlopen(url, timeout=timeout) + cntnt = rspns.read() + break + except urlerror.URLError as e: + err = e + if not isinstance(e.reason, socket.timeout): + raise + else: + raise err + + return io.BytesIO(cntnt) + + + +class ExampleImages(object): + """Access a set of example images.""" + + def __init__(self, scaled=False, dtype=None, zoom=None, gray=False, + pth=None): + """Initialise an ExampleImages object. + + Parameters + ---------- + scaled : bool, optional (default False) + Flag indicating whether images should be on the range [0,...,255] + with np.uint8 dtype (False), or on the range [0,...,1] with + np.float32 dtype (True) + dtype : data-type or None, optional (default None) + Desired data type of images. If `scaled` is True and `dtype` is an + integer type, the output data type is np.float32 + zoom : float or None, optional (default None) + Optional support rescaling factor to apply to the images + gray : bool, optional (default False) + Flag indicating whether RGB images should be converted to grayscale + pth : string or None (default None) + Path to directory containing image files. If the value is None the + path points to a set of example images that are included with the + package. + """ + + self.scaled = scaled + self.dtype = dtype + self.zoom = zoom + self.gray = gray + if pth is None: + self.bpth = os.path.join(os.path.dirname(__file__), 'data') + else: + self.bpth = pth + self.imglst = [] + self.grpimg = {} + for dirpath, dirnames, filenames in os.walk(self.bpth): + # It would be more robust and portable to use + # pathlib.PurePath.relative_to + prnpth = dirpath[len(self.bpth)+1:] + for f in filenames: + fpth = os.path.join(dirpath, f) + if imghdr.what(fpth) is not None: + gpth = os.path.join(prnpth, f) + self.imglst.append(gpth) + if prnpth not in self.grpimg: + self.grpimg[prnpth] = [] + self.grpimg[prnpth].append(gpth) + + + + def images(self): + """Get list of available images. + + Returns + ------- + nlst : list + A list of names of available images + """ + + return self.imglst + + + + def groups(self): + """Get list of available image groups. + + Returns + ------- + grp : list + A list of names of available image groups + """ + + return list(self.grpimg.keys()) + + + + def groupimages(self, grp): + """Get list of available images in specified group. + + Parameters + ---------- + grp : str + Name of image group + + Returns + ------- + nlst : list + A list of names of available images in the specified group + """ + + return self.grpimg[grp] + + + + def image(self, fname, group=None, scaled=None, dtype=None, idxexp=None, + zoom=None, gray=None): + """Get named image. + + Parameters + ---------- + fname : string + Filename of image + group : string or None, optional (default None) + Name of image group + scaled : bool or None, optional (default None) + Flag indicating whether images should be on the range [0,...,255] + with np.uint8 dtype (False), or on the range [0,...,1] with + np.float32 dtype (True). If the value is None, scaling behaviour + is determined by the `scaling` parameter passed to the object + initializer, otherwise that selection is overridden. + dtype : data-type or None, optional (default None) + Desired data type of images. If `scaled` is True and `dtype` is an + integer type, the output data type is np.float32. If the value is + None, the data type is determined by the `dtype` parameter passed to + the object initializer, otherwise that selection is overridden. + idxexp : index expression or None, optional (default None) + An index expression selecting, for example, a cropped region of + the requested image. This selection is applied *before* any + `zoom` rescaling so the expression does not need to be modified when + the zoom factor is changed. + zoom : float or None, optional (default None) + Optional rescaling factor to apply to the images. If the value is + None, support rescaling behaviour is determined by the `zoom` + parameter passed to the object initializer, otherwise that selection + is overridden. + gray : bool or None, optional (default None) + Flag indicating whether RGB images should be converted to grayscale. + If the value is None, behaviour is determined by the `gray` + parameter passed to the object initializer. + + Returns + ------- + img : ndarray + Image array + + Raises + ------ + IOError + If the image is not accessible + """ + + if scaled is None: + scaled = self.scaled + if dtype is None: + if self.dtype is None: + dtype = np.uint8 + else: + dtype = self.dtype + if scaled and np.issubdtype(dtype, np.integer): + dtype = np.float32 + if zoom is None: + zoom = self.zoom + if gray is None: + gray = self.gray + if group is None: + pth = os.path.join(self.bpth, fname) + else: + pth = os.path.join(self.bpth, group, fname) + + try: + img = np.asarray(misc.imread(pth), dtype=dtype) + except IOError: + raise IOError('Could not access image %s in group %s' % + (fname, group)) + + if scaled: + img /= 255.0 + if idxexp is not None: + img = img[idxexp] + if zoom is not None: + if img.ndim == 2: + img = sni.zoom(img, zoom) + else: + img = sni.zoom(img, (zoom,)*2 + (1,)*(img.ndim-2)) + if gray: + img = rgb2gray(img) + + return img + + + +class Timer(object): + """Timer class supporting multiple independent labelled timers. + + The timer is based on the relative time returned by + :func:`timeit.default_timer`. + """ + + def __init__(self, labels=None, dfltlbl='main', alllbl='all'): + """Initialise timer object. + + Parameters + ---------- + labels : string or list, optional (default None) + Specify the label(s) of the timer(s) to be initialised to zero. + dfltlbl : string, optional (default 'main') + Set the default timer label to be used when methods are + called without specifying a label + alllbl : string, optional (default 'all') + Set the label string that will be used to denote all timer labels + """ + + # Initialise current and accumulated time dictionaries + self.t0 = {} + self.td = {} + # Record default label and string indicating all labels + self.dfltlbl = dfltlbl + self.alllbl = alllbl + # Initialise dictionary entries for labels to be created + # immediately + if labels is not None: + if not isinstance(labels, (list, tuple)): + labels = [labels,] + for lbl in labels: + self.td[lbl] = 0.0 + self.t0[lbl] = None + + + + def start(self, labels=None): + """Start specified timer(s). + + Parameters + ---------- + labels : string or list, optional (default None) + Specify the label(s) of the timer(s) to be started. If it is + ``None``, start the default timer with label specified by the + ``dfltlbl`` parameter of :meth:`__init__`. + """ + + # Default label is self.dfltlbl + if labels is None: + labels = self.dfltlbl + # If label is not a list or tuple, create a singleton list + # containing it + if not isinstance(labels, (list, tuple)): + labels = [labels,] + # Iterate over specified label(s) + t = timer() + for lbl in labels: + # On first call to start for a label, set its accumulator to zero + if lbl not in self.td: + self.td[lbl] = 0.0 + self.t0[lbl] = None + # Record the time at which start was called for this lbl if + # it isn't already running + if self.t0[lbl] is None: + self.t0[lbl] = t + + + + def stop(self, labels=None): + """Stop specified timer(s). + + Parameters + ---------- + labels : string or list, optional (default None) + Specify the label(s) of the timer(s) to be stopped. If it is + ``None``, stop the default timer with label specified by the + ``dfltlbl`` parameter of :meth:`__init__`. If it is equal to + the string specified by the ``alllbl`` parameter of + :meth:`__init__`, stop all timers. + """ + + # Get current time + t = timer() + # Default label is self.dfltlbl + if labels is None: + labels = self.dfltlbl + # All timers are affected if label is equal to self.alllbl, + # otherwise only the timer(s) specified by label + if labels == self.alllbl: + labels = self.t0.keys() + elif not isinstance(labels, (list, tuple)): + labels = [labels,] + # Iterate over specified label(s) + for lbl in labels: + if lbl not in self.t0: + raise KeyError('Unrecognized timer key %s' % lbl) + # If self.t0[lbl] is None, the corresponding timer is + # already stopped, so no action is required + if self.t0[lbl] is not None: + # Increment time accumulator from the elapsed time + # since most recent start call + self.td[lbl] += t - self.t0[lbl] + # Set start time to None to indicate timer is not running + self.t0[lbl] = None + + + + def reset(self, labels=None): + """Reset specified timer(s). + + Parameters + ---------- + labels : string or list, optional (default None) + Specify the label(s) of the timer(s) to be stopped. If it is + ``None``, stop the default timer with label specified by the + ``dfltlbl`` parameter of :meth:`__init__`. If it is equal to + the string specified by the ``alllbl`` parameter of + :meth:`__init__`, stop all timers. + """ + + # Get current time + t = timer() + # Default label is self.dfltlbl + if labels is None: + labels = self.dfltlbl + # All timers are affected if label is equal to self.alllbl, + # otherwise only the timer(s) specified by label + if labels == self.alllbl: + labels = self.t0.keys() + elif not isinstance(labels, (list, tuple)): + labels = [labels,] + # Iterate over specified label(s) + for lbl in labels: + if lbl not in self.t0: + raise KeyError('Unrecognized timer key %s' % lbl) + # Set start time to None to indicate timer is not running + self.t0[lbl] = None + # Set time accumulator to zero + self.td[lbl] = 0.0 + + + + def elapsed(self, label=None, total=True): + """Get elapsed time since timer start. + + Parameters + ---------- + label : string, optional (default None) + Specify the label of the timer for which the elapsed time is + required. If it is ``None``, the default timer with label + specified by the ``dfltlbl`` parameter of :meth:`__init__` + is selected. + total : bool, optional (default True) + If ``True`` return the total elapsed time since the first + call of :meth:`start` for the selected timer, otherwise + return the elapsed time since the most recent call of + :meth:`start` for which there has not been a corresponding + call to :meth:`stop`. + + Returns + ------- + dlt : float + Elapsed time + """ + + # Get current time + t = timer() + # Default label is self.dfltlbl + if label is None: + label = self.dfltlbl + # Return 0.0 if default timer selected and it is not initialised + if label not in self.t0: + return 0.0 + # Raise exception if timer with specified label does not exist + if label not in self.t0: + raise KeyError('Unrecognized timer key %s' % label) + # If total flag is True return sum of accumulated time from + # previous start/stop calls and current start call, otherwise + # return just the time since the current start call + te = 0.0 + if self.t0[label] is not None: + te = t - self.t0[label] + if total: + te += self.td[label] + + return te + + + + def labels(self): + """Get a list of timer labels. + + Returns + ------- + lbl : list + List of timer labels + """ + + return self.t0.keys() + + + + def __str__(self): + """Return string representation of object. + + The representation consists of a table with the following columns: + + * Timer label + * Accumulated time from past start/stop calls + * Time since current start call, or 'Stopped' if timer is not + currently running + """ + + # Get current time + t = timer() + # Length of label field, calculated from max label length + lfldln = max([len(lbl) for lbl in self.t0] + [len(self.dfltlbl),]) + 2 + # Header string for table of timers + s = '%-*s Accum. Current\n' % (lfldln, 'Label') + s += '-' * (lfldln + 25) + '\n' + # Construct table of timer details + for lbl in sorted(self.t0): + td = self.td[lbl] + if self.t0[lbl] is None: + ts = ' Stopped' + else: + ts = ' %.2e s' % (t - self.t0[lbl]) + s += '%-*s %.2e s %s\n' % (lfldln, lbl, td, ts) + + return s + + + + +class ContextTimer(object): + """A wrapper class for :class:`Timer` that enables its use as a + context manager. + + For example, instead of + + >>> t = Timer() + >>> t.start() + >>> do_something() + >>> t.stop() + >>> elapsed = t.elapsed() + + one can use + + >>> t = Timer() + >>> with ContextTimer(t): + ... do_something() + >>> elapsed = t.elapsed() + """ + + def __init__(self, timer=None, label=None, action='StartStop'): + """Initialise context manager timer wrapper. + + Parameters + ---------- + timer : class:`Timer` object, optional (default None) + Specify the timer object to be used as a context manager. If + ``None``, a new class:`Timer` object is constructed. + label : string, optional (default None) + Specify the label of the timer to be used. If it is ``None``, + start the default timer. + action : string, optional (default 'StartStop') + Specify actions to be taken on context entry and exit. If + the value is 'StartStop', start the timer on entry and stop + on exit; if it is 'StopStart', stop the timer on entry and + start it on exit. + """ + + if action not in ['StartStop', 'StopStart']: + raise ValueError('Unrecognized action %s' % action) + if timer is None: + self.timer = Timer() + else: + self.timer = timer + self.label = label + self.action = action + + + def __enter__(self): + """Start the timer and return this ContextTimer instance.""" + + if self.action == 'StartStop': + self.timer.start(self.label) + else: + self.timer.stop(self.label) + return self + + + + def __exit__(self, type, value, traceback): + """Stop the timer and return True if no exception was raised within + the 'with' block, otherwise return False. + """ + + if self.action == 'StartStop': + self.timer.stop(self.label) + else: + self.timer.start(self.label) + if type: + return False + else: + return True + + + def elapsed(self, total=True): + """Return the elapsed time for the timer. + + Parameters + ---------- + total : bool, optional (default True) + If ``True`` return the total elapsed time since the first + call of :meth:`start` for the selected timer, otherwise + return the elapsed time since the most recent call of + :meth:`start` for which there has not been a corresponding + call to :meth:`stop`. + + Returns + ------- + dlt : float + Elapsed time + """ + + return self.timer.elapsed(self.label, total=total) diff --git a/facades/github/000001.png b/facades/github/000001.png new file mode 100644 index 0000000..3e6f487 Binary files /dev/null and b/facades/github/000001.png differ diff --git a/facades/github/000002.png b/facades/github/000002.png new file mode 100644 index 0000000..9621a23 Binary files /dev/null and b/facades/github/000002.png differ diff --git a/facades/github/000003.png b/facades/github/000003.png new file mode 100644 index 0000000..d5316e6 Binary files /dev/null and b/facades/github/000003.png differ diff --git a/facades/github/000004.png b/facades/github/000004.png new file mode 100644 index 0000000..5132477 Binary files /dev/null and b/facades/github/000004.png differ diff --git a/facades/github/000005.png b/facades/github/000005.png new file mode 100644 index 0000000..4fd3c2a Binary files /dev/null and b/facades/github/000005.png differ diff --git a/facades/github/000006.png b/facades/github/000006.png new file mode 100644 index 0000000..1f99351 Binary files /dev/null and b/facades/github/000006.png differ diff --git a/facades/github/000007.png b/facades/github/000007.png new file mode 100644 index 0000000..47fc41b Binary files /dev/null and b/facades/github/000007.png differ diff --git a/facades/github/000008.png b/facades/github/000008.png new file mode 100644 index 0000000..aa7ba9f Binary files /dev/null and b/facades/github/000008.png differ diff --git a/facades/github/000009.png b/facades/github/000009.png new file mode 100644 index 0000000..3e6f487 Binary files /dev/null and b/facades/github/000009.png differ diff --git a/facades/github/000010.png b/facades/github/000010.png new file mode 100644 index 0000000..9621a23 Binary files /dev/null and b/facades/github/000010.png differ diff --git a/facades/github/000011.png b/facades/github/000011.png new file mode 100644 index 0000000..d5316e6 Binary files /dev/null and b/facades/github/000011.png differ diff --git a/facades/github/000012.png b/facades/github/000012.png new file mode 100644 index 0000000..5132477 Binary files /dev/null and b/facades/github/000012.png differ diff --git a/facades/github/000013.png b/facades/github/000013.png new file mode 100644 index 0000000..4fd3c2a Binary files /dev/null and b/facades/github/000013.png differ diff --git a/facades/github/000014.png b/facades/github/000014.png new file mode 100644 index 0000000..1f99351 Binary files /dev/null and b/facades/github/000014.png differ diff --git a/facades/github/000015.png b/facades/github/000015.png new file mode 100644 index 0000000..47fc41b Binary files /dev/null and b/facades/github/000015.png differ diff --git a/facades/github/000016.png b/facades/github/000016.png new file mode 100644 index 0000000..aa7ba9f Binary files /dev/null and b/facades/github/000016.png differ diff --git a/kernel.mat b/kernel.mat new file mode 100644 index 0000000..76f552f Binary files /dev/null and b/kernel.mat differ diff --git a/misc.py b/misc.py new file mode 100644 index 0000000..03e795a --- /dev/null +++ b/misc.py @@ -0,0 +1,123 @@ +import torch +import os +import sys + + +def create_exp_dir(exp): + try: + os.makedirs(exp) + print('Creating exp dir: %s' % exp) + except OSError: + pass + return True + + +def weights_init(m): + classname = m.__class__.__name__ + if classname.find('Conv') != -1: + m.weight.data.normal_(0.0, 0.02) + elif classname.find('BatchNorm') != -1: + m.weight.data.normal_(1.0, 0.02) + m.bias.data.fill_(0) + + +def getLoader(datasetName, dataroot, originalSize, imageSize, batchSize=64, workers=4, + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), split='train', shuffle=True, seed=None): + + #import pdb; pdb.set_trace() + if datasetName == 'pix2pix': + # from datasets.pix2pix import pix2pix as commonDataset + # import transforms.pix2pix as transforms + from datasets.pix2pix import pix2pix as commonDataset + import transforms.pix2pix as transforms + elif datasetName == 'pix2pix_val': + # from datasets.pix2pix_val import pix2pix_val as commonDataset + # import transforms.pix2pix as transforms + from datasets.pix2pix_val import pix2pix_val as commonDataset + import transforms.pix2pix as transforms + if datasetName == 'pix2pix_class': + # from datasets.pix2pix import pix2pix as commonDataset + # import transforms.pix2pix as transforms + from datasets.pix2pix_class import pix2pix as commonDataset + import transforms.pix2pix as transforms + if split == 'train': + dataset = commonDataset(root=dataroot, + transform=transforms.Compose([ + transforms.Scale(originalSize), + #transforms.RandomCrop(imageSize), + #transforms.CenterCrop(imageSize), + #transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize(mean, std), + ]), + seed=seed) + else: + dataset = commonDataset(root=dataroot, + transform=transforms.Compose([ + transforms.Scale(originalSize), + #transforms.CenterCrop(imageSize), + transforms.ToTensor(), + transforms.Normalize(mean, std), + ]), + seed=seed) + + dataloader = torch.utils.data.DataLoader(dataset, + batch_size=batchSize, + shuffle=shuffle, + num_workers=int(workers)) + return dataloader + + +class AverageMeter(object): + """Computes and stores the average and current value""" + def __init__(self): + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + +import numpy as np +class ImagePool: + def __init__(self, pool_size=50): + self.pool_size = pool_size + if pool_size > 0: + self.num_imgs = 0 + self.images = [] + + def query(self, image): + if self.pool_size == 0: + return image + if self.num_imgs < self.pool_size: + self.images.append(image.clone()) + self.num_imgs += 1 + return image + else: + if np.random.uniform(0,1) > 0.5: + random_id = np.random.randint(self.pool_size, size=1)[0] + tmp = self.images[random_id].clone() + self.images[random_id] = image.clone() + return tmp + else: + return image + + +def adjust_learning_rate(optimizer, init_lr, epoch, factor, every): + #import pdb; pdb.set_trace() + lrd = init_lr / every + old_lr = optimizer.param_groups[0]['lr'] + # linearly decaying lr + print('learning rate: %f'%old_lr) + lr = old_lr - lrd + if lr < 0: lr = 0 + for param_group in optimizer.param_groups: + param_group['lr'] = lr diff --git a/models/__init__.py b/models/__init__.py new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/models/__init__.py @@ -0,0 +1 @@ + diff --git a/models/face_fed.py b/models/face_fed.py new file mode 100644 index 0000000..488eb84 --- /dev/null +++ b/models/face_fed.py @@ -0,0 +1,855 @@ +import torch +import torch.backends.cudnn as cudnn +import torch.nn as nn +import torch.nn.functional as F +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.model_zoo as model_zoo +from collections import OrderedDict +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.model_zoo as model_zoo +from collections import OrderedDict +import torchvision.models as models +from torch.autograd import Variable + + + +def conv_block(in_dim,out_dim): + return nn.Sequential(nn.Conv2d(in_dim,in_dim,kernel_size=3,stride=1,padding=1), + nn.ELU(True), + nn.Conv2d(in_dim,in_dim,kernel_size=3,stride=1,padding=1), + nn.ELU(True), + nn.Conv2d(in_dim,out_dim,kernel_size=1,stride=1,padding=0), + nn.AvgPool2d(kernel_size=2,stride=2)) +def deconv_block(in_dim,out_dim): + return nn.Sequential(nn.Conv2d(in_dim,out_dim,kernel_size=3,stride=1,padding=1), + nn.ELU(True), + nn.Conv2d(out_dim,out_dim,kernel_size=3,stride=1,padding=1), + nn.ELU(True), + nn.UpsamplingNearest2d(scale_factor=2)) + + +def blockUNet1(in_c, out_c, name, transposed=False, bn=False, relu=True, dropout=False): + block = nn.Sequential() + if relu: + block.add_module('%s.relu' % name, nn.ReLU(inplace=True)) + else: + block.add_module('%s.leakyrelu' % name, nn.LeakyReLU(0.2, inplace=True)) + if not transposed: + block.add_module('%s.conv' % name, nn.Conv2d(in_c, out_c, 3, 1, 1, bias=False)) + else: + block.add_module('%s.tconv' % name, nn.ConvTranspose2d(in_c, out_c, 3, 1, 1, bias=False)) + if bn: + block.add_module('%s.bn' % name, nn.InstanceNorm2d(out_c)) + if dropout: + block.add_module('%s.dropout' % name, nn.Dropout2d(0.5, inplace=True)) + return block + +def blockUNet(in_c, out_c, name, transposed=False, bn=False, relu=True, dropout=False): + block = nn.Sequential() + if relu: + block.add_module('%s.relu' % name, nn.ReLU(inplace=True)) + else: + block.add_module('%s.leakyrelu' % name, nn.LeakyReLU(0.2, inplace=True)) + if not transposed: + block.add_module('%s.conv' % name, nn.Conv2d(in_c, out_c, 4, 2, 1, bias=False)) + else: + block.add_module('%s.tconv' % name, nn.ConvTranspose2d(in_c, out_c, 4, 2, 1, bias=False)) + if bn: + block.add_module('%s.bn' % name, nn.InstanceNorm2d(out_c)) + if dropout: + block.add_module('%s.dropout' % name, nn.Dropout2d(0.5, inplace=True)) + return block + + +class D1(nn.Module): + def __init__(self, nc, ndf, hidden_size): + super(D1, self).__init__() + + # 256 + self.conv1 = nn.Sequential(nn.Conv2d(nc,ndf,kernel_size=3,stride=1,padding=1), + nn.ELU(True)) + # 256 + self.conv2 = conv_block(ndf,ndf) + # 128 + self.conv3 = conv_block(ndf, ndf*2) + # 64 + self.conv4 = conv_block(ndf*2, ndf*3) + # 32 + self.encode = nn.Conv2d(ndf*3, hidden_size, kernel_size=1,stride=1,padding=0) + self.decode = nn.Conv2d(hidden_size, ndf, kernel_size=1,stride=1,padding=0) + # 32 + self.deconv4 = deconv_block(ndf, ndf) + # 64 + self.deconv3 = deconv_block(ndf, ndf) + # 128 + self.deconv2 = deconv_block(ndf, ndf) + # 256 + self.deconv1 = nn.Sequential(nn.Conv2d(ndf,ndf,kernel_size=3,stride=1,padding=1), + nn.ELU(True), + nn.Conv2d(ndf,ndf,kernel_size=3,stride=1,padding=1), + nn.ELU(True), + nn.Conv2d(ndf, nc, kernel_size=3, stride=1, padding=1), + nn.Tanh()) + """ + self.deconv1 = nn.Sequential(nn.Conv2d(ndf,nc,kernel_size=3,stride=1,padding=1), + nn.Tanh()) + """ + def forward(self,x): + out1 = self.conv1(x) + out2 = self.conv2(out1) + out3 = self.conv3(out2) + out4 = self.conv4(out3) + out5 = self.encode(out4) + dout5= self.decode(out5) + dout4= self.deconv4(dout5) + dout3= self.deconv3(dout4) + dout2= self.deconv2(dout3) + dout1= self.deconv1(dout2) + return dout1 + +class D(nn.Module): + def __init__(self, nc, nf): + super(D, self).__init__() + + main = nn.Sequential() + # 256 + layer_idx = 1 + name = 'layer%d' % layer_idx + main.add_module('%s.conv' % name, nn.Conv2d(nc, nf, 4, 2, 1, bias=False)) + + # 128 + layer_idx += 1 + name = 'layer%d' % layer_idx + main.add_module(name, blockUNet(nf, nf*2, name, transposed=False, bn=True, relu=False, dropout=False)) + + # 64 + layer_idx += 1 + name = 'layer%d' % layer_idx + nf = nf * 2 + main.add_module(name, blockUNet(nf, nf*2, name, transposed=False, bn=True, relu=False, dropout=False)) + + # 32 + layer_idx += 1 + name = 'layer%d' % layer_idx + nf = nf * 2 + main.add_module('%s.leakyrelu' % name, nn.LeakyReLU(0.2, inplace=True)) + main.add_module('%s.conv' % name, nn.Conv2d(nf, nf*2, 4, 1, 1, bias=False)) + main.add_module('%s.bn' % name, nn.InstanceNorm2d(nf*2)) + + # 31 + layer_idx += 1 + name = 'layer%d' % layer_idx + nf = nf * 2 + main.add_module('%s.leakyrelu' % name, nn.LeakyReLU(0.2, inplace=True)) + main.add_module('%s.conv' % name, nn.Conv2d(nf, 1, 4, 1, 1, bias=False)) + main.add_module('%s.sigmoid' % name , nn.Sigmoid()) + # 30 (sizePatchGAN=30) + + self.main = main + + def forward(self, x): + output = self.main(x) + return output + +class ShareSepConv(nn.Module): + def __init__(self, kernel_size): + super(ShareSepConv, self).__init__() + assert kernel_size % 2 == 1, 'kernel size should be odd' + self.padding = (kernel_size - 1)//2 + weight_tensor = torch.zeros(1, 1, kernel_size, kernel_size) + weight_tensor[0, 0, (kernel_size-1)//2, (kernel_size-1)//2] = 1 + self.weight = nn.Parameter(weight_tensor) + self.kernel_size = kernel_size + + def forward(self, x): + inc = x.size(1) + expand_weight = self.weight.expand(inc, 1, self.kernel_size, self.kernel_size).contiguous() + return F.conv2d(x, expand_weight, + None, 1, self.padding, 1, inc) + + +class BottleneckBlockdls(nn.Module): + def __init__(self, in_planes, out_planes, dropRate=0.0): + super(BottleneckBlockdls, self).__init__() + inter_planes = out_planes * 4 + self.bn1 = nn.BatchNorm2d(in_planes) + self.relu = nn.ReLU(inplace=True) + self.conv_o = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, + padding=0, bias=False) + self.sharewconv1 = ShareSepConv(3) + self.conv1 = nn.Conv2d(in_planes, inter_planes, kernel_size=1, stride=1, + padding=0, bias=False) + self.sharewconv2 = ShareSepConv(3) + self.bn2 = nn.BatchNorm2d(inter_planes) + self.conv2 = nn.Conv2d(inter_planes, inter_planes, kernel_size=3, stride=1, + padding=2, dilation=2, bias=False) + self.bn4 = nn.BatchNorm2d(inter_planes) + self.conv4 = nn.Conv2d(inter_planes, out_planes, kernel_size=3, stride=1, + padding=2, dilation=2, bias=False) + self.droprate = dropRate + def forward(self, x): + out = self.conv1(self.relu(self.bn1(x))) + if self.droprate > 0: + out = F.dropout(out, p=self.droprate, inplace=False, training=self.training) + out = self.conv2(self.relu(self.bn2(out))) + outx = self.conv_o(x) + out = outx + self.conv4(self.sharewconv2(self.relu(self.bn4(out)))) + if self.droprate > 0: + out = F.dropout(out, p=self.droprate, inplace=False, training=self.training) + return torch.cat([x, out], 1) + +class BottleneckBlockdl(nn.Module): + def __init__(self, in_planes, out_planes, dropRate=0.0): + super(BottleneckBlockdl, self).__init__() + inter_planes = out_planes * 3 + self.bn1 = nn.InstanceNorm2d(in_planes) + self.relu = nn.ReLU(inplace=True) + self.conv_o = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, + padding=0, bias=False) + self.conv1 = nn.Conv2d(in_planes, inter_planes, kernel_size=1, stride=1, + padding=0, bias=False) + self.bn2 = nn.InstanceNorm2d(inter_planes) + self.conv2 = nn.Conv2d(inter_planes, inter_planes, kernel_size=3, stride=1, + padding=1, dilation=1, bias=False) + self.bn3 = nn.InstanceNorm2d(inter_planes) + self.conv3 = nn.Conv2d(inter_planes, inter_planes, kernel_size=3, stride=1, + padding=2, dilation=2, bias=False) + self.bn4 = nn.InstanceNorm2d(inter_planes) + self.sharewconv = ShareSepConv(3) + self.conv4 = nn.Conv2d(inter_planes, out_planes, kernel_size=3, stride=1, + padding=2, dilation=2, bias=False) + self.droprate = dropRate + def forward(self, x): + out = self.conv1(self.relu(self.bn1(x))) + if self.droprate > 0: + out = F.dropout(out, p=self.droprate, inplace=False, training=self.training) + out = self.conv2(self.relu(self.bn2(out))) + out = self.conv3(self.relu(self.bn3(out))) + outx = self.conv_o(x) + out = outx + self.conv4(self.sharewconv(self.relu(self.bn4(out)))) + if self.droprate > 0: + out = F.dropout(out, p=self.droprate, inplace=False, training=self.training) + return torch.cat([x, out], 1) + + +class BottleneckBlockrs1(nn.Module): + def __init__(self, in_planes, out_planes, dropRate=0.0): + super(BottleneckBlockrs1, self).__init__() + inter_planes = out_planes * 3 + self.bn1 = nn.InstanceNorm2d(in_planes) + self.relu = nn.ReLU(inplace=True) + self.conv_o = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, + padding=0, bias=False) + self.conv1 = nn.Conv2d(in_planes, inter_planes, kernel_size=1, stride=1, + padding=0, bias=False) + self.bn2 = nn.InstanceNorm2d(inter_planes) + self.conv2 = nn.Conv2d(inter_planes, inter_planes, kernel_size=3, stride=1, + padding=1, bias=False) + self.bn3 = nn.InstanceNorm2d(inter_planes) + self.conv3 = nn.Conv2d(inter_planes, inter_planes, kernel_size=3, stride=1, + padding=2, dilation=2, bias=False) + self.bn4 = nn.InstanceNorm2d(inter_planes) + self.conv4 = nn.Conv2d(inter_planes, out_planes, kernel_size=3, stride=1, + padding=1, bias=False) + self.droprate = dropRate + def forward(self, x): + out = self.conv1(self.relu(self.bn1(x))) + if self.droprate > 0: + out = F.dropout(out, p=self.droprate, inplace=False, training=self.training) + out = self.conv2(self.relu(self.bn2(out))) + out = self.conv3(self.relu(self.bn3(out))) + outx = self.conv_o(x) + out = outx + self.conv4(self.relu(self.bn4(out))) + if self.droprate > 0: + out = F.dropout(out, p=self.droprate, inplace=False, training=self.training) + return torch.cat([x, out], 1) + +class BottleneckBlockrs(nn.Module): + def __init__(self, in_planes, out_planes, dropRate=0.0): + super(BottleneckBlockrs, self).__init__() + inter_planes = out_planes * 3 + self.bn1 = nn.InstanceNorm2d(in_planes) + self.relu = nn.ReLU(inplace=True) + self.conv_o = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, + padding=0, bias=False) + self.conv1 = nn.Conv2d(in_planes, inter_planes, kernel_size=1, stride=1, + padding=0, bias=False) + self.bn2 = nn.InstanceNorm2d(inter_planes) + self.conv2 = nn.Conv2d(inter_planes, inter_planes, kernel_size=3, stride=1, + padding=1, bias=False) + self.bn3 = nn.InstanceNorm2d(inter_planes) + self.conv3 = nn.Conv2d(inter_planes, inter_planes, kernel_size=3, stride=1, + padding=1, bias=False) + self.bn4 = nn.InstanceNorm2d(inter_planes) + self.conv4 = nn.Conv2d(inter_planes, out_planes, kernel_size=3, stride=1, + padding=1, bias=False) + self.droprate = dropRate + def forward(self, x): + out = self.conv1(self.relu(self.bn1(x))) + if self.droprate > 0: + out = F.dropout(out, p=self.droprate, inplace=False, training=self.training) + out = self.conv2(self.relu(self.bn2(out))) + out = self.conv3(self.relu(self.bn3(out))) + outx = self.conv_o(x) + out = outx + self.conv4(self.relu(self.bn4(out))) + if self.droprate > 0: + out = F.dropout(out, p=self.droprate, inplace=False, training=self.training) + return torch.cat([x, out], 1) + + + + + +class TransitionBlock(nn.Module): + def __init__(self, in_planes, out_planes, dropRate=0.0): + super(TransitionBlock, self).__init__() + self.bn1 = nn.InstanceNorm2d(in_planes) + self.relu = nn.ReLU(inplace=True) + self.conv1 = nn.ConvTranspose2d(in_planes, out_planes, kernel_size=1, stride=1, + padding=0, bias=False) + self.droprate = dropRate + def forward(self, x): + out = self.conv1(self.relu(self.bn1(x))) + if self.droprate > 0: + out = F.dropout(out, p=self.droprate, inplace=False, training=self.training) + return F.upsample_nearest(out, scale_factor=2) + + + +class TransitionBlock1(nn.Module): + def __init__(self, in_planes, out_planes, dropRate=0.0): + super(TransitionBlock1, self).__init__() + self.bn1 = nn.InstanceNorm2d(in_planes) + self.relu = nn.ReLU(inplace=True) + self.conv1 = nn.ConvTranspose2d(in_planes, out_planes, kernel_size=1, stride=1, + padding=0, bias=False) + self.droprate = dropRate + def forward(self, x): + out = self.conv1(self.relu(self.bn1(x))) + if self.droprate > 0: + out = F.dropout(out, p=self.droprate, inplace=False, training=self.training) + return F.avg_pool2d(out, 2) + + + +class TransitionBlock3(nn.Module): + def __init__(self, in_planes, out_planes, dropRate=0.0): + super(TransitionBlock3, self).__init__() + self.bn1 = nn.InstanceNorm2d(in_planes) + self.relu = nn.ReLU(inplace=True) + self.conv1 = nn.ConvTranspose2d(in_planes, out_planes, kernel_size=1, stride=1, + padding=0, bias=False) + self.droprate = dropRate + def forward(self, x): + out = self.conv1(self.relu(self.bn1(x))) + if self.droprate > 0: + out = F.dropout(out, p=self.droprate, inplace=False, training=self.training) + return out + + + + + + +class vgg19ca(nn.Module): + def __init__(self): + super(vgg19ca, self).__init__() + + + + + ############# 256-256 ############## + haze_class = models.vgg19_bn(pretrained=True) + self.feature = nn.Sequential(haze_class.features[0]) + + for i in range(1,3): + self.feature.add_module(str(i),haze_class.features[i]) + + self.conv16=nn.Conv2d(64, 24, kernel_size=3,stride=1,padding=1) # 1mm + self.dense_classifier=nn.Linear(127896, 512) + self.dense_classifier1=nn.Linear(512, 4) + + + def forward(self, x): + + feature=self.feature(x) + # feature = Variable(feature.data, requires_grad=True) + + feature=self.conv16(feature) + # print feature.size() + + # feature=Variable(feature.data,requires_grad=True) + + + + out = F.relu(feature, inplace=True) + out = F.avg_pool2d(out, kernel_size=7).view(out.size(0), -1) + # print out.size() + + # out=Variable(out.data,requires_grad=True) + out = F.relu(self.dense_classifier(out)) + out = (self.dense_classifier1(out)) + + + return out +class BottleneckBlock(nn.Module): + def __init__(self, in_planes, out_planes, dropRate=0.0): + super(BottleneckBlock, self).__init__() + inter_planes = out_planes * 3 + self.bn1 = nn.InstanceNorm2d(in_planes) + self.relu = nn.ReLU(inplace=True) + self.conv_o = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, + padding=0, bias=False) + self.conv1 = nn.Conv2d(in_planes, inter_planes, kernel_size=1, stride=1, + padding=0, bias=False) + self.bn2 = nn.InstanceNorm2d(inter_planes) + self.conv2 = nn.Conv2d(inter_planes, inter_planes, kernel_size=3, stride=1, + padding=1, bias=False) + + self.bn4 = nn.InstanceNorm2d(inter_planes) + self.conv4 = nn.Conv2d(inter_planes, out_planes, kernel_size=3, stride=1, + padding=1, bias=False) + self.droprate = dropRate + def forward(self, x): + out = self.conv1(self.relu(self.bn1(x))) + if self.droprate > 0: + out = F.dropout(out, p=self.droprate, inplace=False, training=self.training) + out = self.conv2(self.relu(self.bn2(out))) + outx = self.conv_o(x) + out = outx + self.conv4(self.relu(self.bn4(out))) + if self.droprate > 0: + out = F.dropout(out, p=self.droprate, inplace=False, training=self.training) + return out + +class TransitionBlockbil(nn.Module): + def __init__(self, in_planes, out_planes, dropRate=0.0): + super(TransitionBlockbil, self).__init__() + self.bn1 = nn.InstanceNorm2d(in_planes) + self.relu = nn.ReLU(inplace=True) + self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, + padding=0, bias=False) + self.bn2 = nn.InstanceNorm2d(out_planes) + self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1, + padding=1, bias=False) + self.droprate = dropRate + def forward(self, x): + out = self.conv1(self.relu(self.bn1(x))) + if self.droprate > 0: + out = F.dropout(out, p=self.droprate, inplace=False, training=self.training) + out = F.upsample_bilinear(out, scale_factor=2) + return self.conv2(out) + + +class Deblur_first(nn.Module): + def __init__(self,in_channels): + super(Deblur_first, self).__init__() + + self.dense_block1=BottleneckBlockrs(in_channels,32-in_channels) + self.trans_block1=TransitionBlock1(32,16) + + ############# Block2-down 32-32 ############## + self.dense_block2=BottleneckBlockdl(16,16) + self.trans_block2=TransitionBlock3(32,16) + + ############# Block3-down 16-16 ############## + self.dense_block3=BottleneckBlockdl(16,16) + self.trans_block3=TransitionBlock3(32,16) + + + ############# Block5-up 16-16 ############## + self.dense_block5=BottleneckBlockdl(32,16) + self.trans_block5=TransitionBlock3(48,16) + + self.dense_block6=BottleneckBlockrs(16,16) + self.trans_block6=TransitionBlockbil(32,16) + + + self.conv_refin=nn.Conv2d(16,16,3,1,1) + self.conv_refin_in=nn.Conv2d(in_channels,16,3,1,1) + self.tanh=nn.Tanh() + + + self.conv1010 = nn.Conv2d(20, 1, kernel_size=1,stride=1,padding=0) # 1mm + self.conv1020 = nn.Conv2d(20, 1, kernel_size=1,stride=1,padding=0) # 1mm + self.conv1030 = nn.Conv2d(20, 1, kernel_size=1,stride=1,padding=0) # 1mm + self.conv1040 = nn.Conv2d(20, 1, kernel_size=1,stride=1,padding=0) # 1mm + + self.refine3= nn.Conv2d(16, 3, kernel_size=3,stride=1,padding=1) + # self.refine3= nn.Conv2d(20+4, 3, kernel_size=7,stride=1,padding=3) + + self.upsample = F.upsample_nearest + + self.relu=nn.LeakyReLU(0.2, inplace=True) + self.refineclean1= nn.Conv2d(3, 8, kernel_size=7,stride=1,padding=3) + self.refineclean2= nn.Conv2d(8, 3, kernel_size=3,stride=1,padding=1) + + + def forward(self, x,smaps): + ## 256x256 + x1=self.dense_block1(torch.cat([x,smaps],1)) + x1=self.trans_block1(x1) + + ### 32x32 + x2=(self.dense_block2(x1)) + x2=self.trans_block2(x2) + + #print x2.size() + ### 16 X 16 + x3=(self.dense_block3(x2)) + x3=self.trans_block3(x3) + + x5_in=torch.cat([x3, x1], 1) + x5_i=(self.dense_block5(x5_in)) + + x5=self.trans_block5(x5_i) + x6=(self.dense_block6(x5)) + x6=(self.trans_block6(x6)) + + x7=self.relu(self.conv_refin_in(torch.cat([x,smaps],1))) - self.relu(self.conv_refin(x6)) + residual=self.tanh(self.refine3(x7)) + clean = x - residual + clean = self.relu(self.refineclean1(clean)) + clean = self.tanh(self.refineclean2(clean)) + + + return clean,x5 + +class Deblur_class(nn.Module): + def __init__(self): + super(Deblur_class, self).__init__() + + ##### stage class networks ########### + self.deblur_class1 = Deblur_first(4) + self.deblur_class2 = Deblur_first(4) + self.deblur_class3 = Deblur_first(11) + self.deblur_class4 = Deblur_first(4) + ###################################### + + def forward(self, x_input1,x_input2,x_input3,x_input4,class1,class2,class3,class4): + xh_class1,x_lst1 = self.deblur_class1(x_input1,class1) + xh_class2,x_lst2 = self.deblur_class2(x_input2,class2) + xh_class3,x_lst3 = self.deblur_class3(x_input3,class3) + xh_class4,x_lst4 = self.deblur_class4(x_input4,class4) + + return xh_class1,xh_class2,xh_class3,xh_class4,x_lst1,x_lst2,x_lst3,x_lst4 + +class BottleneckBlockcf(nn.Module): + def __init__(self, in_planes, out_planes, dropRate=0.0): + super(BottleneckBlockcf, self).__init__() + inter_planes = out_planes * 3 + self.bn1 = nn.InstanceNorm2d(in_planes) + self.relu = nn.ReLU(inplace=True) + self.conv_o = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, + padding=0, bias=False) + self.conv1 = nn.Conv2d(in_planes, inter_planes, kernel_size=1, stride=1, + padding=0, bias=False) + self.bn2 = nn.InstanceNorm2d(inter_planes) + self.conv2 = nn.Conv2d(inter_planes, inter_planes, kernel_size=3, stride=1, + padding=1, bias=False) + + self.bn4 = nn.InstanceNorm2d(inter_planes) + self.conv4 = nn.Conv2d(inter_planes, out_planes, kernel_size=3, stride=1, + padding=1, bias=False) + self.droprate = dropRate + def forward(self, x): + out = self.conv1(self.relu(self.bn1(x))) + if self.droprate > 0: + out = F.dropout(out, p=self.droprate, inplace=False, training=self.training) + out = self.conv2(self.relu(self.bn2(out))) + out = self.conv4(self.relu(self.bn4(out))) + if self.droprate > 0: + out = F.dropout(out, p=self.droprate, inplace=False, training=self.training) + return out + +class scale_kernel_conf(nn.Module): + def __init__(self): + super(scale_kernel_conf, self).__init__() + + self.conv1 = nn.Conv2d(6,16,3,1,1)#BottleneckBlock(35, 16) + self.trans_block1 = TransitionBlock1(16, 16) + self.conv2 = BottleneckBlockcf(16, 32) + self.trans_block2 = TransitionBlock1(32, 16) + self.conv3 = BottleneckBlockcf(16, 32) + self.trans_block3 = TransitionBlock1(32, 16) + self.conv4 = BottleneckBlockcf(16, 32) + self.trans_block4 = TransitionBlock3(32, 16) + self.conv_refin = nn.Conv2d(16, 3, 1, 1, 0) + self.sig = torch.nn.Sigmoid() + + self.relu = nn.LeakyReLU(0.2, inplace=True) + + def forward(self, x,target): + x1=self.conv1(torch.cat([x,target],1)) + x1 = self.trans_block1(x1) + x2=self.conv2(x1) + x2 = self.trans_block2(x2) + x3=self.conv3(x2) + x3 = self.trans_block3(x3) + x4=self.conv3(x3) + x4 = self.trans_block4(x4) + #print(x4.size()) + residual = self.sig(self.conv_refin(self.sig(F.avg_pool2d(x4,16)))) + #print(residual) + residual = F.upsample_nearest(residual, scale_factor=128) + #print(residual.size()) + return residual + + + +class Deblur_segdl(nn.Module): + def __init__(self): + super(Deblur_segdl, self).__init__() + self.deblur_class1 = Deblur_first(4) + self.deblur_class2 = Deblur_first(4) + self.deblur_class3 = Deblur_first(4) + self.deblur_class4 = Deblur_first(4) + self.dense_block1=BottleneckBlockrs(7,57) + self.dense_block_cl=BottleneckBlock(64,32) + #self.trans_block_cl=TransitionBlock1(64,32) + self.trans_block1=TransitionBlock1(64,32) + + ############# Block2-down 32-32 ############## + self.dense_block2=BottleneckBlockrs1(67,64) + self.trans_block2=TransitionBlock3(131,64) + + ############# Block3-down 16-16 ############## + self.dense_block3=BottleneckBlockdl(64,64) + self.trans_block3=TransitionBlock3(128,64) + + self.dense_block3_1=BottleneckBlockdl(64,64) + self.trans_block3_1=TransitionBlock3(128,64) + + self.dense_block3_2=BottleneckBlockdl(64,64) + self.trans_block3_2=TransitionBlock3(128,64) + + ############# Block4-up 8-8 ############## + self.dense_block4=BottleneckBlockdl(64,64) + self.trans_block4=TransitionBlock3(128,64) + + ############# Block5-up 16-16 ############## + self.dense_block5=BottleneckBlockrs1(128,64) + self.trans_block5=TransitionBlockbil(195,64) + + self.dense_block6=BottleneckBlockrs(71,64) + self.trans_block6=TransitionBlock3(135,16) + + + self.conv_refin=nn.Conv2d(23,16,3,1,1) + self.conv_refin_in=nn.Conv2d(7,16,3,1,1) + self.conv_refin_in64=nn.Conv2d(3,16,3,1,1) + self.conv_refin64=nn.Conv2d(192,16,3,1,1) + self.tanh=nn.Tanh() + + + self.conv1010 = nn.Conv2d(20, 1, kernel_size=1,stride=1,padding=0) # 1mm + self.conv1020 = nn.Conv2d(20, 1, kernel_size=1,stride=1,padding=0) # 1mm + self.conv1030 = nn.Conv2d(20, 1, kernel_size=1,stride=1,padding=0) # 1mm + self.conv1040 = nn.Conv2d(20, 1, kernel_size=1,stride=1,padding=0) # 1mm + + self.refine3= nn.Conv2d(16, 3, kernel_size=3,stride=1,padding=1) + # self.refine3= nn.Conv2d(20+4, 3, kernel_size=7,stride=1,padding=3) + + self.upsample = F.upsample_nearest + + self.relu=nn.LeakyReLU(0.2, inplace=True) + self.refineclean1= nn.Conv2d(3, 8, kernel_size=7,stride=1,padding=3) + self.refineclean2= nn.Conv2d(8, 3, kernel_size=3,stride=1,padding=1) + + self.conv11 = nn.Conv2d(64, 1, kernel_size=3,stride=1,padding=1) # 1mm + self.conv21 = nn.Conv2d(64, 1, kernel_size=3,stride=1,padding=1) # 1mm + self.conv31 = nn.Conv2d(64, 1, kernel_size=3,stride=1,padding=1) # 1mm + self.conv3_11 = nn.Conv2d(64, 1, kernel_size=3,stride=1,padding=1) # 1mm + self.conv3_21 = nn.Conv2d(64, 1, kernel_size=3,stride=1,padding=1) # 1mm + self.conv41 = nn.Conv2d(64, 1, kernel_size=3,stride=1,padding=1) # 1mm + self.conv51 = nn.Conv2d(64, 1, kernel_size=3,stride=1,padding=1) # 1mm + #self.conv61 = nn.Conv2d(8, 1, kernel_size=3,stride=1,padding=1) # 1mm + + + + + self.batchnorm20=nn.InstanceNorm2d(20) + self.batchnorm1=nn.InstanceNorm2d(1) + self.conf_ker = scale_kernel_conf() + + + + def forward(self, x,x_64,smaps,class1,class2,class3,class4,target,class_msk1,class_msk2,class_msk3,class_msk4): + ## 256x256 + xcl_class1,xh_class1 = self.deblur_class1(x,class1) + xcl_class2,xh_class2 = self.deblur_class2(x,class2) + xcl_class3,xh_class3 = self.deblur_class3(x,class3) + xcl_class4,xh_class4 = self.deblur_class4(x,class4) + x_cl = self.dense_block_cl(torch.cat([xh_class1,xh_class2,xh_class3,xh_class4],1)) + x1=self.dense_block1(torch.cat([x,smaps],1)) + x1=self.trans_block1(x1) + + ### 32x32 + x2=(self.dense_block2(torch.cat([x1,x_64,x_cl],1))) + x2=self.trans_block2(x2) + + #print x2.size() + ### 16 X 16 + x3=(self.dense_block3(x2)) + x3=self.trans_block3(x3) + + x3_1 = (self.dense_block3_1(x3)) + x3_1 = self.trans_block3_1(x3_1) + #print x3_1.size() + x3_2 = (self.dense_block3_2(x3_1)) + x3_2 = self.trans_block3_2(x3_2) + + ## Classifier ## + #x4_in = torch.cat([x3_2, x2], 1) + x4=(self.dense_block4(x3_2)) + x4=self.trans_block4(x4) + x5_in=torch.cat([x4, x1,x_cl], 1) + x5_i=(self.dense_block5(x5_in)) + + xhat64 = self.relu(self.conv_refin_in64(x_64)) - self.relu(self.conv_refin64(x5_i)) + xhat64 = self.tanh(self.refine3(xhat64)) + x5=self.trans_block5(torch.cat([x5_i,xhat64],1)) + x6=(self.dense_block6(torch.cat([x5,x,smaps],1))) + x6=(self.trans_block6(x6)) + shape_out = x6.data.size() + # print(shape_out) + shape_out = shape_out[2:4] + x11 = self.upsample(self.relu((self.conv11(torch.cat([x1,x_cl], 1)))), size=shape_out) + x21 = self.upsample(self.relu((self.conv21(x2))), size=shape_out) + x31 = self.upsample(self.relu((self.conv31(x3))), size=shape_out) + x3_11 = self.upsample(self.relu((self.conv3_11(x3_1))), size=shape_out) + x3_21 = self.upsample(self.relu((self.conv3_21(x3_2))), size=shape_out) + x41 = self.upsample(self.relu((self.conv41(x4))), size=shape_out) + x51 = self.upsample(self.relu((self.conv51(x5))), size=shape_out) + x6=torch.cat([x6,x51,x41,x3_21,x3_11,x31,x21,x11],1) + x7=self.relu(self.conv_refin_in(torch.cat([x,smaps],1))) - self.relu(self.conv_refin(x6)) + residual=self.tanh(self.refine3(x7)) + clean = x - residual + clean = self.relu(self.refineclean1(clean)) + clean = self.tanh(self.refineclean2(clean)) + + clean64 = x_64 - xhat64 + clean64 = self.relu(self.refineclean1(clean64)) + clean64 = self.tanh(self.refineclean2(clean64)) + + xmask1 = self.conf_ker(clean*class_msk1,target*class_msk1) + xmask2 = self.conf_ker(clean*class_msk2,target*class_msk2) + xmask3 = self.conf_ker(clean*class_msk3,target*class_msk3) + xmask4 = self.conf_ker(clean*class_msk4,target*class_msk4) + + return clean,clean64,xmask1,xmask2,xmask3,xmask4,xcl_class1,xcl_class2,xcl_class3,xcl_class4 + +class Segmentation(nn.Module): + def __init__(self): + super(Segmentation, self).__init__() + + self.dense_block1=BottleneckBlockrs1(3,61) + self.trans_block1=TransitionBlock1(64,64) + + ############# Block2-down 32-32 ############## + self.dense_block2=BottleneckBlockdls(67,64) + self.trans_block2=TransitionBlock1(131,64) + + ############# Block3-down 16-16 ############## + self.dense_block3=BottleneckBlockdls(64,64) + self.trans_block3=TransitionBlock3(128,64) + + self.dense_block3_1=BottleneckBlockdls(64,64) + self.trans_block3_1=TransitionBlock3(128,64) + + self.dense_block3_2=BottleneckBlockdls(64,64) + self.trans_block3_2=TransitionBlock3(128,64) + + ############# Block4-up 8-8 ############## + self.dense_block4=BottleneckBlockdls(128,64) + self.trans_block4=TransitionBlock(192,64) + + ############# Block5-up 16-16 ############## + self.dense_block5=BottleneckBlockdls(128,64) + self.trans_block5=TransitionBlock(196,64) + + self.dense_block6=BottleneckBlockrs1(64,64) + self.trans_block6=TransitionBlock3(128,16) + + + self.conv_refin=nn.Conv2d(23,16,3,1,1) + self.conv_refin64=nn.Conv2d(192,16,3,1,1) + self.tanh=nn.Sigmoid() + + self.refine3= nn.Conv2d(16, 4, kernel_size=3,stride=1,padding=1) + self.refine3_i= nn.Conv2d(16, 4, kernel_size=3,stride=1,padding=1) + # self.refine3= nn.Conv2d(20+4, 3, kernel_size=7,stride=1,padding=3) + + self.upsample = F.upsample_nearest + + self.relu=nn.LeakyReLU(0.2, inplace=True) + self.refineclean1= nn.Conv2d(4, 8, kernel_size=7,stride=1,padding=3) + self.refineclean2= nn.Conv2d(8, 4, kernel_size=3,stride=1,padding=1) + + self.conv11 = nn.Conv2d(64, 1, kernel_size=3,stride=1,padding=1) # 1mm + self.conv21 = nn.Conv2d(64, 1, kernel_size=3,stride=1,padding=1) # 1mm + self.conv31 = nn.Conv2d(64, 1, kernel_size=3,stride=1,padding=1) # 1mm + self.conv3_11 = nn.Conv2d(64, 1, kernel_size=3,stride=1,padding=1) # 1mm + self.conv3_21 = nn.Conv2d(64, 1, kernel_size=3,stride=1,padding=1) # 1mm + self.conv41 = nn.Conv2d(64, 1, kernel_size=3,stride=1,padding=1) # 1mm + self.conv51 = nn.Conv2d(64, 1, kernel_size=3,stride=1,padding=1) # 1mm + #self.conv61 = nn.Conv2d(8, 1, kernel_size=3,stride=1,padding=1) # 1mm + + + + + self.batchnorm20=nn.BatchNorm2d(20) + self.batchnorm1=nn.BatchNorm2d(1) + + + + def forward(self, x,x_64): + ## 256x256 + x1=self.dense_block1(x) + x1=self.trans_block1(x1) + + ### 32x32 + x2=(self.dense_block2(torch.cat([x1,x_64],1))) + x2=self.trans_block2(x2) + + #print x2.size() + ### 16 X 16 + x3=(self.dense_block3(x2)) + x3=self.trans_block3(x3) + + x3_1 = (self.dense_block3_1(x3)) + x3_1 = self.trans_block3_1(x3_1) + #print x3_1.size() + x3_2 = (self.dense_block3_2(x3_1)) + x3_2 = self.trans_block3_2(x3_2) + + ## Classifier ## + x4_in = torch.cat([x3_2, x2], 1) + x4=(self.dense_block4(x4_in)) + x4=self.trans_block4(x4) + x5_in=torch.cat([x4, x1], 1) + x5_i=(self.dense_block5(x5_in)) + xhat64 = self.relu(self.conv_refin64(x5_i)) + xhat64 = self.tanh(self.refine3_i(xhat64)) + x5=self.trans_block5(torch.cat([x5_i, xhat64], 1)) + + x6=(self.dense_block6(x5)) + x6=(self.trans_block6(x6)) + shape_out = x6.data.size() + # print(shape_out) + shape_out = shape_out[2:4] + x11 = self.upsample(self.relu((self.conv11(x1))), size=shape_out) + x21 = self.upsample(self.relu((self.conv21(x2))), size=shape_out) + x31 = self.upsample(self.relu((self.conv31(x3))), size=shape_out) + x3_11 = self.upsample(self.relu((self.conv3_11(x3_1))), size=shape_out) + x3_21 = self.upsample(self.relu((self.conv3_21(x3_2))), size=shape_out) + x41 = self.upsample(self.relu((self.conv41(x4))), size=shape_out) + x51 = self.upsample(self.relu((self.conv51(x5))), size=shape_out) + x6 = torch.cat([x6,x51,x41,x3_21,x3_11,x31,x21,x11],1) + x7 = self.relu(self.conv_refin(x6)) + residual = self.tanh(self.refine3(x7)) + + return residual,xhat64 diff --git a/myutils/1 b/myutils/1 new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/myutils/1 @@ -0,0 +1 @@ + diff --git a/myutils/StyleLoader.py b/myutils/StyleLoader.py new file mode 100644 index 0000000..4c4ec44 --- /dev/null +++ b/myutils/StyleLoader.py @@ -0,0 +1,37 @@ +##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +## Created by: Hang Zhang +## ECE Department, Rutgers University +## Email: zhang.hang@rutgers.edu +## Copyright (c) 2017 +## +## This source code is licensed under the MIT-style license found in the +## LICENSE file in the root directory of this source tree +##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +import os +from torch.autograd import Variable + +from myutils import utils + +class StyleLoader(): + def __init__(self, style_folder, style_size, cuda=True): + self.folder = style_folder + self.style_size = style_size + self.files = os.listdir(style_folder) + self.cuda = cuda + + def get(self, i): + idx = i%len(self.files) + filepath = os.path.join(self.folder, self.files[idx]) + style = utils.tensor_load_rgbimage(filepath, self.style_size) + style = style.unsqueeze(0) + style = utils.preprocess_batch(style) + if self.cuda: + style = style.cuda() + style_v = Variable(style, requires_grad=False) + return style_v + + def size(self): + return len(self.files) + + diff --git a/myutils/__init__.pyc b/myutils/__init__.pyc new file mode 100644 index 0000000..a211062 Binary files /dev/null and b/myutils/__init__.pyc differ diff --git a/myutils/convert_lua.py b/myutils/convert_lua.py new file mode 100644 index 0000000..6d72a64 --- /dev/null +++ b/myutils/convert_lua.py @@ -0,0 +1,310 @@ +from __future__ import print_function + +import os +import math +import torch +import argparse +import numpy as np +import torch.nn as nn +import torch.optim as optim +import torch.legacy.nn as lnn +import torch.nn.functional as F + +from functools import reduce +from torch.autograd import Variable +from torch.utils.serialization import load_lua + +class LambdaBase(nn.Sequential): + def __init__(self, fn, *args): + super(LambdaBase, self).__init__(*args) + self.lambda_func = fn + + def forward_prepare(self, input): + output = [] + for module in self._modules.values(): + output.append(module(input)) + return output if output else input + +class Lambda(LambdaBase): + def forward(self, input): + return self.lambda_func(self.forward_prepare(input)) + +class LambdaMap(LambdaBase): + def forward(self, input): + # result is Variables list [Variable1, Variable2, ...] + return list(map(self.lambda_func,self.forward_prepare(input))) + +class LambdaReduce(LambdaBase): + def forward(self, input): + # result is a Variable + return reduce(self.lambda_func,self.forward_prepare(input)) + + +def copy_param(m,n): + if m.weight is not None: n.weight.data.copy_(m.weight) + if m.bias is not None: n.bias.data.copy_(m.bias) + if hasattr(n,'running_mean'): n.running_mean.copy_(m.running_mean) + if hasattr(n,'running_var'): n.running_var.copy_(m.running_var) + +def add_submodule(seq, *args): + for n in args: + seq.add_module(str(len(seq._modules)),n) + +def lua_recursive_model(module,seq): + for m in module.modules: + name = type(m).__name__ + real = m + if name == 'TorchObject': + name = m._typename.replace('cudnn.','') + m = m._obj + + if name == 'SpatialConvolution' or name == 'nn.SpatialConvolutionMM': + if not hasattr(m,'groups') or m.groups is None: m.groups=1 + n = nn.Conv2d(m.nInputPlane,m.nOutputPlane,(m.kW,m.kH),(m.dW,m.dH),(m.padW,m.padH),1,m.groups,bias=(m.bias is not None)) + print(size(m.weight)) + print(n) + copy_param(m,n) + add_submodule(seq,n) + elif name == 'SpatialBatchNormalization': + n = nn.BatchNorm2d(m.running_mean.size(0), m.eps, m.momentum, m.affine) + copy_param(m,n) + add_submodule(seq,n) + elif name == 'VolumetricBatchNormalization': + n = nn.BatchNorm3d(m.running_mean.size(0), m.eps, m.momentum, m.affine) + copy_param(m, n) + add_submodule(seq, n) + elif name == 'ReLU': + n = nn.ReLU() + add_submodule(seq,n) + elif name == 'Sigmoid': + n = nn.Sigmoid() + add_submodule(seq,n) + elif name == 'SpatialMaxPooling': + n = nn.MaxPool2d((m.kW,m.kH),(m.dW,m.dH),(m.padW,m.padH),ceil_mode=m.ceil_mode) + add_submodule(seq,n) + elif name == 'SpatialAveragePooling': + n = nn.AvgPool2d((m.kW,m.kH),(m.dW,m.dH),(m.padW,m.padH),ceil_mode=m.ceil_mode) + add_submodule(seq,n) + elif name == 'SpatialUpSamplingNearest': + n = nn.UpsamplingNearest2d(scale_factor=m.scale_factor) + add_submodule(seq,n) + elif name == 'View': + n = Lambda(lambda x: x.view(x.size(0),-1)) + add_submodule(seq,n) + elif name == 'Reshape': + n = Lambda(lambda x: x.view(x.size(0),-1)) + add_submodule(seq,n) + elif name == 'Linear': + # Linear in pytorch only accept 2D input + n1 = Lambda(lambda x: x.view(1,-1) if 1==len(x.size()) else x ) + n2 = nn.Linear(m.weight.size(1),m.weight.size(0),bias=(m.bias is not None)) + copy_param(m,n2) + n = nn.Sequential(n1,n2) + add_submodule(seq,n) + elif name == 'Dropout': + m.inplace = False + n = nn.Dropout(m.p) + add_submodule(seq,n) + elif name == 'SoftMax': + n = nn.Softmax() + add_submodule(seq,n) + elif name == 'Identity': + n = Lambda(lambda x: x) # do nothing + add_submodule(seq,n) + elif name == 'SpatialFullConvolution': + n = nn.ConvTranspose2d(m.nInputPlane,m.nOutputPlane,(m.kW,m.kH),(m.dW,m.dH),(m.padW,m.padH),(m.adjW,m.adjH)) + copy_param(m,n) + add_submodule(seq,n) + elif name == 'VolumetricFullConvolution': + n = nn.ConvTranspose3d(m.nInputPlane,m.nOutputPlane,(m.kT,m.kW,m.kH),(m.dT,m.dW,m.dH),(m.padT,m.padW,m.padH),(m.adjT,m.adjW,m.adjH),m.groups) + copy_param(m,n) + add_submodule(seq, n) + elif name == 'SpatialReplicationPadding': + n = nn.ReplicationPad2d((m.pad_l,m.pad_r,m.pad_t,m.pad_b)) + add_submodule(seq,n) + elif name == 'SpatialReflectionPadding': + n = nn.ReflectionPad2d((m.pad_l,m.pad_r,m.pad_t,m.pad_b)) + add_submodule(seq,n) + elif name == 'Copy': + n = Lambda(lambda x: x) # do nothing + add_submodule(seq,n) + elif name == 'Narrow': + n = Lambda(lambda x,a=(m.dimension,m.index,m.length): x.narrow(*a)) + add_submodule(seq,n) + elif name == 'SpatialCrossMapLRN': + lrn = lnn.SpatialCrossMapLRN(m.size,m.alpha,m.beta,m.k) + n = Lambda(lambda x,lrn=lrn: Variable(lrn.forward(x.data))) + add_submodule(seq,n) + elif name == 'Sequential': + n = nn.Sequential() + lua_recursive_model(m,n) + add_submodule(seq,n) + elif name == 'ConcatTable': # output is list + n = LambdaMap(lambda x: x) + lua_recursive_model(m,n) + add_submodule(seq,n) + elif name == 'CAddTable': # input is list + n = LambdaReduce(lambda x,y: x+y) + add_submodule(seq,n) + elif name == 'Concat': + dim = m.dimension + n = LambdaReduce(lambda x,y,dim=dim: torch.cat((x,y),dim)) + lua_recursive_model(m,n) + add_submodule(seq,n) + elif name == 'TorchObject': + print('Not Implement',name,real._typename) + else: + print('Not Implement',name) + + +def lua_recursive_source(module): + s = [] + for m in module.modules: + name = type(m).__name__ + real = m + if name == 'TorchObject': + name = m._typename.replace('cudnn.','') + m = m._obj + + if name == 'SpatialConvolution' or name == 'nn.SpatialConvolutionMM': + if not hasattr(m,'groups') or m.groups is None: m.groups=1 + s += ['nn.Conv2d({},{},{},{},{},{},{},bias={}),#Conv2d'.format(m.nInputPlane, + m.nOutputPlane,(m.kW,m.kH),(m.dW,m.dH),(m.padW,m.padH),1,m.groups,m.bias is not None)] + elif name == 'SpatialBatchNormalization': + s += ['nn.BatchNorm2d({},{},{},{}),#BatchNorm2d'.format(m.running_mean.size(0), m.eps, m.momentum, m.affine)] + elif name == 'VolumetricBatchNormalization': + s += ['nn.BatchNorm3d({},{},{},{}),#BatchNorm3d'.format(m.running_mean.size(0), m.eps, m.momentum, m.affine)] + elif name == 'ReLU': + s += ['nn.ReLU()'] + elif name == 'Sigmoid': + s += ['nn.Sigmoid()'] + elif name == 'SpatialMaxPooling': + s += ['nn.MaxPool2d({},{},{},ceil_mode={}),#MaxPool2d'.format((m.kW,m.kH),(m.dW,m.dH),(m.padW,m.padH),m.ceil_mode)] + elif name == 'SpatialAveragePooling': + s += ['nn.AvgPool2d({},{},{},ceil_mode={}),#AvgPool2d'.format((m.kW,m.kH),(m.dW,m.dH),(m.padW,m.padH),m.ceil_mode)] + elif name == 'SpatialUpSamplingNearest': + s += ['nn.UpsamplingNearest2d(scale_factor={})'.format(m.scale_factor)] + elif name == 'View': + s += ['Lambda(lambda x: x.view(x.size(0),-1)), # View'] + elif name == 'Reshape': + s += ['Lambda(lambda x: x.view(x.size(0),-1)), # Reshape'] + elif name == 'Linear': + s1 = 'Lambda(lambda x: x.view(1,-1) if 1==len(x.size()) else x )' + s2 = 'nn.Linear({},{},bias={})'.format(m.weight.size(1),m.weight.size(0),(m.bias is not None)) + s += ['nn.Sequential({},{}),#Linear'.format(s1,s2)] + elif name == 'Dropout': + s += ['nn.Dropout({})'.format(m.p)] + elif name == 'SoftMax': + s += ['nn.Softmax()'] + elif name == 'Identity': + s += ['Lambda(lambda x: x), # Identity'] + elif name == 'SpatialFullConvolution': + s += ['nn.ConvTranspose2d({},{},{},{},{},{})'.format(m.nInputPlane, + m.nOutputPlane,(m.kW,m.kH),(m.dW,m.dH),(m.padW,m.padH),(m.adjW,m.adjH))] + elif name == 'VolumetricFullConvolution': + s += ['nn.ConvTranspose3d({},{},{},{},{},{},{})'.format(m.nInputPlane, + m.nOutputPlane,(m.kT,m.kW,m.kH),(m.dT,m.dW,m.dH),(m.padT,m.padW,m.padH),(m.adjT,m.adjW,m.adjH),m.groups)] + elif name == 'SpatialReplicationPadding': + s += ['nn.ReplicationPad2d({})'.format((m.pad_l,m.pad_r,m.pad_t,m.pad_b))] + elif name == 'SpatialReflectionPadding': + s += ['nn.ReflectionPad2d({})'.format((m.pad_l,m.pad_r,m.pad_t,m.pad_b))] + elif name == 'Copy': + s += ['Lambda(lambda x: x), # Copy'] + elif name == 'Narrow': + s += ['Lambda(lambda x,a={}: x.narrow(*a))'.format((m.dimension,m.index,m.length))] + elif name == 'SpatialCrossMapLRN': + lrn = 'lnn.SpatialCrossMapLRN(*{})'.format((m.size,m.alpha,m.beta,m.k)) + s += ['Lambda(lambda x,lrn={}: Variable(lrn.forward(x.data)))'.format(lrn)] + + elif name == 'Sequential': + s += ['nn.Sequential( # Sequential'] + s += lua_recursive_source(m) + s += [')'] + elif name == 'ConcatTable': + s += ['LambdaMap(lambda x: x, # ConcatTable'] + s += lua_recursive_source(m) + s += [')'] + elif name == 'CAddTable': + s += ['LambdaReduce(lambda x,y: x+y), # CAddTable'] + elif name == 'Concat': + dim = m.dimension + s += ['LambdaReduce(lambda x,y,dim={}: torch.cat((x,y),dim), # Concat'.format(m.dimension)] + s += lua_recursive_source(m) + s += [')'] + else: + s += '# ' + name + ' Not Implement,\n' + s = map(lambda x: '\t{}'.format(x),s) + return s + +def simplify_source(s): + s = map(lambda x: x.replace(',(1, 1),(0, 0),1,1,bias=True),#Conv2d',')'),s) + s = map(lambda x: x.replace(',(0, 0),1,1,bias=True),#Conv2d',')'),s) + s = map(lambda x: x.replace(',1,1,bias=True),#Conv2d',')'),s) + s = map(lambda x: x.replace(',bias=True),#Conv2d',')'),s) + s = map(lambda x: x.replace('),#Conv2d',')'),s) + s = map(lambda x: x.replace(',1e-05,0.1,True),#BatchNorm2d',')'),s) + s = map(lambda x: x.replace('),#BatchNorm2d',')'),s) + s = map(lambda x: x.replace(',(0, 0),ceil_mode=False),#MaxPool2d',')'),s) + s = map(lambda x: x.replace(',ceil_mode=False),#MaxPool2d',')'),s) + s = map(lambda x: x.replace('),#MaxPool2d',')'),s) + s = map(lambda x: x.replace(',(0, 0),ceil_mode=False),#AvgPool2d',')'),s) + s = map(lambda x: x.replace(',ceil_mode=False),#AvgPool2d',')'),s) + s = map(lambda x: x.replace(',bias=True)),#Linear',')), # Linear'),s) + s = map(lambda x: x.replace(')),#Linear',')), # Linear'),s) + + s = map(lambda x: '{},\n'.format(x),s) + s = map(lambda x: x[1:],s) + s = reduce(lambda x,y: x+y, s) + return s + +def torch_to_pytorch(t7_filename,outputname=None): + model = load_lua(t7_filename,unknown_classes=True) + if type(model).__name__=='hashable_uniq_dict': model=model.model + model.gradInput = None + slist = lua_recursive_source(lnn.Sequential().add(model)) + s = simplify_source(slist) + header = ''' +import torch +import torch.nn as nn +import torch.legacy.nn as lnn +from functools import reduce +from torch.autograd import Variable +class LambdaBase(nn.Sequential): + def __init__(self, fn, *args): + super(LambdaBase, self).__init__(*args) + self.lambda_func = fn + def forward_prepare(self, input): + output = [] + for module in self._modules.values(): + output.append(module(input)) + return output if output else input +class Lambda(LambdaBase): + def forward(self, input): + return self.lambda_func(self.forward_prepare(input)) +class LambdaMap(LambdaBase): + def forward(self, input): + return list(map(self.lambda_func,self.forward_prepare(input))) +class LambdaReduce(LambdaBase): + def forward(self, input): + return reduce(self.lambda_func,self.forward_prepare(input)) +''' + varname = t7_filename.replace('.t7','').replace('.','_').replace('-','_') + s = '{}\n\n{} = {}'.format(header,varname,s[:-2]) + + if outputname is None: outputname=varname + with open(outputname+'.py', "w") as pyfile: + pyfile.write(s) + + n = nn.Sequential() + lua_recursive_model(model,n) + torch.save(n.state_dict(),outputname+'.pth') + + +parser = argparse.ArgumentParser(description='Convert torch t7 model to pytorch') +parser.add_argument('--model','-m', type=str, required=True, + help='torch model file in t7 format') +parser.add_argument('--output', '-o', type=str, default=None, + help='output file name prefix, xxx.py xxx.pth') +args = parser.parse_args() + +torch_to_pytorch(args.model,args.output) \ No newline at end of file diff --git a/myutils/utils.py b/myutils/utils.py new file mode 100644 index 0000000..92f5eac --- /dev/null +++ b/myutils/utils.py @@ -0,0 +1,94 @@ +import os + +import numpy as np +import torch +from PIL import Image +from torch.autograd import Variable +from torch.utils.serialization import load_lua + +from myutils.vgg16 import Vgg16 + +def tensor_load_rgbimage(filename, size=None, scale=None, keep_asp=False): + img = Image.open(filename).convert('RGB') + if size is not None: + if keep_asp: + size2 = int(size * 1.0 / img.size[0] * img.size[1]) + img = img.resize((size, size2), Image.ANTIALIAS) + else: + img = img.resize((size, size), Image.ANTIALIAS) + + elif scale is not None: + img = img.resize((int(img.size[0] / scale), int(img.size[1] / scale)), Image.ANTIALIAS) + img = np.array(img).transpose(2, 0, 1) + img = torch.from_numpy(img).float() + return img + + +def tensor_save_rgbimage(tensor, filename, cuda=False): + if cuda: + img = tensor.clone().cpu().clamp(0, 255).numpy() + else: + img = tensor.clone().clamp(0, 255).numpy() + img = img.transpose(1, 2, 0).astype('uint8') + img = Image.fromarray(img) + img.save(filename) + + +def tensor_save_bgrimage(tensor, filename, cuda=False): + (b, g, r) = torch.chunk(tensor, 3) + tensor = torch.cat((r, g, b)) + tensor_save_rgbimage(tensor, filename, cuda) + + +def gram_matrix(y): + (b, ch, h, w) = y.size() + features = y.view(b, ch, w * h) + features_t = features.transpose(1, 2) + gram = features.bmm(features_t) / (ch * h * w) + return gram + + +def subtract_imagenet_mean_batch(batch): + """Subtract ImageNet mean pixel-wise from a BGR image.""" + tensortype = type(batch.data) + mean = tensortype(batch.data.size()) + mean[:, 0, :, :] = 103.939 + mean[:, 1, :, :] = 116.779 + mean[:, 2, :, :] = 123.680 + return batch - Variable(mean) + + +def add_imagenet_mean_batch(batch): + """Add ImageNet mean pixel-wise from a BGR image.""" + tensortype = type(batch.data) + mean = tensortype(batch.data.size()) + mean[:, 0, :, :] = 103.939 + mean[:, 1, :, :] = 116.779 + mean[:, 2, :, :] = 123.680 + return batch + Variable(mean) + +def imagenet_clamp_batch(batch, low, high): + batch[:,0,:,:].data.clamp_(low-103.939, high-103.939) + batch[:,1,:,:].data.clamp_(low-116.779, high-116.779) + batch[:,2,:,:].data.clamp_(low-123.680, high-123.680) + + +def preprocess_batch(batch): + batch = batch.transpose(0, 1) + (r, g, b) = torch.chunk(batch, 3) + batch = torch.cat((b, g, r)) + batch = batch.transpose(0, 1) + return batch + + +def init_vgg16(model_folder): + """load the vgg16 model feature""" + if not os.path.exists(os.path.join(model_folder, 'vgg16.weight')): + if not os.path.exists(os.path.join(model_folder, 'vgg16.t7')): + os.system( + 'wget http://cs.stanford.edu/people/jcjohns/fast-neural-style/models/vgg16.t7 -O ' + os.path.join(model_folder, 'vgg16.t7')) + vgglua = load_lua(os.path.join(model_folder, 'vgg16.t7')) + vgg = Vgg16() + for (src, dst) in zip(vgglua.parameters()[0], vgg.parameters()): + dst.data[:] = src + torch.save(vgg.state_dict(), os.path.join(model_folder, 'vgg16.weight')) diff --git a/myutils/utils.pyc b/myutils/utils.pyc new file mode 100644 index 0000000..d58d726 Binary files /dev/null and b/myutils/utils.pyc differ diff --git a/myutils/vgg16.py b/myutils/vgg16.py new file mode 100644 index 0000000..109acf4 --- /dev/null +++ b/myutils/vgg16.py @@ -0,0 +1,49 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class Vgg16(torch.nn.Module): + def __init__(self): + super(Vgg16, self).__init__() + self.conv1_1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1) + self.conv1_2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) + + self.conv2_1 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1) + self.conv2_2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1) + + self.conv3_1 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1) + self.conv3_2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1) + self.conv3_3 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1) + + self.conv4_1 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1) + self.conv4_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) + self.conv4_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) + + self.conv5_1 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) + self.conv5_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) + self.conv5_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) + + def forward(self, X): + h = F.relu(self.conv1_1(X)) + h = F.relu(self.conv1_2(h)) + relu1_2 = h + h = F.max_pool2d(h, kernel_size=2, stride=2) + + h = F.relu(self.conv2_1(h)) + h = F.relu(self.conv2_2(h)) + relu2_2 = h + h = F.max_pool2d(h, kernel_size=2, stride=2) + + h = F.relu(self.conv3_1(h)) + h = F.relu(self.conv3_2(h)) + h = F.relu(self.conv3_3(h)) + relu3_3 = h + h = F.max_pool2d(h, kernel_size=2, stride=2) + + h = F.relu(self.conv4_1(h)) + h = F.relu(self.conv4_2(h)) + h = F.relu(self.conv4_3(h)) + relu4_3 = h + + return [relu1_2, relu2_2, relu3_3, relu4_3] diff --git a/myutils/vgg16.pyc b/myutils/vgg16.pyc new file mode 100644 index 0000000..842f416 Binary files /dev/null and b/myutils/vgg16.pyc differ diff --git a/pretrained_models/Deblur_epoch_Best.pth b/pretrained_models/Deblur_epoch_Best.pth new file mode 100644 index 0000000..55d61c4 Binary files /dev/null and b/pretrained_models/Deblur_epoch_Best.pth differ diff --git a/pretrained_models/SMaps_Best.pth b/pretrained_models/SMaps_Best.pth new file mode 100644 index 0000000..3772fc5 Binary files /dev/null and b/pretrained_models/SMaps_Best.pth differ diff --git a/test_face_deblur.py b/test_face_deblur.py new file mode 100644 index 0000000..ddd09b3 --- /dev/null +++ b/test_face_deblur.py @@ -0,0 +1,314 @@ +from __future__ import print_function +import argparse +import os +import sys +import random +import torch +import torch.nn as nn +import torch.nn.parallel +import torch.backends.cudnn as cudnn +cudnn.benchmark = True +cudnn.fastest = True +import torch.optim as optim +import torchvision.utils as vutils +from torch.autograd import Variable + +from misc import * +import models.face_fed as net + +from myutils.vgg16 import Vgg16 +from myutils import utils +import pdb + +# Pre-defined Parameters +parser = argparse.ArgumentParser() +parser.add_argument('--dataset', required=False, + default='pix2pix', help='') +parser.add_argument('--dataroot', required=False, + default='', help='path to trn dataset') +parser.add_argument('--valDataroot', required=False, + default='', help='path to val dataset') +parser.add_argument('--mode', type=str, default='B2A', help='B2A: facade, A2B: edges2shoes') +parser.add_argument('--batchSize', type=int, default=1, help='input batch size') +parser.add_argument('--valBatchSize', type=int, default=1, help='input batch size') +parser.add_argument('--originalSize', type=int, + default=128, help='the height / width of the original input image') +parser.add_argument('--imageSize', type=int, + default=128, help='the height / width of the cropped input image to network') +parser.add_argument('--inputChannelSize', type=int, + default=3, help='size of the input channels') +parser.add_argument('--outputChannelSize', type=int, + default=3, help='size of the output channels') +parser.add_argument('--ngf', type=int, default=64) +parser.add_argument('--ndf', type=int, default=64) +parser.add_argument('--niter', type=int, default=400, help='number of epochs to train for') +parser.add_argument('--lrD', type=float, default=0.0002, help='learning rate, default=0.0002') +parser.add_argument('--lrG', type=float, default=0.0002, help='learning rate, default=0.0002') +parser.add_argument('--annealStart', type=int, default=0, help='annealing learning rate start to') +parser.add_argument('--annealEvery', type=int, default=400, help='epoch to reaching at learning rate of 0') +parser.add_argument('--lambdaGAN', type=float, default=0.01, help='lambdaGAN') +parser.add_argument('--lambdaIMG', type=float, default=1, help='lambdaIMG') +parser.add_argument('--poolSize', type=int, default=50, help='Buffer size for storing previously generated samples from G') +parser.add_argument('--wd', type=float, default=0.0000, help='weight decay in D') +parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam') +parser.add_argument('--netG', default='', help="path to netG (to continue training)") +parser.add_argument('--netD', default='', help="path to netD (to continue training)") +parser.add_argument('--workers', type=int, help='number of data loading workers', default=1) +parser.add_argument('--exp', default='sample', help='folder to output images and model checkpoints') +parser.add_argument('--display', type=int, default=5, help='interval for displaying train-logs') +parser.add_argument('--evalIter', type=int, default=500, help='interval for evauating(generating) images from valDataroot') +opt = parser.parse_args() +print(opt) + + + +create_exp_dir(opt.exp) +opt.manualSeed = random.randint(1, 10000) +random.seed(opt.manualSeed) +torch.manual_seed(opt.manualSeed) +torch.cuda.manual_seed_all(opt.manualSeed) +print("Random Seed: ", opt.manualSeed) + +# Initialize dataloader +dataloader = getLoader(opt.dataset, + opt.dataroot, + opt.originalSize, + opt.imageSize, + opt.batchSize, + opt.workers, + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), + split='val', + shuffle=True, + seed=opt.manualSeed) +opt.dataset='pix2pix_val' + +valDataloader = getLoader(opt.dataset, + opt.valDataroot, + opt.originalSize, #opt.originalSize, + opt.imageSize, + opt.valBatchSize, + opt.workers, + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), + split='val', + shuffle=False, + seed=opt.manualSeed) + +# get logger +trainLogger = open('%s/train.log' % opt.exp, 'w') + + + +ngf = opt.ngf +ndf = opt.ndf +inputChannelSize = opt.inputChannelSize +outputChannelSize= opt.outputChannelSize + + +# Load Pre-trained derain model +netS=net.Segmentation() +netG=net.Deblur_segdl() + +#netC.apply(weights_init) + + +netG.apply(weights_init) +if opt.netG != '': + state_dict_g = torch.load(opt.netG) + new_state_dict_g = {} + for k, v in state_dict_g.items(): + name = k[7:] # remove `module.` + #print(k) + new_state_dict_g[name] = v + # load params + netG.load_state_dict(new_state_dict_g) + #netG.load_state_dict(torch.load(opt.netG)) +print(netG) +netG.eval() +#netS.apply(weights_init) +netS.load_state_dict(torch.load('./pretrained_models/SMaps_Best.pth')) +#netS.eval() +netS.cuda() +netG.cuda() + +# Initialize testing data +target= torch.FloatTensor(opt.batchSize, outputChannelSize, opt.imageSize, opt.imageSize) +input = torch.FloatTensor(opt.batchSize, inputChannelSize, opt.imageSize, opt.imageSize) + +val_target= torch.FloatTensor(opt.valBatchSize, outputChannelSize, opt.imageSize, opt.imageSize) +val_input = torch.FloatTensor(opt.valBatchSize, inputChannelSize, opt.imageSize, opt.imageSize) +label_d = torch.FloatTensor(opt.batchSize) + + +target = torch.FloatTensor(opt.batchSize, outputChannelSize, opt.imageSize, opt.imageSize) +input = torch.FloatTensor(opt.batchSize, inputChannelSize, opt.imageSize, opt.imageSize) +depth = torch.FloatTensor(opt.batchSize, inputChannelSize, opt.imageSize, opt.imageSize) +ato = torch.FloatTensor(opt.batchSize, inputChannelSize, opt.imageSize, opt.imageSize) + + +val_target = torch.FloatTensor(opt.valBatchSize, outputChannelSize, opt.imageSize, opt.imageSize) +val_input = torch.FloatTensor(opt.valBatchSize, inputChannelSize, opt.imageSize, opt.imageSize) +val_depth = torch.FloatTensor(opt.valBatchSize, inputChannelSize, opt.imageSize, opt.imageSize) +val_ato = torch.FloatTensor(opt.valBatchSize, inputChannelSize, opt.imageSize, opt.imageSize) + + +target_128= torch.FloatTensor(opt.batchSize, outputChannelSize, (opt.imageSize/4), (opt.imageSize/4)) +input_128 = torch.FloatTensor(opt.batchSize, inputChannelSize, (opt.imageSize/4), (opt.imageSize/4)) +target_256= torch.FloatTensor(opt.batchSize, outputChannelSize, (opt.imageSize/2), (opt.imageSize/2)) +input_256 = torch.FloatTensor(opt.batchSize, inputChannelSize, (opt.imageSize/2), (opt.imageSize/2)) + +val_target_128= torch.FloatTensor(opt.batchSize, outputChannelSize, (opt.imageSize/4), (opt.imageSize/4)) +val_input_128 = torch.FloatTensor(opt.batchSize, inputChannelSize, (opt.imageSize/4), (opt.imageSize/4)) +val_target_256= torch.FloatTensor(opt.batchSize, outputChannelSize, (opt.imageSize/2), (opt.imageSize/2)) +val_input_256 = torch.FloatTensor(opt.batchSize, inputChannelSize, (opt.imageSize/2), (opt.imageSize/2)) + +target, input, depth, ato = target.cuda(), input.cuda(), depth.cuda(), ato.cuda() +val_target, val_input, val_depth, val_ato = val_target.cuda(), val_input.cuda(), val_depth.cuda(), val_ato.cuda() + +target = Variable(target, volatile=True) +input = Variable(input,volatile=True) +depth = Variable(depth,volatile=True) +ato = Variable(ato,volatile=True) + +target_128, input_128 = target_128.cuda(), input_128.cuda() +val_target_128, val_input_128 = val_target_128.cuda(), val_input_128.cuda() +target_256, input_256 = target_256.cuda(), input_256.cuda() +val_target_256, val_input_256 = val_target_256.cuda(), val_input_256.cuda() + +target_128 = Variable(target_128) +input_128 = Variable(input_128) +target_256 = Variable(target_256) +input_256 = Variable(input_256) + +label_d = Variable(label_d.cuda()) + + + + +def norm_ip(img, min, max): + img.clamp_(min=min, max=max) + img.add_(-min).div_(max - min) + + return img + + +def norm_range(t, range): + if range is not None: + norm_ip(t, range[0], range[1]) + else: + norm_ip(t, -1, +1) + + return t#norm_ip(t, t.min(), t.max()) + +# get optimizer +optimizerG = optim.Adam(netG.parameters(), lr = opt.lrG, betas = (opt.beta1, 0.999), weight_decay=0.00005) + + +# Begin Testing +for epoch in range(1): + heavy, medium, light=200, 200, 200 + for i, data in enumerate(valDataloader, 0): + if 1: + print('Image:'+str(i)) + import time + data_val = data + + t0 = time.time() + + val_input_cpu, val_target_cpu = data_val + + val_target_cpu, val_input_cpu = val_target_cpu.float().cuda(), val_input_cpu.float().cuda() + val_batch_output = torch.FloatTensor(val_input.size()).fill_(0) + + val_input.resize_as_(val_input_cpu).copy_(val_input_cpu) + val_target=Variable(val_target_cpu, volatile=True) + + + z=0 + + + + with torch.no_grad(): + for idx in range(val_input.size(0)): + single_img = val_input[idx,:,:,:].unsqueeze(0) + val_inputv = Variable(single_img, volatile=True) + print (val_inputv.size()) + # val_inputv = val_inputv.float().cuda() + val_inputv_256 = torch.nn.functional.interpolate(val_inputv,scale_factor=0.5) + val_inputv_128 = torch.nn.functional.interpolate(val_inputv,scale_factor=0.25) + + ## Get de-rained results ## + #residual_val, x_hat_val, x_hatlv128, x_hatvl256 = netG(val_inputv, val_inputv_256, val_inputv_128) + + t1 = time.time() + print('running time:'+str(t1 - t0)) + from PIL import Image + + #x_hat_val = netG(val_inputv) + #smaps_vl = netS(val_inputv) + #S_valinput = torch.cat([smaps_vl,val_inputv],1) + """smaps,smaps64 = netS(val_inputv,val_inputv_256) + S_input = torch.cat([smaps,val_inputv],1) + x_hat_val, x_hat_val64 = netG(val_inputv,val_inputv_256,smaps,smaps64)""" + + + #x_hatcls1,x_hatcls2,x_hatcls3,x_hatcls4,x_lst1,x_lst2,x_lst3,x_lst4 = netG(val_inputv,val_inputv_256,smaps_i,smaps_i64,class1,class2,class3,class4) + smaps,smaps64 = netS(val_inputv,val_inputv_256) + class1 = torch.zeros([1,1,128,128], dtype=torch.float32) + class1[:,0,:,:] = smaps[:,0,:,:] + class2 = torch.zeros([1,1,128,128], dtype=torch.float32) + class2[:,0,:,:] = smaps[:,1,:,:] + class3 = torch.zeros([1,1,128,128], dtype=torch.float32) + class3[:,0,:,:] = smaps[:,2,:,:] + class4 = torch.zeros([1,1,128,128], dtype=torch.float32) + class4[:,0,:,:] = smaps[:,3,:,:] + class_msk1 = torch.zeros([1,3,128,128], dtype=torch.float32) + class_msk1[:,0,:,:] = smaps[:,0,:,:] + class_msk1[:,1,:,:] = smaps[:,0,:,:] + class_msk1[:,2,:,:] = smaps[:,0,:,:] + class_msk2 = torch.zeros([1,3,128,128], dtype=torch.float32) + class_msk2[:,0,:,:] = smaps[:,1,:,:] + class_msk2[:,1,:,:] = smaps[:,1,:,:] + class_msk2[:,2,:,:] = smaps[:,1,:,:] + class_msk3 = torch.zeros([1,3,128,128], dtype=torch.float32) + class_msk3[:,0,:,:] = smaps[:,2,:,:] + class_msk3[:,1,:,:] = smaps[:,2,:,:] + class_msk3[:,2,:,:] = smaps[:,2,:,:] + class_msk4 = torch.zeros([1,3,128,128], dtype=torch.float32) + class_msk4[:,0,:,:] = smaps[:,3,:,:] + class_msk4[:,1,:,:] = smaps[:,3,:,:] + class_msk4[:,2,:,:] = smaps[:,3,:,:] + class1 = class1.float().cuda() + class2 = class2.float().cuda() + class3 = class3.float().cuda() + class4 = class4.float().cuda() + class_msk4 = class_msk4.float().cuda() + class_msk3 = class_msk3.float().cuda() + class_msk2 = class_msk2.float().cuda() + class_msk1 = class_msk1.float().cuda() + x_hat_val, x_hat_val64,xmask1,xmask2,xmask3,xmask4,xcl_class1,xcl_class2,xcl_class3,xcl_class4 = netG(val_inputv,val_inputv_256,smaps,class1,class2,class3,class4,val_inputv,class_msk1,class_msk2,class_msk3,class_msk4) + # x_hat1,x_hat64,xmask1,xmask2,xmask3,xmask4,xcl_class1,xcl_class2,xcl_class3,xcl_class4 = netG(input,input_256,smaps_i,class1,class2,class3,class4,target,class_msk1,class_msk2,class_msk3,class_msk4) + #x_hat_val.data + #val_batch_output[idx,:,:,:].copy_(x_hat_val.data[0,1,:,:]) + # print(torch.mean(xmask1),torch.mean(xmask2),torch.mean(xmask3),torch.mean(xmask4)) + print (smaps.size()) + tensor = x_hat_val.data.cpu() + + + ### Save the de-rained results ##### + from PIL import Image + directory = './result_all/deblurh/'#'./result_all/new_model_data/DID-MDN/' + if not os.path.exists(directory): + os.makedirs(directory) + + tensor = torch.squeeze(tensor) + tensor=norm_range(tensor, None) + print(tensor.min(),tensor.max()) + + filename='./result_all/deblurh/'+str(i+1)+'.png' + ndarr = tensor.mul(255).clamp(0, 255).byte().permute(1, 2, 0).numpy() + im = Image.fromarray(ndarr) + + im.save(filename) + + + diff --git a/train_face_deblur.py b/train_face_deblur.py new file mode 100644 index 0000000..db294dd --- /dev/null +++ b/train_face_deblur.py @@ -0,0 +1,514 @@ +from __future__ import print_function +import argparse +import os +import sys +import random +import torch +import torch.nn as nn +import torch.nn.parallel +import torch.backends.cudnn as cudnn +cudnn.benchmark = True +cudnn.fastest = True +import torch.optim as optim +import torchvision.utils as vutils +from torch.autograd import Variable + +from misc import * +import models.face_fed as net + + +from myutils.vgg16 import Vgg16 +from myutils import utils +import pdb +import torch.nn.functional as F +#from PIL import Image +from torchvision import transforms + +import h5py +from os import listdir +from os.path import isfile, join + +parser = argparse.ArgumentParser() +parser.add_argument('--dataset', required=False, + default='pix2pix_class', help='') +parser.add_argument('--dataroot', required=False, + default='', help='path to trn dataset') +parser.add_argument('--valDataroot', required=False, + default='', help='path to val dataset') +parser.add_argument('--mode', type=str, default='B2A', help='B2A: facade, A2B: edges2shoes') +parser.add_argument('--batchSize', type=int, default=1, help='input batch size') +parser.add_argument('--valBatchSize', type=int, default=120, help='input batch size') +parser.add_argument('--originalSize', type=int, + default=175, help='the height / width of the original input image') +parser.add_argument('--imageSize', type=int, + default=128, help='the height / width of the cropped input image to network') +parser.add_argument('--inputChannelSize', type=int, + default=3, help='size of the input channels') +parser.add_argument('--outputChannelSize', type=int, + default=3, help='size of the output channels') +parser.add_argument('--ngf', type=int, default=64) +parser.add_argument('--ndf', type=int, default=64) +parser.add_argument('--niter', type=int, default=5000, help='number of epochs to train for') +parser.add_argument('--lrD', type=float, default=0.0002, help='learning rate, default=0.0002') +parser.add_argument('--lrG', type=float, default=0.0002, help='learning rate, default=0.0002') +parser.add_argument('--annealStart', type=int, default=0, help='annealing learning rate start to') +parser.add_argument('--annealEvery', type=int, default=400, help='epoch to reaching at learning rate of 0') +parser.add_argument('--lambdaGAN', type=float, default=0.01, help='lambdaGAN') +parser.add_argument('--lambdaIMG', type=float, default=2.0, help='lambdaIMG') +parser.add_argument('--poolSize', type=int, default=50, help='Buffer size for storing previously generated samples from G') +parser.add_argument('--wd', type=float, default=0.0000, help='weight decay in D') +parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam') +parser.add_argument('--netG', default='', help="path to netG (to continue training)") +parser.add_argument('--netD', default='', help="path to netD (to continue training)") +parser.add_argument('--workers', type=int, help='number of data loading workers', default=1) +parser.add_argument('--exp', default='sample', help='folder to output images and model checkpoints') +parser.add_argument('--display', type=int, default=5, help='interval for displaying train-logs') +parser.add_argument('--evalIter', type=int, default=500, help='interval for evauating(generating) images from valDataroot') +opt = parser.parse_args() +print(opt) + +from scipy import signal +import h5py +from scipy import signal +import random +k_filename ='./kernel.mat' +kfp = h5py.File(k_filename) +kernels = np.array(kfp['kernels']) +kernels = kernels.transpose([0,2,1]) + +create_exp_dir(opt.exp) +opt.manualSeed = random.randint(1, 10000) +# opt.manualSeed = 101 +random.seed(opt.manualSeed) +torch.manual_seed(opt.manualSeed) +torch.cuda.manual_seed_all(opt.manualSeed) +print("Random Seed: ", opt.manualSeed) + +# get dataloader +opt.dataset='pix2pix_val' +print (opt.dataroot) +dataloader = getLoader(opt.dataset, + opt.dataroot, + opt.originalSize, + opt.imageSize, + opt.batchSize, + opt.workers, + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), + split='train', + shuffle=True, + seed=opt.manualSeed) + +opt.dataset='pix2pix_val' +valDataloader = getLoader(opt.dataset, + opt.valDataroot, + opt.originalSize, + opt.imageSize, + opt.valBatchSize, + opt.workers, + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), + split='val', + shuffle=False, + seed=opt.manualSeed) + + +# get logger +trainLogger = open('%s/train.log' % opt.exp, 'w') + +def gradient(y): + gradient_h=torch.abs(y[:, :, :, :-1] - y[:, :, :, 1:]) + gradient_y=torch.abs(y[:, :, :-1, :] - y[:, :, 1:, :]) + + return gradient_h, gradient_y + + +ngf = opt.ngf +ndf = opt.ndf +inputChannelSize = opt.inputChannelSize +outputChannelSize= opt.outputChannelSize + +# get models +netS=net.Segmentation() +netG=net.Deblur_segdl() + + +netS.load_state_dict(torch.load('./pretrained_models/SMaps_Best.pth')) + + +# state_dict_g = torch.load('./face_deblur/Deblur_epoch_46.pth') +# new_state_dict_g = {} +# for k, v in state_dict_g.items(): +# name = k[7:] # remove `module.` +# #print(k) +# new_state_dict_g[name] = v +# # load params +# netG.load_state_dict(new_state_dict_g) + +netG = torch.nn.DataParallel(netG) +netS = torch.nn.DataParallel(netS) +netG.train() +criterionCAE = nn.L1Loss() +criterionCAE1 = nn.SmoothL1Loss() + +target= torch.FloatTensor(opt.batchSize, outputChannelSize, opt.imageSize, opt.imageSize) +input = torch.FloatTensor(opt.batchSize, inputChannelSize, opt.imageSize, opt.imageSize) +target_128= torch.FloatTensor(opt.batchSize, outputChannelSize, (opt.imageSize/4), (opt.imageSize/4)) +input_128 = torch.FloatTensor(opt.batchSize, inputChannelSize, (opt.imageSize/4), (opt.imageSize/4)) +target_256= torch.FloatTensor(opt.batchSize, outputChannelSize, (opt.imageSize/2), (opt.imageSize/2)) +input_256 = torch.FloatTensor(opt.batchSize, inputChannelSize, (opt.imageSize/2), (opt.imageSize/2)) + + + + +val_target= torch.FloatTensor(opt.valBatchSize, outputChannelSize, opt.imageSize, opt.imageSize) +val_input = torch.FloatTensor(opt.valBatchSize, inputChannelSize, opt.imageSize, opt.imageSize) +val_target_128= torch.FloatTensor(opt.batchSize, outputChannelSize, (opt.imageSize/4), (opt.imageSize/4)) +val_input_128 = torch.FloatTensor(opt.batchSize, inputChannelSize, (opt.imageSize/4), (opt.imageSize/4)) +val_target_256= torch.FloatTensor(opt.batchSize, outputChannelSize, (opt.imageSize/2), (opt.imageSize/2)) +val_input_256 = torch.FloatTensor(opt.batchSize, inputChannelSize, (opt.imageSize/2), (opt.imageSize/2)) +label_d = torch.FloatTensor(opt.batchSize) + + +target = torch.FloatTensor(opt.batchSize, outputChannelSize, opt.imageSize, opt.imageSize) +input = torch.FloatTensor(opt.batchSize, inputChannelSize, opt.imageSize, opt.imageSize) +depth = torch.FloatTensor(opt.batchSize, inputChannelSize, opt.imageSize, opt.imageSize) +ato = torch.FloatTensor(opt.batchSize, inputChannelSize, opt.imageSize, opt.imageSize) + + +val_target = torch.FloatTensor(opt.valBatchSize, outputChannelSize, opt.imageSize, opt.imageSize) +val_input = torch.FloatTensor(opt.valBatchSize, inputChannelSize, opt.imageSize, opt.imageSize) +val_depth = torch.FloatTensor(opt.valBatchSize, inputChannelSize, opt.imageSize, opt.imageSize) +val_ato = torch.FloatTensor(opt.valBatchSize, inputChannelSize, opt.imageSize, opt.imageSize) + + + + +# NOTE: size of 2D output maps in the discriminator +sizePatchGAN = 30 +real_label = 1 +fake_label = 0 + +# image pool storing previously generated samples from G +imagePool = ImagePool(opt.poolSize) + +# NOTE weight for L_cGAN and L_L1 (i.e. Eq.(4) in the paper) +lambdaGAN = opt.lambdaGAN +lambdaIMG = opt.lambdaIMG + +netG.cuda() +#netC.cuda() +netS.cuda() +criterionCAE.cuda() +criterionCAE1.cuda() + + + + +target, input, depth, ato = target.cuda(), input.cuda(), depth.cuda(), ato.cuda() +val_target, val_input, val_depth, val_ato = val_target.cuda(), val_input.cuda(), val_depth.cuda(), val_ato.cuda() + +target = Variable(target) +input = Variable(input) + +target_128, input_128 = target_128.cuda(), input_128.cuda() +val_target_128, val_input_128 = val_target_128.cuda(), val_input_128.cuda() +target_256, input_256 = target_256.cuda(), input_256.cuda() +val_target_256, val_input_256 = val_target_256.cuda(), val_input_256.cuda() + +target_128 = Variable(target_128) +input_128 = Variable(input_128) +target_256 = Variable(target_256) +input_256 = Variable(input_256) +# input = Variable(input,requires_grad=False) +# depth = Variable(depth) +ato = Variable(ato) + +# Initialize VGG-16 +vgg = Vgg16() +utils.init_vgg16('./models/') +vgg.load_state_dict(torch.load(os.path.join('./models/', "vgg16.weight"))) +vgg.cuda() + + +label_d = Variable(label_d.cuda()) + +# get randomly sampled validation images and save it +print(len(dataloader)) +val_iter = iter(valDataloader) +data_val = val_iter.next() + + +val_input_cpu, val_target_cpu = data_val + +val_target_cpu, val_input_cpu = val_target_cpu.float().cuda(), val_input_cpu.float().cuda() + + + +val_target.resize_as_(val_target_cpu).copy_(val_target_cpu) +val_input.resize_as_(val_input_cpu).copy_(val_input_cpu) + +vutils.save_image(val_target, '%s/real_target.png' % opt.exp, normalize=True) +vutils.save_image(val_input, '%s/real_input.png' % opt.exp, normalize=True) + + + + +optimizerG = optim.Adam(netG.parameters(), lr = opt.lrG, betas = (opt.beta1, 0.999), weight_decay=0.00005) +# NOTE training loop +ganIterations = 0 +count = 48 + + +for epoch in range(opt.niter): + if epoch % 19 == 0 and epoch>0: + opt.lrG = opt.lrG/2.0 + for param_group in optimizerG.param_groups: + param_group['lr'] = opt.lrG + if epoch >= opt.annealStart: + adjust_learning_rate(optimizerG, opt.lrG, epoch, None, opt.annealEvery) + + + for i, data in enumerate(dataloader, 0): + + input_cpu, target_cpu = data + batch_size = target_cpu.size(0) + b,ch,x,y = target_cpu.size() + x1 = int((x-opt.imageSize)/2) + y1 = int((y-opt.imageSize)/2) + input_cpu = input_cpu.numpy() + target_cpu = target_cpu.numpy() + for j in range(batch_size): + index = random.randint(0,24500) + input_cpu[j,0,:,:]= signal.convolve(input_cpu[j,0,:,:],kernels[index,:,:],mode='same') + input_cpu[j,1,:,:]= signal.convolve(input_cpu[j,1,:,:],kernels[index,:,:],mode='same') + input_cpu[j,2,:,:]= signal.convolve(input_cpu[j,2,:,:],kernels[index,:,:],mode='same') + input_cpu = input_cpu + (1.0/255.0)* np.random.normal(0,4,input_cpu.shape) + input_cpu = input_cpu[:,:,x1:x1+opt.imageSize,y1:y1+opt.imageSize] + target_cpu = target_cpu[:,:,x1:x1+opt.imageSize,y1:y1+opt.imageSize] + input_cpu = torch.from_numpy(input_cpu) + target_cpu = torch.from_numpy(target_cpu) + + + + + target_cpu, input_cpu = target_cpu.float().cuda(), input_cpu.float().cuda() + + + # get paired data + target.data.resize_as_(target_cpu).copy_(target_cpu) + input.data.resize_as_(input_cpu).copy_(input_cpu) + input_256 = torch.nn.functional.interpolate(input,scale_factor=0.5) + target_256 = torch.nn.functional.interpolate(target,scale_factor=0.5) + + + + with torch.no_grad(): + smaps_i,smaps_i64 = netS(input,input_256) + smaps,smaps64 = netS(target,target_256) + class1 = torch.zeros([batch_size,1,128,128], dtype=torch.float32) + class1[:,0,:,:] = smaps_i[:,0,:,:] + class2 = torch.zeros([batch_size,1,128,128], dtype=torch.float32) + class2[:,0,:,:] = smaps_i[:,1,:,:] + class3 = torch.zeros([batch_size,1,128,128], dtype=torch.float32) + class3[:,0,:,:] = smaps_i[:,2,:,:] + class4 = torch.zeros([batch_size,1,128,128], dtype=torch.float32) + class4[:,0,:,:] = smaps_i[:,3,:,:] + class_msk1 = torch.zeros([batch_size,3,128,128], dtype=torch.float32) + class_msk1[:,0,:,:] = smaps[:,0,:,:] + class_msk1[:,1,:,:] = smaps[:,0,:,:] + class_msk1[:,2,:,:] = smaps[:,0,:,:] + class_msk2 = torch.zeros([batch_size,3,128,128], dtype=torch.float32) + class_msk2[:,0,:,:] = smaps[:,1,:,:] + class_msk2[:,1,:,:] = smaps[:,1,:,:] + class_msk2[:,2,:,:] = smaps[:,1,:,:] + class_msk3 = torch.zeros([batch_size,3,128,128], dtype=torch.float32) + class_msk3[:,0,:,:] = smaps[:,2,:,:] + class_msk3[:,1,:,:] = smaps[:,2,:,:] + class_msk3[:,2,:,:] = smaps[:,2,:,:] + class_msk4 = torch.zeros([batch_size,3,128,128], dtype=torch.float32) + class_msk4[:,0,:,:] = smaps[:,3,:,:] + class_msk4[:,1,:,:] = smaps[:,3,:,:] + class_msk4[:,2,:,:] = smaps[:,3,:,:] + + + + + + class1 = class1.float().cuda() + class2 = class2.float().cuda() + class3 = class3.float().cuda() + class4 = class4.float().cuda() + class_msk4 = class_msk4.float().cuda() + class_msk3 = class_msk3.float().cuda() + class_msk2 = class_msk2.float().cuda() + class_msk1 = class_msk1.float().cuda() + x_hat1,x_hat64,xmask1,xmask2,xmask3,xmask4,xcl_class1,xcl_class2,xcl_class3,xcl_class4 = netG(input,input_256,smaps_i,class1,class2,class3,class4,target,class_msk1,class_msk2,class_msk3,class_msk4) + + x_hat = x_hat1 + + + #xeff = conf*x_hat+(1-conf)*target + #xeff_64 = conf_64*x_hat64+(1-conf_64)*target_256 + + + #print(x_hat.size()) + if ganIterations % 2 == 0: + netG.zero_grad() # start to update G + + #x1 = xmask1*class_msk1*x_hat+(1-xmask1)*class_msk1*target + #smaps_hat,smaps64_hat = netS(x_hat1,x_hat64) + if epoch>4 or (epoch<4 and epoch%2 == 0): + with torch.no_grad(): + smaps,smaps64 = netS(target,target_256) + L_img_ = 0.33*criterionCAE(x_hat64, target_256) #+ 0.5*criterionCAE(smaps_hat, smaps) + L_img_ = L_img_ + 1.2 *criterionCAE(xmask1*class_msk1*x_hat+(1-xmask1)*class_msk1*target, class_msk1*target) + L_img_ = L_img_ + 1.2 *criterionCAE(xmask2*class_msk2*x_hat+(1-xmask2)*class_msk2*target, class_msk2*target) + L_img_ = L_img_ + 3.6 *criterionCAE(xmask3*class_msk3*x_hat+(1-xmask3)*class_msk3*target, class_msk3*target) + L_img_ = L_img_ + 1.2 *criterionCAE(xmask4*class_msk4*x_hat+(1-xmask4)*class_msk4*target, class_msk4*target) + if ganIterations % (25*opt.display) == 0: + print(L_img_.data[0]) + sys.stdout.flush() + if ganIterations< 1000: + lam_cmp = 1.0 + else : + lam_cmp = 0.09 + sng = 0.00000001 + L_img_ = L_img_ - (lam_cmp/(4.0))*torch.mean(torch.log(xmask1+sng)) + L_img_ = L_img_ - (lam_cmp/(4.0))*torch.mean(torch.log(xmask2+sng)) + L_img_ = L_img_ - (lam_cmp/(4.0))*torch.mean(torch.log(xmask3+sng)) + L_img_ = L_img_ - (lam_cmp/(4.0))*torch.mean(torch.log(xmask4+sng)) + if ganIterations % (50*opt.display) == 0: + print(L_img_.data[0]) + sys.stdout.flush() + + #L_img_ = L_img_ + 2*criterionCAE(class_msk3*x_hat,class_msk3*target) + # L_res = lambdaIMG * L_res_ + gradh_xhat,gradv_xhat=gradient(x_hat) + gradh_tar,gradv_tar=gradient(target) + gradh_xhat64,gradv_xhat64=gradient(x_hat64) + gradh_tar64,gradv_tar64=gradient(target_256) + L_img_ = L_img_ + 0.15*criterionCAE(gradh_xhat,gradh_tar)+ 0.15*criterionCAE(gradv_xhat,gradv_tar)+ 0.08*criterionCAE(gradh_xhat64,gradh_tar64)+0.08*criterionCAE(gradv_xhat64,gradv_tar64) + if ganIterations % (25*opt.display) == 0: + print(L_img_.data[0]) + print((torch.mean(torch.log(xmask1)).data),(torch.mean(torch.log(xmask2)).data),(torch.mean(xmask3).data),(torch.mean(xmask4).data)) + sys.stdout.flush() + # L_res = lambdaIMG * L_res_ + L_img = lambdaIMG * L_img_ + + if lambdaIMG <> 0: + #L_img.backward(retain_graph=True) # in case of current version of pytorch + L_img.backward(retain_graph=True) + # L_res.backward(retain_variables=True) + + # Perceptual Loss 1 + features_content = vgg(target) + f_xc_c = Variable(features_content[1].data, requires_grad=False) + features_y = vgg(x_hat) + + features_content = vgg(target_256) + f_xc_c64 = Variable(features_content[1].data, requires_grad=False) + features_y64 = vgg(x_hat64) + + content_loss = 1.8*lambdaIMG* criterionCAE(features_y[1], f_xc_c) + 1.8*0.33*lambdaIMG* criterionCAE(features_y64[1], f_xc_c64) + content_loss.backward(retain_graph=True) + + # Perceptual Loss 2 + features_content = vgg(target) + f_xc_c = Variable(features_content[0].data, requires_grad=False) + features_y = vgg(x_hat) + + features_content = vgg(target_256) + f_xc_c64 = Variable(features_content[0].data, requires_grad=False) + features_y64 = vgg(x_hat64) + + content_loss1 = 1.8*lambdaIMG* criterionCAE(features_y[0], f_xc_c) + 1.8*0.33*lambdaIMG* criterionCAE(features_y64[0], f_xc_c64) + content_loss1.backward(retain_graph=True) + + + else: + L_img_ = 1.2 *criterionCAE(xcl_class1, target) + L_img_ = L_img_ + 1.2 *criterionCAE(xcl_class2, target) + L_img_ = L_img_ + 3.6 *criterionCAE(xcl_class3, target) + L_img_ = L_img_ + 1.2 *criterionCAE(xcl_class4, target) + L_img = lambdaIMG * L_img_ + if lambdaIMG <> 0: + L_img.backward(retain_graph=True) + if ganIterations % (25*opt.display) == 0: + print(L_img_.data[0]) + print("updating fisrt stage parameters") + sys.stdout.flush() + + + + if ganIterations % 2 == 0: + optimizerG.step() + ganIterations += 1 + + if ganIterations % opt.display == 0: + print('[%d/%d][%d/%d] Loss: %f ' + % (epoch, opt.niter, i, len(dataloader), + L_img.data[0])) + sys.stdout.flush() + trainLogger.write('%d\t%f\n' % \ + (i, L_img.data[0])) + trainLogger.flush() + if ganIterations % (int(len(dataloader)/2)) == 0: + val_batch_output = torch.zeros([16,3,128,128], dtype=torch.float32)#torch.FloatTensor([10,3,128,128]).fill_(0) + for idx in range(val_input.size(0)): + single_img = val_input[idx,:,:,:].unsqueeze(0) + val_inputv = Variable(single_img, volatile=True) + with torch.no_grad(): + #smaps_vl = netS(val_inputv) + #S_valinput = torch.cat([smaps_vl,val_inputv],1) + index = idx+24500 + #rint(val_inputv.size()) + val_inputv = val_inputv.cpu().numpy() + val_inputv[0,0,:,:]= signal.convolve(val_inputv[0,0,:,:],kernels[index,:,:],mode='same') + val_inputv[0,1,:,:]= signal.convolve(val_inputv[0,1,:,:],kernels[index,:,:],mode='same') + val_inputv[0,2,:,:]= signal.convolve(val_inputv[0,2,:,:],kernels[index,:,:],mode='same') + val_inputv = val_inputv[:,:,x1:x1+opt.imageSize,y1:y1+opt.imageSize] + val_inputv = val_inputv + (1.0/255.0)* np.random.normal(0,4,val_inputv.shape) + val_inputv = torch.from_numpy(val_inputv) + val_inputv = val_inputv.float().cuda() + val_inputv_256 = torch.nn.functional.interpolate(val_inputv,scale_factor=0.5) + #rint(val_inputv.size()) + smaps,smaps64 = netS(val_inputv,val_inputv_256) + class1 = torch.zeros([1,1,128,128], dtype=torch.float32) + class1[:,0,:,:] = smaps[:,0,:,:] + class2 = torch.zeros([1,1,128,128], dtype=torch.float32) + class2[:,0,:,:] = smaps[:,1,:,:] + class3 = torch.zeros([1,1,128,128], dtype=torch.float32) + class3[:,0,:,:] = smaps[:,2,:,:] + class4 = torch.zeros([1,1,128,128], dtype=torch.float32) + class4[:,0,:,:] = smaps[:,3,:,:] + class_msk1 = torch.zeros([1,3,128,128], dtype=torch.float32) + class_msk1[:,0,:,:] = smaps[:,0,:,:] + class_msk1[:,1,:,:] = smaps[:,0,:,:] + class_msk1[:,2,:,:] = smaps[:,0,:,:] + class_msk2 = torch.zeros([1,3,128,128], dtype=torch.float32) + class_msk2[:,0,:,:] = smaps[:,1,:,:] + class_msk2[:,1,:,:] = smaps[:,1,:,:] + class_msk2[:,2,:,:] = smaps[:,1,:,:] + class_msk3 = torch.zeros([1,3,128,128], dtype=torch.float32) + class_msk3[:,0,:,:] = smaps[:,2,:,:] + class_msk3[:,1,:,:] = smaps[:,2,:,:] + class_msk3[:,2,:,:] = smaps[:,2,:,:] + class_msk4 = torch.zeros([1,3,128,128], dtype=torch.float32) + class_msk4[:,0,:,:] = smaps[:,3,:,:] + class_msk4[:,1,:,:] = smaps[:,3,:,:] + class_msk4[:,2,:,:] = smaps[:,3,:,:] + x_hat_val, x_hat_val64,xmask1,xmask2,xmask3,xmask4,xcl_class1,xcl_class2,xcl_class3,xcl_class4 = netG(val_inputv,val_inputv_256,smaps,class1,class2,class3,class4,val_inputv,class_msk1,class_msk2,class_msk3,class_msk4) + #x_hat_val.data[0,:,:,:] = masks*x_hat_val.data[0,:,:,:] + val_batch_output[idx,:,:,:].copy_(x_hat_val.data[0,:,:,:]) + ### We use a random label here just for intermediate result visuliztion (No need to worry about the label here) ## + + + if ganIterations % (int(len(dataloader)/2)) == 0: + vutils.save_image(val_batch_output, '%s/generated_epoch_iter%08d.png' % \ + (opt.exp, ganIterations), normalize=True, scale_each=False) + del val_batch_output + if ganIterations % (int(len(dataloader)/2)) == 0: + torch.save(netG.state_dict(), '%s/Deblur_epoch_%d.pth' % (opt.exp, count)) + #torch.save(netC.state_dict(), '%s/Deblur_first_epoch_%d.pth' % (opt.exp, count)) + count = count +1 +trainLogger.close() + diff --git a/transforms/__init__.py b/transforms/__init__.py new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/transforms/__init__.py @@ -0,0 +1 @@ + diff --git a/transforms/__init__.pyc b/transforms/__init__.pyc new file mode 100644 index 0000000..74b8dbb Binary files /dev/null and b/transforms/__init__.pyc differ diff --git a/transforms/pix2pix.py b/transforms/pix2pix.py new file mode 100644 index 0000000..3985a9a --- /dev/null +++ b/transforms/pix2pix.py @@ -0,0 +1,227 @@ +from __future__ import division +import torch +import math +import random +from PIL import Image, ImageOps +import numpy as np +import numbers +import types + +class Compose(object): + """Composes several transforms together. + Args: + transforms (List[Transform]): list of transforms to compose. + Example: + >>> transforms.Compose([ + >>> transforms.CenterCrop(10), + >>> transforms.ToTensor(), + >>> ]) + """ + def __init__(self, transforms): + self.transforms = transforms + + def __call__(self, imgA, imgB): + for t in self.transforms: + imgA, imgB = t(imgA, imgB) + return imgA, imgB + +class ToTensor(object): + """Converts a PIL.Image or numpy.ndarray (H x W x C) in the range + [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]. + """ + def __call__(self, picA, picB): + pics = [picA, picB] + output = [] + for pic in pics: + if isinstance(pic, np.ndarray): + # handle numpy array + img = torch.from_numpy(pic.transpose((2, 0, 1))) + else: + # handle PIL Image + img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes())) + # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK + if pic.mode == 'YCbCr': + nchannel = 3 + else: + nchannel = len(pic.mode) + img = img.view(pic.size[1], pic.size[0], nchannel) + # put it from HWC to CHW format + # yikes, this transpose takes 80% of the loading time/CPU + img = img.transpose(0, 1).transpose(0, 2).contiguous() + img = img.float().div(255.) + output.append(img) + return output[0], output[1] + +class ToPILImage(object): + """Converts a torch.*Tensor of range [0, 1] and shape C x H x W + or numpy ndarray of dtype=uint8, range[0, 255] and shape H x W x C + to a PIL.Image of range [0, 255] + """ + def __call__(self, picA, picB): + pics = [picA, picB] + output = [] + for pic in pics: + npimg = pic + mode = None + if not isinstance(npimg, np.ndarray): + npimg = pic.mul(255).byte().numpy() + npimg = np.transpose(npimg, (1, 2, 0)) + + if npimg.shape[2] == 1: + npimg = npimg[:, :, 0] + mode = "L" + output.append(Image.fromarray(npimg, mode=mode)) + + return output[0], output[1] + +class Normalize(object): + """Given mean: (R, G, B) and std: (R, G, B), + will normalize each channel of the torch.*Tensor, i.e. + channel = (channel - mean) / std + """ + def __init__(self, mean, std): + self.mean = mean + self.std = std + + def __call__(self, tensorA, tensorB): + tensors = [tensorA, tensorB] + output = [] + for tensor in tensors: + # TODO: make efficient + for t, m, s in zip(tensor, self.mean, self.std): + t.sub_(m).div_(s) + output.append(tensor) + return output[0], output[1] + +class Scale(object): + """Rescales the input PIL.Image to the given 'size'. + 'size' will be the size of the smaller edge. + For example, if height > width, then image will be + rescaled to (size * height / width, size) + size: size of the smaller edge + interpolation: Default: PIL.Image.BILINEAR + """ + def __init__(self, size, interpolation=Image.BILINEAR): + self.size = size + self.interpolation = interpolation + + def __call__(self, imgA, imgB): + imgs = [imgA, imgB] + output = [] + for img in imgs: + # w, h = img.size + # if (w <= h and w == self.size) or (h <= w and h == self.size): + # output.append(img) + # continue + # if w < h: + # ow = self.size + # oh = int(self.size * h / w) + # output.append(img.resize((ow, oh), self.interpolation)) + # continue + # else: + # oh = self.size + # ow = int(self.size * w / h) + oh = self.size + ow = self.size + output.append(img.resize((ow, oh), self.interpolation)) + # print output[0].size + return output[0], output[1] + +class CenterCrop(object): + """Crops the given PIL.Image at the center to have a region of + the given size. size can be a tuple (target_height, target_width) + or an integer, in which case the target will be of a square shape (size, size) + """ + def __init__(self, size): + if isinstance(size, numbers.Number): + self.size = (int(size), int(size)) + else: + self.size = size + + def __call__(self, imgA, imgB): + imgs = [imgA, imgB] + output = [] + for img in imgs: + w, h = img.size + th, tw = self.size + x1 = int(round((w - tw) / 2.)) + y1 = int(round((h - th) / 2.)) + output.append(img.crop((x1, y1, x1 + tw, y1 + th))) + return output[0], output[1] + +class Pad(object): + """Pads the given PIL.Image on all sides with the given "pad" value""" + def __init__(self, padding, fill=0): + assert isinstance(padding, numbers.Number) + assert isinstance(fill, numbers.Number) or isinstance(fill, str) or isinstance(fill, tuple) + self.padding = padding + self.fill = fill + + def __call__(self, imgA, imgB): + imgs = [imgA, imgB] + output = [] + for img in imgs: + output.append(ImageOps.expand(img, border=self.padding, fill=self.fill)) + return output[0], output[1] + +class Lambda(object): + """Applies a lambda as a transform.""" + def __init__(self, lambd): + assert isinstance(lambd, types.LambdaType) + self.lambd = lambd + + def __call__(self, imgA, imgB): + imgs = [imgA, imgB] + output = [] + for img in imgs: + output.append(self.lambd(img)) + return output[0], output[1] + +class RandomCrop(object): + """Crops the given PIL.Image at a random location to have a region of + the given size. size can be a tuple (target_height, target_width) + or an integer, in which case the target will be of a square shape (size, size) + """ + def __init__(self, size, padding=0): + if isinstance(size, numbers.Number): + self.size = (int(size), int(size)) + else: + self.size = size + self.padding = padding + + def __call__(self, imgA, imgB): + imgs = [imgA, imgB] + output = [] + x1 = -1 + y1 = -1 + for img in imgs: + if self.padding > 0: + img = ImageOps.expand(img, border=self.padding, fill=0) + + w, h = img.size + th, tw = self.size + if w == tw and h == th: + output.append(img) + continue + + if x1 == -1 and y1 == -1: + x1 = random.randint(0, w - tw) + y1 = random.randint(0, h - th) + output.append(img.crop((x1, y1, x1 + tw, y1 + th))) + return output[0], output[1] + +class RandomHorizontalFlip(object): + """Randomly horizontally flips the given PIL.Image with a probability of 0.5 + """ + def __call__(self, imgA, imgB): + imgs = [imgA, imgB] + output = [] + # flag = random.random() < 0.5 + flag = random.random() < -1 + + for img in imgs: + if flag: + output.append(img.transpose(Image.FLIP_LEFT_RIGHT)) + else: + output.append(img) + return output[0], output[1] diff --git a/transforms/pix2pix.pyc b/transforms/pix2pix.pyc new file mode 100644 index 0000000..416cd9c Binary files /dev/null and b/transforms/pix2pix.pyc differ diff --git a/transforms/pix2pix3.py b/transforms/pix2pix3.py new file mode 100644 index 0000000..755e071 --- /dev/null +++ b/transforms/pix2pix3.py @@ -0,0 +1,224 @@ +from __future__ import division +import torch +import math +import random +from PIL import Image, ImageOps +import numpy as np +import numbers +import types + +class Compose(object): + """Composes several transforms together. + Args: + transforms (List[Transform]): list of transforms to compose. + Example: + >>> transforms.Compose([ + >>> transforms.CenterCrop(10), + >>> transforms.ToTensor(), + >>> ]) + """ + def __init__(self, transforms): + self.transforms = transforms + + def __call__(self, imgA, imgB, imgC): + for t in self.transforms: + imgA, imgB, imgC = t(imgA, imgB, imgC) + return imgA, imgB, imgC + +class ToTensor(object): + """Converts a PIL.Image or numpy.ndarray (H x W x C) in the range + [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]. + """ + def __call__(self, picA, picB, picC): + pics = [picA, picB, picC] + output = [] + for pic in pics: + if isinstance(pic, np.ndarray): + # handle numpy array + img = torch.from_numpy(pic.transpose((2, 0, 1))) + else: + # handle PIL Image + img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes())) + # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK + if pic.mode == 'YCbCr': + nchannel = 3 + else: + nchannel = len(pic.mode) + img = img.view(pic.size[1], pic.size[0], nchannel) + # put it from HWC to CHW format + # yikes, this transpose takes 80% of the loading time/CPU + img = img.transpose(0, 1).transpose(0, 2).contiguous() + img = img.float().div(255.) + output.append(img) + return output[0], output[1], output[2] + +class ToPILImage(object): + """Converts a torch.*Tensor of range [0, 1] and shape C x H x W + or numpy ndarray of dtype=uint8, range[0, 255] and shape H x W x C + to a PIL.Image of range [0, 255] + """ + def __call__(self, picA, picB, picC): + pics = [picA, picB, picC] + output = [] + for pic in pics: + npimg = pic + mode = None + if not isinstance(npimg, np.ndarray): + npimg = pic.mul(255).byte().numpy() + npimg = np.transpose(npimg, (1, 2, 0)) + + if npimg.shape[2] == 1: + npimg = npimg[:, :, 0] + mode = "L" + output.append(Image.fromarray(npimg, mode=mode)) + + return output[0], output[1], output[2] + +class Normalize(object): + """Given mean: (R, G, B) and std: (R, G, B), + will normalize each channel of the torch.*Tensor, i.e. + channel = (channel - mean) / std + """ + def __init__(self, mean, std): + self.mean = mean + self.std = std + + def __call__(self, tensorA, tensorB, tensorC): + tensors = [tensorA, tensorB, tensorC] + output = [] + for tensor in tensors: + # TODO: make efficient + for t, m, s in zip(tensor, self.mean, self.std): + t.sub_(m).div_(s) + output.append(tensor) + return output[0], output[1], output[2] + +class Scale(object): + """Rescales the input PIL.Image to the given 'size'. + 'size' will be the size of the smaller edge. + For example, if height > width, then image will be + rescaled to (size * height / width, size) + size: size of the smaller edge + interpolation: Default: PIL.Image.BILINEAR + """ + def __init__(self, size, interpolation=Image.BILINEAR): + self.size = size + self.interpolation = interpolation + + def __call__(self, imgA, imgB, imgC): + imgs = [imgA, imgB, imgC] + output = [] + for img in imgs: + w, h = img.size + if (w <= h and w == self.size) or (h <= w and h == self.size): + output.append(img) + continue + if w < h: + ow = self.size + oh = int(self.size * h / w) + output.append(img.resize((ow, oh), self.interpolation)) + continue + else: + oh = self.size + ow = int(self.size * w / h) + output.append(img.resize((ow, oh), self.interpolation)) + return output[0], output[1], output[2] + +class CenterCrop(object): + """Crops the given PIL.Image at the center to have a region of + the given size. size can be a tuple (target_height, target_width) + or an integer, in which case the target will be of a square shape (size, size) + """ + def __init__(self, size): + if isinstance(size, numbers.Number): + self.size = (int(size), int(size)) + else: + self.size = size + + def __call__(self, imgA, imgB, imgC): + imgs = [imgA, imgB, imgC] + output = [] + for img in imgs: + w, h = img.size + th, tw = self.size + x1 = int(round((w - tw) / 2.)) + y1 = int(round((h - th) / 2.)) + output.append(img.crop((x1, y1, x1 + tw, y1 + th))) + return output[0], output[1], output[2] + +class Pad(object): + """Pads the given PIL.Image on all sides with the given "pad" value""" + def __init__(self, padding, fill=0): + assert isinstance(padding, numbers.Number) + assert isinstance(fill, numbers.Number) or isinstance(fill, str) or isinstance(fill, tuple) + self.padding = padding + self.fill = fill + + def __call__(self, imgA, imgB, imgC): + imgs = [imgA, imgB, imgC] + output = [] + for img in imgs: + output.append(ImageOps.expand(img, border=self.padding, fill=self.fill)) + return output[0], output[1], output[2] + +class Lambda(object): + """Applies a lambda as a transform.""" + def __init__(self, lambd): + assert isinstance(lambd, types.LambdaType) + self.lambd = lambd + + def __call__(self, imgA, imgB, imgC): + imgs = [imgA, imgB, imgC] + output = [] + for img in imgs: + output.append(self.lambd(img)) + return output[0], output[1], output[2] + +class RandomCrop(object): + """Crops the given PIL.Image at a random location to have a region of + the given size. size can be a tuple (target_height, target_width) + or an integer, in which case the target will be of a square shape (size, size) + """ + def __init__(self, size, padding=0): + if isinstance(size, numbers.Number): + self.size = (int(size), int(size)) + else: + self.size = size + self.padding = padding + + def __call__(self, imgA, imgB, imgC): + imgs = [imgA, imgB, imgC] + output = [] + x1 = -1 + y1 = -1 + for img in imgs: + if self.padding > 0: + img = ImageOps.expand(img, border=self.padding, fill=0) + + w, h = img.size + th, tw = self.size + if w == tw and h == th: + output.append(img) + continue + + if x1 == -1 and y1 == -1: + x1 = random.randint(0, w - tw) + y1 = random.randint(0, h - th) + output.append(img.crop((x1, y1, x1 + tw, y1 + th))) + return output[0], output[1], output[2] + +class RandomHorizontalFlip(object): + """Randomly horizontally flips the given PIL.Image with a probability of 0.5 + """ + def __call__(self, imgA, imgB, imgC): + imgs = [imgA, imgB, imgC] + output = [] + # flag = random.random() < 0.5 + flag = random.random() < -1 + + for img in imgs: + if flag: + output.append(img.transpose(Image.FLIP_LEFT_RIGHT)) + else: + output.append(img) + return output[0], output[1], output[2] diff --git a/transforms/pix2pix_val.py b/transforms/pix2pix_val.py new file mode 100644 index 0000000..a49c9fe --- /dev/null +++ b/transforms/pix2pix_val.py @@ -0,0 +1,225 @@ +from __future__ import division +import torch +import math +import random +from PIL import Image, ImageOps +import numpy as np +import numbers +import types + +class Compose(object): + """Composes several transforms together. + Args: + transforms (List[Transform]): list of transforms to compose. + Example: + >>> transforms.Compose([ + >>> transforms.CenterCrop(10), + >>> transforms.ToTensor(), + >>> ]) + """ + def __init__(self, transforms): + self.transforms = transforms + + def __call__(self, imgA, imgB): + for t in self.transforms: + imgA, imgB = t(imgA, imgB) + return imgA, imgB + +class ToTensor(object): + """Converts a PIL.Image or numpy.ndarray (H x W x C) in the range + [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]. + """ + def __call__(self, picA, picB): + pics = [picA, picB] + output = [] + for pic in pics: + if isinstance(pic, np.ndarray): + # handle numpy array + img = torch.from_numpy(pic.transpose((2, 0, 1))) + else: + # handle PIL Image + img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes())) + # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK + if pic.mode == 'YCbCr': + nchannel = 3 + else: + nchannel = len(pic.mode) + img = img.view(pic.size[1], pic.size[0], nchannel) + # put it from HWC to CHW format + # yikes, this transpose takes 80% of the loading time/CPU + img = img.transpose(0, 1).transpose(0, 2).contiguous() + img = img.float().div(255.) + output.append(img) + return output[0], output[1] + +class ToPILImage(object): + """Converts a torch.*Tensor of range [0, 1] and shape C x H x W + or numpy ndarray of dtype=uint8, range[0, 255] and shape H x W x C + to a PIL.Image of range [0, 255] + """ + def __call__(self, picA, picB): + pics = [picA, picB] + output = [] + for pic in pics: + npimg = pic + mode = None + if not isinstance(npimg, np.ndarray): + npimg = pic.mul(255).byte().numpy() + npimg = np.transpose(npimg, (1, 2, 0)) + + if npimg.shape[2] == 1: + npimg = npimg[:, :, 0] + mode = "L" + output.append(Image.fromarray(npimg, mode=mode)) + + return output[0], output[1] + +class Normalize(object): + """Given mean: (R, G, B) and std: (R, G, B), + will normalize each channel of the torch.*Tensor, i.e. + channel = (channel - mean) / std + """ + def __init__(self, mean, std): + self.mean = mean + self.std = std + + def __call__(self, tensorA, tensorB): + tensors = [tensorA, tensorB] + output = [] + for tensor in tensors: + # TODO: make efficient + for t, m, s in zip(tensor, self.mean, self.std): + t.sub_(m).div_(s) + output.append(tensor) + return output[0], output[1] + +class Scale(object): + """Rescales the input PIL.Image to the given 'size'. + 'size' will be the size of the smaller edge. + For example, if height > width, then image will be + rescaled to (size * height / width, size) + size: size of the smaller edge + interpolation: Default: PIL.Image.BILINEAR + """ + def __init__(self, size, interpolation=Image.BILINEAR): + self.size = size + self.interpolation = interpolation + + def __call__(self, imgA, imgB): + imgs = [imgA, imgB] + output = [] + for img in imgs: + w, h = img.size + if (w <= h and w == self.size) or (h <= w and h == self.size): + output.append(img) + continue + if w < h: + ow = self.size + oh = int(self.size * h / w) + output.append(img.resize((ow, oh), self.interpolation)) + continue + else: + oh = self.size + ow = int(self.size * w / h) + output.append(img.resize((ow, oh), self.interpolation)) + return output[0], output[1] + +class CenterCrop(object): + """Crops the given PIL.Image at the center to have a region of + the given size. size can be a tuple (target_height, target_width) + or an integer, in which case the target will be of a square shape (size, size) + """ + def __init__(self, size): + if isinstance(size, numbers.Number): + self.size = (int(size), int(size)) + else: + self.size = size + + def __call__(self, imgA, imgB): + imgs = [imgA, imgB] + output = [] + for img in imgs: + w, h = img.size + th, tw = self.size + x1 = int(round((w - tw) / 2.)) + y1 = int(round((h - th) / 2.)) + # output.append(img.crop((x1, y1, x1 + tw, y1 + th))) + + output.append(img) + + return output[0], output[1] + +class Pad(object): + """Pads the given PIL.Image on all sides with the given "pad" value""" + def __init__(self, padding, fill=0): + assert isinstance(padding, numbers.Number) + assert isinstance(fill, numbers.Number) or isinstance(fill, str) or isinstance(fill, tuple) + self.padding = padding + self.fill = fill + + def __call__(self, imgA, imgB): + imgs = [imgA, imgB] + output = [] + for img in imgs: + output.append(ImageOps.expand(img, border=self.padding, fill=self.fill)) + return output[0], output[1] + +class Lambda(object): + """Applies a lambda as a transform.""" + def __init__(self, lambd): + assert isinstance(lambd, types.LambdaType) + self.lambd = lambd + + def __call__(self, imgA, imgB): + imgs = [imgA, imgB] + output = [] + for img in imgs: + output.append(self.lambd(img)) + return output[0], output[1] + +class RandomCrop(object): + """Crops the given PIL.Image at a random location to have a region of + the given size. size can be a tuple (target_height, target_width) + or an integer, in which case the target will be of a square shape (size, size) + """ + def __init__(self, size, padding=0): + if isinstance(size, numbers.Number): + self.size = (int(size), int(size)) + else: + self.size = size + self.padding = padding + + def __call__(self, imgA, imgB): + imgs = [imgA, imgB] + output = [] + x1 = -1 + y1 = -1 + for img in imgs: + if self.padding > 0: + img = ImageOps.expand(img, border=self.padding, fill=0) + + w, h = img.size + th, tw = self.size + if w == tw and h == th: + output.append(img) + continue + + if x1 == -1 and y1 == -1: + x1 = random.randint(0, w - tw) + y1 = random.randint(0, h - th) + output.append(img.crop((x1, y1, x1 + tw, y1 + th))) + return output[0], output[1] + +class RandomHorizontalFlip(object): + """Randomly horizontally flips the given PIL.Image with a probability of 0.5 + """ + def __call__(self, imgA, imgB): + imgs = [imgA, imgB] + output = [] + flag = random.random() < 0.5 + for img in imgs: + if flag: + output.append(img.transpose(Image.FLIP_LEFT_RIGHT)) + else: + output.append(img) + return output[0], output[1] diff --git a/transforms/pix2pix_val3.py b/transforms/pix2pix_val3.py new file mode 100644 index 0000000..c8a914d --- /dev/null +++ b/transforms/pix2pix_val3.py @@ -0,0 +1,224 @@ +from __future__ import division +import torch +import math +import random +from PIL import Image, ImageOps +import numpy as np +import numbers +import types + +class Compose(object): + """Composes several transforms together. + Args: + transforms (List[Transform]): list of transforms to compose. + Example: + >>> transforms.Compose([ + >>> transforms.CenterCrop(10), + >>> transforms.ToTensor(), + >>> ]) + """ + def __init__(self, transforms): + self.transforms = transforms + + def __call__(self, imgA, imgB, imgC): + for t in self.transforms: + imgA, imgB, imgC = t(imgA, imgB, imgC) + return imgA, imgB, imgC + +class ToTensor(object): + """Converts a PIL.Image or numpy.ndarray (H x W x C) in the range + [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]. + """ + def __call__(self, picA, picB, picC): + pics = [picA, picB, picC] + output = [] + for pic in pics: + if isinstance(pic, np.ndarray): + # handle numpy array + img = torch.from_numpy(pic.transpose((2, 0, 1))) + else: + # handle PIL Image + img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes())) + # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK + if pic.mode == 'YCbCr': + nchannel = 3 + else: + nchannel = len(pic.mode) + img = img.view(pic.size[1], pic.size[0], nchannel) + # put it from HWC to CHW format + # yikes, this transpose takes 80% of the loading time/CPU + img = img.transpose(0, 1).transpose(0, 2).contiguous() + img = img.float().div(255.) + output.append(img) + return output[0], output[1], output[2] + +class ToPILImage(object): + """Converts a torch.*Tensor of range [0, 1] and shape C x H x W + or numpy ndarray of dtype=uint8, range[0, 255] and shape H x W x C + to a PIL.Image of range [0, 255] + """ + def __call__(self, picA, picB, picC): + pics = [picA, picB, picC] + output = [] + for pic in pics: + npimg = pic + mode = None + if not isinstance(npimg, np.ndarray): + npimg = pic.mul(255).byte().numpy() + npimg = np.transpose(npimg, (1, 2, 0)) + + if npimg.shape[2] == 1: + npimg = npimg[:, :, 0] + mode = "L" + output.append(Image.fromarray(npimg, mode=mode)) + + return output[0], output[1], output[2] + +class Normalize(object): + """Given mean: (R, G, B) and std: (R, G, B), + will normalize each channel of the torch.*Tensor, i.e. + channel = (channel - mean) / std + """ + def __init__(self, mean, std): + self.mean = mean + self.std = std + + def __call__(self, tensorA, tensorB, tensorC): + tensors = [tensorA, tensorB, tensorC] + output = [] + for tensor in tensors: + # TODO: make efficient + for t, m, s in zip(tensor, self.mean, self.std): + t.sub_(m).div_(s) + output.append(tensor) + return output[0], output[1], output[2] + +class Scale(object): + """Rescales the input PIL.Image to the given 'size'. + 'size' will be the size of the smaller edge. + For example, if height > width, then image will be + rescaled to (size * height / width, size) + size: size of the smaller edge + interpolation: Default: PIL.Image.BILINEAR + """ + def __init__(self, size, interpolation=Image.BILINEAR): + self.size = size + self.interpolation = interpolation + + def __call__(self, imgA, imgB, imgC): + imgs = [imgA, imgB, imgC] + output = [] + for img in imgs: + w, h = img.size + if (w <= h and w == self.size) or (h <= w and h == self.size): + output.append(img) + continue + if w < h: + ow = self.size + oh = int(self.size * h / w) + output.append(img.resize((ow, oh), self.interpolation)) + continue + else: + oh = self.size + ow = int(self.size * w / h) + output.append(img.resize((ow, oh), self.interpolation)) + return output[0], output[1], output[2] + +class CenterCrop(object): + """Crops the given PIL.Image at the center to have a region of + the given size. size can be a tuple (target_height, target_width) + or an integer, in which case the target will be of a square shape (size, size) + """ + def __init__(self, size): + if isinstance(size, numbers.Number): + self.size = (int(size), int(size)) + else: + self.size = size + + def __call__(self, imgA, imgB, imgC): + imgs = [imgA, imgB, imgC] + output = [] + for img in imgs: + w, h = img.size + th, tw = self.size + x1 = int(round((w - tw) / 2.)) + y1 = int(round((h - th) / 2.)) + output.append(img.crop((x1, y1, x1 + tw, y1 + th))) + return output[0], output[1], output[2] + +class Pad(object): + """Pads the given PIL.Image on all sides with the given "pad" value""" + def __init__(self, padding, fill=0): + assert isinstance(padding, numbers.Number) + assert isinstance(fill, numbers.Number) or isinstance(fill, str) or isinstance(fill, tuple) + self.padding = padding + self.fill = fill + + def __call__(self, imgA, imgB, imgC): + imgs = [imgA, imgB, imgC] + output = [] + for img in imgs: + output.append(ImageOps.expand(img, border=self.padding, fill=self.fill)) + return output[0], output[1], output[2] + +class Lambda(object): + """Applies a lambda as a transform.""" + def __init__(self, lambd): + assert isinstance(lambd, types.LambdaType) + self.lambd = lambd + + def __call__(self, imgA, imgB, imgC): + imgs = [imgA, imgB, imgC] + output = [] + for img in imgs: + output.append(self.lambd(img)) + return output[0], output[1], output[2] + +class RandomCrop(object): + """Crops the given PIL.Image at a random location to have a region of + the given size. size can be a tuple (target_height, target_width) + or an integer, in which case the target will be of a square shape (size, size) + """ + def __init__(self, size, padding=0): + if isinstance(size, numbers.Number): + self.size = (int(size), int(size)) + else: + self.size = size + self.padding = padding + + def __call__(self, imgA, imgB, imgC): + imgs = [imgA, imgB, imgC] + output = [] + x1 = -1 + y1 = -1 + for img in imgs: + if self.padding > 0: + img = ImageOps.expand(img, border=self.padding, fill=0) + + w, h = img.size + th, tw = self.size + if w == tw and h == th: + output.append(img) + continue + + if x1 == -1 and y1 == -1: + x1 = random.randint(0, w - tw) + y1 = random.randint(0, h - th) + output.append(img.crop((x1, y1, x1 + tw, y1 + th))) + return output[0], output[1], output[2] + +class RandomHorizontalFlip(object): + """Randomly horizontally flips the given PIL.Image with a probability of 0.5 + """ + def __call__(self, imgA, imgB, imgC): + imgs = [imgA, imgB, imgC] + output = [] + # flag = random.random() < 0.5 + flag = random.random() < -1 + + for img in imgs: + if flag: + output.append(img.transpose(Image.FLIP_LEFT_RIGHT)) + else: + output.append(img) + return output[0], output[1], output[2]