You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

313 lines
13 KiB

from __future__ import print_function
import argparse
import os
import sys
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
cudnn.benchmark = True
cudnn.fastest = True
import torch.optim as optim
import torchvision.utils as vutils
from torch.autograd import Variable
from misc import *
import models.face_fed as net
from myutils.vgg16 import Vgg16
from myutils import utils
import pdb
# Pre-defined Parameters
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', required=False,
default='pix2pix', help='')
parser.add_argument('--dataroot', required=False,
default='', help='path to trn dataset')
parser.add_argument('--valDataroot', required=False,
default='', help='path to val dataset')
parser.add_argument('--mode', type=str, default='B2A', help='B2A: facade, A2B: edges2shoes')
parser.add_argument('--batchSize', type=int, default=1, help='input batch size')
parser.add_argument('--valBatchSize', type=int, default=1, help='input batch size')
parser.add_argument('--originalSize', type=int,
default=128, help='the height / width of the original input image')
parser.add_argument('--imageSize', type=int,
default=128, help='the height / width of the cropped input image to network')
parser.add_argument('--inputChannelSize', type=int,
default=3, help='size of the input channels')
parser.add_argument('--outputChannelSize', type=int,
default=3, help='size of the output channels')
parser.add_argument('--ngf', type=int, default=64)
parser.add_argument('--ndf', type=int, default=64)
parser.add_argument('--niter', type=int, default=400, help='number of epochs to train for')
parser.add_argument('--lrD', type=float, default=0.0002, help='learning rate, default=0.0002')
parser.add_argument('--lrG', type=float, default=0.0002, help='learning rate, default=0.0002')
parser.add_argument('--annealStart', type=int, default=0, help='annealing learning rate start to')
parser.add_argument('--annealEvery', type=int, default=400, help='epoch to reaching at learning rate of 0')
parser.add_argument('--lambdaGAN', type=float, default=0.01, help='lambdaGAN')
parser.add_argument('--lambdaIMG', type=float, default=1, help='lambdaIMG')
parser.add_argument('--poolSize', type=int, default=50, help='Buffer size for storing previously generated samples from G')
parser.add_argument('--wd', type=float, default=0.0000, help='weight decay in D')
parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam')
parser.add_argument('--netG', default='', help="path to netG (to continue training)")
parser.add_argument('--netD', default='', help="path to netD (to continue training)")
parser.add_argument('--workers', type=int, help='number of data loading workers', default=1)
parser.add_argument('--exp', default='sample', help='folder to output images and model checkpoints')
parser.add_argument('--display', type=int, default=5, help='interval for displaying train-logs')
parser.add_argument('--evalIter', type=int, default=500, help='interval for evauating(generating) images from valDataroot')
opt = parser.parse_args()#pt.originalSize
print(opt)
create_exp_dir(opt.exp)
opt.manualSeed = random.randint(1, 10000)
random.seed(opt.manualSeed)
torch.manual_seed(opt.manualSeed)
torch.cuda.manual_seed_all(opt.manualSeed)
print("Random Seed: ", opt.manualSeed)
# Initialize dataloader
dataloader = getLoader(opt.dataset,
opt.dataroot,
opt.originalSize,
opt.imageSize,
opt.batchSize,
opt.workers,
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
split='val',
shuffle=True,
seed=opt.manualSeed)
opt.dataset='pix2pix_val'
valDataloader = getLoader(opt.dataset,
opt.valDataroot,
opt.originalSize, #opt.originalSize,
opt.imageSize,
opt.valBatchSize,
opt.workers,
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
split='val',
shuffle=False,
seed=opt.manualSeed)
# get logger
trainLogger = open('%s/train.log' % opt.exp, 'w')
ngf = opt.ngf
ndf = opt.ndf
inputChannelSize = opt.inputChannelSize
outputChannelSize= opt.outputChannelSize
# Load Pre-trained derain model
netS=net.Segmentation()
netG=net.Deblur_segdl()
#netC.apply(weights_init)
netG.apply(weights_init)
if opt.netG != '':
state_dict_g = torch.load(opt.netG)
new_state_dict_g = {}
for k, v in state_dict_g.items():
name = k[7:] # remove `module.`
#print(k)
new_state_dict_g[name] = v
# load params
netG.load_state_dict(new_state_dict_g)
#netG.load_state_dict(torch.load(opt.netG))
print(netG)
netG.eval()
#netS.apply(weights_init)
netS.load_state_dict(torch.load('./pretrained_models/SMaps_Best.pth'))
#netS.eval()
netS.cuda()
netG.cuda()
# Initialize testing data
target= torch.FloatTensor(opt.batchSize, outputChannelSize, opt.imageSize, opt.imageSize)
input = torch.FloatTensor(opt.batchSize, inputChannelSize, opt.imageSize, opt.imageSize)
val_target= torch.FloatTensor(opt.valBatchSize, outputChannelSize, opt.imageSize, opt.imageSize)
val_input = torch.FloatTensor(opt.valBatchSize, inputChannelSize, opt.imageSize, opt.imageSize)
label_d = torch.FloatTensor(opt.batchSize)
target = torch.FloatTensor(opt.batchSize, outputChannelSize, opt.imageSize, opt.imageSize)
input = torch.FloatTensor(opt.batchSize, inputChannelSize, opt.imageSize, opt.imageSize)
depth = torch.FloatTensor(opt.batchSize, inputChannelSize, opt.imageSize, opt.imageSize)
ato = torch.FloatTensor(opt.batchSize, inputChannelSize, opt.imageSize, opt.imageSize)
val_target = torch.FloatTensor(opt.valBatchSize, outputChannelSize, opt.imageSize, opt.imageSize)
val_input = torch.FloatTensor(opt.valBatchSize, inputChannelSize, opt.imageSize, opt.imageSize)
val_depth = torch.FloatTensor(opt.valBatchSize, inputChannelSize, opt.imageSize, opt.imageSize)
val_ato = torch.FloatTensor(opt.valBatchSize, inputChannelSize, opt.imageSize, opt.imageSize)
target_128= torch.FloatTensor(opt.batchSize, outputChannelSize, (opt.imageSize//4), (opt.imageSize//4))
input_128 = torch.FloatTensor(opt.batchSize, inputChannelSize, (opt.imageSize//4), (opt.imageSize//4))
target_256= torch.FloatTensor(opt.batchSize, outputChannelSize, (opt.imageSize//2), (opt.imageSize//2))
input_256 = torch.FloatTensor(opt.batchSize, inputChannelSize, (opt.imageSize//2), (opt.imageSize//2))
val_target_128= torch.FloatTensor(opt.batchSize, outputChannelSize, (opt.imageSize//4), (opt.imageSize//4))
val_input_128 = torch.FloatTensor(opt.batchSize, inputChannelSize, (opt.imageSize//4), (opt.imageSize//4))
val_target_256= torch.FloatTensor(opt.batchSize, outputChannelSize, (opt.imageSize//2), (opt.imageSize//2))
val_input_256 = torch.FloatTensor(opt.batchSize, inputChannelSize, (opt.imageSize//2), (opt.imageSize//2))
target, input, depth, ato = target.cuda(), input.cuda(), depth.cuda(), ato.cuda()
val_target, val_input, val_depth, val_ato = val_target.cuda(), val_input.cuda(), val_depth.cuda(), val_ato.cuda()
target = Variable(target, volatile=True)
input = Variable(input,volatile=True)
depth = Variable(depth,volatile=True)
ato = Variable(ato,volatile=True)
target_128, input_128 = target_128.cuda(), input_128.cuda()
val_target_128, val_input_128 = val_target_128.cuda(), val_input_128.cuda()
target_256, input_256 = target_256.cuda(), input_256.cuda()
val_target_256, val_input_256 = val_target_256.cuda(), val_input_256.cuda()
target_128 = Variable(target_128)
input_128 = Variable(input_128)
target_256 = Variable(target_256)
input_256 = Variable(input_256)
label_d = Variable(label_d.cuda())
def norm_ip(img, min, max):
img.clamp_(min=min, max=max)
img.add_(-min).div_(max - min)
return img
def norm_range(t, range):
if range is not None:
norm_ip(t, range[0], range[1])
else:
norm_ip(t, -1, +1)
return t#norm_ip(t, t.min(), t.max())
# get optimizer
optimizerG = optim.Adam(netG.parameters(), lr = opt.lrG, betas = (opt.beta1, 0.999), weight_decay=0.00005)
# Begin Testing
for epoch in range(1):
heavy, medium, light=200, 200, 200
for i, data in enumerate(valDataloader, 0):
if 1:
print('Image:'+str(i))
import time
data_val = data
t0 = time.time()
val_input_cpu, val_target_cpu = data_val
val_target_cpu, val_input_cpu = val_target_cpu.float().cuda(), val_input_cpu.float().cuda()
val_batch_output = torch.FloatTensor(val_input.size()).fill_(0)
val_input.resize_as_(val_input_cpu).copy_(val_input_cpu)
val_target=Variable(val_target_cpu, volatile=True)
z=0
with torch.no_grad():
for idx in range(val_input.size(0)):
single_img = val_input[idx,:,:,:].unsqueeze(0)
val_inputv = Variable(single_img, volatile=True)
print (val_inputv.size())
# val_inputv = val_inputv.float().cuda()
val_inputv_256 = torch.nn.functional.interpolate(val_inputv,scale_factor=0.5)
val_inputv_128 = torch.nn.functional.interpolate(val_inputv,scale_factor=0.25)
## Get de-rained results ##
#residual_val, x_hat_val, x_hatlv128, x_hatvl256 = netG(val_inputv, val_inputv_256, val_inputv_128)
t1 = time.time()
print('running time:'+str(t1 - t0))
from PIL import Image
#x_hat_val = netG(val_inputv)
#smaps_vl = netS(val_inputv)
#S_valinput = torch.cat([smaps_vl,val_inputv],1)
"""smaps,smaps64 = netS(val_inputv,val_inputv_256)
S_input = torch.cat([smaps,val_inputv],1)
x_hat_val, x_hat_val64 = netG(val_inputv,val_inputv_256,smaps,smaps64)"""
#x_hatcls1,x_hatcls2,x_hatcls3,x_hatcls4,x_lst1,x_lst2,x_lst3,x_lst4 = netG(val_inputv,val_inputv_256,smaps_i,smaps_i64,class1,class2,class3,class4)
smaps,smaps64 = netS(val_inputv,val_inputv_256)
class1 = torch.zeros([1,1,opt.originalSize,opt.originalSize], dtype=torch.float32)
class1[:,0,:,:] = smaps[:,0,:,:]
class2 = torch.zeros([1,1,opt.originalSize,opt.originalSize], dtype=torch.float32)
class2[:,0,:,:] = smaps[:,1,:,:]
class3 = torch.zeros([1,1,opt.originalSize,opt.originalSize], dtype=torch.float32)
class3[:,0,:,:] = smaps[:,2,:,:]
class4 = torch.zeros([1,1,opt.originalSize,opt.originalSize], dtype=torch.float32)
class4[:,0,:,:] = smaps[:,3,:,:]
class_msk1 = torch.zeros([1,3,opt.originalSize,opt.originalSize], dtype=torch.float32)
class_msk1[:,0,:,:] = smaps[:,0,:,:]
class_msk1[:,1,:,:] = smaps[:,0,:,:]
class_msk1[:,2,:,:] = smaps[:,0,:,:]
class_msk2 = torch.zeros([1,3,opt.originalSize,opt.originalSize], dtype=torch.float32)
class_msk2[:,0,:,:] = smaps[:,1,:,:]
class_msk2[:,1,:,:] = smaps[:,1,:,:]
class_msk2[:,2,:,:] = smaps[:,1,:,:]
class_msk3 = torch.zeros([1,3,opt.originalSize,opt.originalSize], dtype=torch.float32)
class_msk3[:,0,:,:] = smaps[:,2,:,:]
class_msk3[:,1,:,:] = smaps[:,2,:,:]
class_msk3[:,2,:,:] = smaps[:,2,:,:]
class_msk4 = torch.zeros([1,3,opt.originalSize,opt.originalSize], dtype=torch.float32)
class_msk4[:,0,:,:] = smaps[:,3,:,:]
class_msk4[:,1,:,:] = smaps[:,3,:,:]
class_msk4[:,2,:,:] = smaps[:,3,:,:]
class1 = class1.float().cuda()
class2 = class2.float().cuda()
class3 = class3.float().cuda()
class4 = class4.float().cuda()
class_msk4 = class_msk4.float().cuda()
class_msk3 = class_msk3.float().cuda()
class_msk2 = class_msk2.float().cuda()
class_msk1 = class_msk1.float().cuda()
x_hat_val, x_hat_val64,xmask1,xmask2,xmask3,xmask4,xcl_class1,xcl_class2,xcl_class3,xcl_class4 = netG(val_inputv,val_inputv_256,smaps,class1,class2,class3,class4,val_inputv,class_msk1,class_msk2,class_msk3,class_msk4)
# x_hat1,x_hat64,xmask1,xmask2,xmask3,xmask4,xcl_class1,xcl_class2,xcl_class3,xcl_class4 = netG(input,input_256,smaps_i,class1,class2,class3,class4,target,class_msk1,class_msk2,class_msk3,class_msk4)
#x_hat_val.data
#val_batch_output[idx,:,:,:].copy_(x_hat_val.data[0,1,:,:])
# print(torch.mean(xmask1),torch.mean(xmask2),torch.mean(xmask3),torch.mean(xmask4))
print (smaps.size())
tensor = x_hat_val.data.cpu()
### Save the de-rained results #####
from PIL import Image
directory = './result_all/deblurh/'#'./result_all/new_model_data/DID-MDN/'
if not os.path.exists(directory):
os.makedirs(directory)
tensor = torch.squeeze(tensor)
tensor=norm_range(tensor, None)
print(tensor.min(),tensor.max())
filename='./result_all/deblurh/'+str(i+1)+'.png'
ndarr = tensor.mul(255).clamp(0, 255).byte().permute(1, 2, 0).numpy()
im = Image.fromarray(ndarr)
im.save(filename)