You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

528 lines
22 KiB

from __future__ import print_function
import argparse
import os
import sys
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
cudnn.benchmark = True
cudnn.fastest = True
import torch.optim as optim
import torchvision.utils as vutils
from torch.autograd import Variable
from misc import *
import models.face_fed as net
from myutils.vgg16 import Vgg16
from myutils import utils
import pdb
import torch.nn.functional as F
#from PIL import Image
from torchvision import transforms
import h5py
from os import listdir
from os.path import isfile, join
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', required=False,
default='pix2pix_class', help='')
parser.add_argument('--dataroot', required=False,
default='', help='path to trn dataset')
parser.add_argument('--valDataroot', required=False,
default='', help='path to val dataset')
parser.add_argument('--mode', type=str, default='B2A', help='B2A: facade, A2B: edges2shoes')
parser.add_argument('--batchSize', type=int, default=1, help='input batch size')
parser.add_argument('--valBatchSize', type=int, default=120, help='input batch size')
parser.add_argument('--originalSize', type=int,
default=175, help='the height / width of the original input image')
parser.add_argument('--imageSize', type=int,
default=128, help='the height / width of the cropped input image to network')
parser.add_argument('--inputChannelSize', type=int,
default=3, help='size of the input channels')
parser.add_argument('--outputChannelSize', type=int,
default=3, help='size of the output channels')
parser.add_argument('--ngf', type=int, default=64)
parser.add_argument('--ndf', type=int, default=64)
parser.add_argument('--niter', type=int, default=5000, help='number of epochs to train for')
parser.add_argument('--lrD', type=float, default=0.0002, help='learning rate, default=0.0002')
parser.add_argument('--lrG', type=float, default=0.0002, help='learning rate, default=0.0002')
parser.add_argument('--annealStart', type=int, default=0, help='annealing learning rate start to')
parser.add_argument('--annealEvery', type=int, default=400, help='epoch to reaching at learning rate of 0')
parser.add_argument('--lambdaGAN', type=float, default=0.01, help='lambdaGAN')
parser.add_argument('--lambdaIMG', type=float, default=2.0, help='lambdaIMG')
parser.add_argument('--poolSize', type=int, default=50, help='Buffer size for storing previously generated samples from G')
parser.add_argument('--wd', type=float, default=0.0000, help='weight decay in D')
parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam')
parser.add_argument('--netG', default='', help="path to netG (to continue training)")
parser.add_argument('--netD', default='', help="path to netD (to continue training)")
parser.add_argument('--workers', type=int, help='number of data loading workers', default=1)
parser.add_argument('--exp', default='sample', help='folder to output images and model checkpoints')
parser.add_argument('--display', type=int, default=5, help='interval for displaying train-logs')
parser.add_argument('--evalIter', type=int, default=500, help='interval for evauating(generating) images from valDataroot')
opt = parser.parse_args()
print(opt)
from scipy import signal
import h5py
from scipy import signal
import random
#loading kernels mat file
k_filename ='./kernel.mat'
kfp = h5py.File(k_filename)
kernels = np.array(kfp['kernels'])
kernels = kernels.transpose([0,2,1])
vgg = Vgg16()
#utils.init_vgg16('./models/')
state_dict_g = torch.load('VGG_FACE.pth')
new_state_dict_g = {}
new_state_dict_g["conv1_1.weight"]= state_dict_g["0.weight"]
new_state_dict_g["conv1_1.bias"]= state_dict_g["0.bias"]
new_state_dict_g["conv1_2.weight"]= state_dict_g["2.weight"]
new_state_dict_g["conv1_2.bias"]= state_dict_g["2.bias"]
new_state_dict_g["conv2_1.weight"]= state_dict_g["5.weight"]
new_state_dict_g["conv2_1.bias"]= state_dict_g["5.bias"]
new_state_dict_g["conv2_2.weight"]= state_dict_g["7.weight"]
new_state_dict_g["conv2_2.bias"]= state_dict_g["7.bias"]
new_state_dict_g["conv3_1.weight"]= state_dict_g["10.weight"]
new_state_dict_g["conv3_1.bias"]= state_dict_g["10.bias"]
new_state_dict_g["conv3_2.weight"]= state_dict_g["12.weight"]
new_state_dict_g["conv3_2.bias"]= state_dict_g["12.bias"]
new_state_dict_g["conv3_3.weight"]= state_dict_g["14.weight"]
new_state_dict_g["conv3_3.bias"]= state_dict_g["14.bias"]
new_state_dict_g["conv4_1.weight"]= state_dict_g["17.weight"]
new_state_dict_g["conv4_1.bias"]= state_dict_g["17.bias"]
new_state_dict_g["conv4_2.weight"]= state_dict_g["19.weight"]
new_state_dict_g["conv4_2.bias"]= state_dict_g["19.bias"]
new_state_dict_g["conv4_3.weight"]= state_dict_g["21.weight"]
new_state_dict_g["conv4_3.bias"]= state_dict_g["21.bias"]
new_state_dict_g["conv5_1.weight"]= state_dict_g["24.weight"]
new_state_dict_g["conv5_1.bias"]= state_dict_g["24.bias"]
new_state_dict_g["conv5_2.weight"]= state_dict_g["26.weight"]
new_state_dict_g["conv5_2.bias"]= state_dict_g["26.bias"]
new_state_dict_g["conv5_3.weight"]= state_dict_g["28.weight"]
new_state_dict_g["conv5_3.bias"]= state_dict_g["28.bias"]
vgg.load_state_dict(new_state_dict_g)
vgg = torch.nn.DataParallel(vgg)
vgg.cuda()
create_exp_dir(opt.exp)
opt.manualSeed = random.randint(1, 10000)
random.seed(opt.manualSeed)
torch.manual_seed(opt.manualSeed)
torch.cuda.manual_seed_all(opt.manualSeed)
print("Random Seed: ", opt.manualSeed)
# get dataloader
opt.dataset='pix2pix_val'
print (opt.dataroot)
dataloader = getLoader(opt.dataset,
opt.dataroot,
opt.originalSize,
opt.imageSize,
opt.batchSize,
opt.workers,
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
split='train',
shuffle=True,
seed=opt.manualSeed)
opt.dataset='pix2pix_val'
valDataloader = getLoader(opt.dataset,
opt.valDataroot,
opt.originalSize,
opt.imageSize,
opt.valBatchSize,
opt.workers,
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
split='val',
shuffle=False,
seed=opt.manualSeed)
# get logger
trainLogger = open('%s/train.log' % opt.exp, 'w')
def gradient(y):
gradient_h=torch.abs(y[:, :, :, :-1] - y[:, :, :, 1:])
gradient_y=torch.abs(y[:, :, :-1, :] - y[:, :, 1:, :])
return gradient_h, gradient_y
ngf = opt.ngf
ndf = opt.ndf
inputChannelSize = opt.inputChannelSize
outputChannelSize= opt.outputChannelSize
# get models
netS=net.Segmentation()
netG=net.Deblur_segdl()
netS.load_state_dict(torch.load('./pretrained_models/SMaps_Best.pth'))
netG.apply(weights_init)
if opt.netG != '':
state_dict_g = torch.load(opt.netG)
new_state_dict_g = {}
for k, v in state_dict_g.items():
name = k[7:]
new_state_dict_g[name] = v
# load params
netG.load_state_dict(new_state_dict_g)
print(netG)
netG = torch.nn.DataParallel(netG)
netS = torch.nn.DataParallel(netS)
netG.train()
criterionCAE = nn.L1Loss()
criterionCAE1 = nn.SmoothL1Loss()
target= torch.FloatTensor(opt.batchSize, outputChannelSize, opt.imageSize, opt.imageSize)
input = torch.FloatTensor(opt.batchSize, inputChannelSize, opt.imageSize, opt.imageSize)
target_128= torch.FloatTensor(opt.batchSize, outputChannelSize, (opt.imageSize//4), (opt.imageSize//4))
input_128 = torch.FloatTensor(opt.batchSize, inputChannelSize, (opt.imageSize//4), (opt.imageSize//4))
target_256= torch.FloatTensor(opt.batchSize, outputChannelSize, (opt.imageSize//2), (opt.imageSize//2))
input_256 = torch.FloatTensor(opt.batchSize, inputChannelSize, (opt.imageSize//2), (opt.imageSize//2))
val_target= torch.FloatTensor(opt.valBatchSize, outputChannelSize, opt.imageSize, opt.imageSize)
val_input = torch.FloatTensor(opt.valBatchSize, inputChannelSize, opt.imageSize, opt.imageSize)
val_target_128= torch.FloatTensor(opt.batchSize, outputChannelSize, (opt.imageSize//4), (opt.imageSize//4))
val_input_128 = torch.FloatTensor(opt.batchSize, inputChannelSize, (opt.imageSize//4), (opt.imageSize//4))
val_target_256= torch.FloatTensor(opt.batchSize, outputChannelSize, (opt.imageSize//2), (opt.imageSize//2))
val_input_256 = torch.FloatTensor(opt.batchSize, inputChannelSize, (opt.imageSize//2), (opt.imageSize//2))
label_d = torch.FloatTensor(opt.batchSize)
target = torch.FloatTensor(opt.batchSize, outputChannelSize, opt.imageSize, opt.imageSize)
input = torch.FloatTensor(opt.batchSize, inputChannelSize, opt.imageSize, opt.imageSize)
depth = torch.FloatTensor(opt.batchSize, inputChannelSize, opt.imageSize, opt.imageSize)
ato = torch.FloatTensor(opt.batchSize, inputChannelSize, opt.imageSize, opt.imageSize)
val_target = torch.FloatTensor(opt.valBatchSize, outputChannelSize, opt.imageSize, opt.imageSize)
val_input = torch.FloatTensor(opt.valBatchSize, inputChannelSize, opt.imageSize, opt.imageSize)
val_depth = torch.FloatTensor(opt.valBatchSize, inputChannelSize, opt.imageSize, opt.imageSize)
val_ato = torch.FloatTensor(opt.valBatchSize, inputChannelSize, opt.imageSize, opt.imageSize)
# image pool storing previously generated samples from G
lambdaGAN = opt.lambdaGAN
lambdaIMG = opt.lambdaIMG
netG.cuda()
netS.cuda()
criterionCAE.cuda()
criterionCAE1.cuda()
target, input, depth, ato = target.cuda(), input.cuda(), depth.cuda(), ato.cuda()
val_target, val_input, val_depth, val_ato = val_target.cuda(), val_input.cuda(), val_depth.cuda(), val_ato.cuda()
target = Variable(target)
input = Variable(input)
target_128, input_128 = target_128.cuda(), input_128.cuda()
val_target_128, val_input_128 = val_target_128.cuda(), val_input_128.cuda()
target_256, input_256 = target_256.cuda(), input_256.cuda()
val_target_256, val_input_256 = val_target_256.cuda(), val_input_256.cuda()
target_128 = Variable(target_128)
input_128 = Variable(input_128)
target_256 = Variable(target_256)
input_256 = Variable(input_256)
ato = Variable(ato)
label_d = Variable(label_d.cuda())
# get randomly sampled validation images and save it
print(len(dataloader))
val_iter = iter(valDataloader)
data_val = val_iter.next()
val_input_cpu, val_target_cpu = data_val
val_target_cpu, val_input_cpu = val_target_cpu.float().cuda(), val_input_cpu.float().cuda()
val_target.resize_as_(val_target_cpu).copy_(val_target_cpu)
val_input.resize_as_(val_input_cpu).copy_(val_input_cpu)
vutils.save_image(val_target, '%s/real_target.png' % opt.exp, normalize=True)
vutils.save_image(val_input, '%s/real_input.png' % opt.exp, normalize=True)
optimizerG = optim.Adam(netG.parameters(), lr = opt.lrG, betas = (opt.beta1, 0.999), weight_decay=0.00005)
# NOTE training loop
ganIterations = 0
count = 0
for epoch in range(opt.niter):
if epoch % 19 == 0 and epoch>0:
opt.lrG = opt.lrG/2.0
for param_group in optimizerG.param_groups:
param_group['lr'] = opt.lrG
if epoch >= opt.annealStart:
adjust_learning_rate(optimizerG, opt.lrG, epoch, None, opt.annealEvery)
for i, data in enumerate(dataloader, 0):
input_cpu, target_cpu = data
batch_size = target_cpu.size(0)
b,ch,x,y = target_cpu.size()
x1 = int((x-opt.imageSize)/2)
y1 = int((y-opt.imageSize)/2)
#generating blurry image
input_cpu = input_cpu.numpy()
target_cpu = target_cpu.numpy()
for j in range(batch_size):
index = random.randint(0,24500)
input_cpu[j,0,:,:]= signal.convolve(input_cpu[j,0,:,:],kernels[index,:,:],mode='same')
input_cpu[j,1,:,:]= signal.convolve(input_cpu[j,1,:,:],kernels[index,:,:],mode='same')
input_cpu[j,2,:,:]= signal.convolve(input_cpu[j,2,:,:],kernels[index,:,:],mode='same')
input_cpu = input_cpu + (1.0/255.0)* np.random.normal(0,4,input_cpu.shape)
input_cpu = input_cpu[:,:,x1:x1+opt.imageSize,y1:y1+opt.imageSize]
target_cpu = target_cpu[:,:,x1:x1+opt.imageSize,y1:y1+opt.imageSize]
input_cpu = torch.from_numpy(input_cpu)
target_cpu = torch.from_numpy(target_cpu)
target_cpu, input_cpu = target_cpu.float().cuda(), input_cpu.float().cuda()
# getting input and target image at 0.5 scale
target.data.resize_as_(target_cpu).copy_(target_cpu)
input.data.resize_as_(input_cpu).copy_(input_cpu)
input_256 = torch.nn.functional.interpolate(input,scale_factor=0.5)
target_256 = torch.nn.functional.interpolate(target,scale_factor=0.5)
# computing segmentation masks for input and target
with torch.no_grad():
smaps_i,smaps_i64 = netS(input,input_256)
smaps,smaps64 = netS(target,target_256)
class1 = torch.zeros([batch_size,1,128,128], dtype=torch.float32)
class1[:,0,:,:] = smaps_i[:,0,:,:]
class2 = torch.zeros([batch_size,1,128,128], dtype=torch.float32)
class2[:,0,:,:] = smaps_i[:,1,:,:]
class3 = torch.zeros([batch_size,1,128,128], dtype=torch.float32)
class3[:,0,:,:] = smaps_i[:,2,:,:]
class4 = torch.zeros([batch_size,1,128,128], dtype=torch.float32)
class4[:,0,:,:] = smaps_i[:,3,:,:]
class_msk1 = torch.zeros([batch_size,3,128,128], dtype=torch.float32)
class_msk1[:,0,:,:] = smaps[:,0,:,:]
class_msk1[:,1,:,:] = smaps[:,0,:,:]
class_msk1[:,2,:,:] = smaps[:,0,:,:]
class_msk2 = torch.zeros([batch_size,3,128,128], dtype=torch.float32)
class_msk2[:,0,:,:] = smaps[:,1,:,:]
class_msk2[:,1,:,:] = smaps[:,1,:,:]
class_msk2[:,2,:,:] = smaps[:,1,:,:]
class_msk3 = torch.zeros([batch_size,3,128,128], dtype=torch.float32)
class_msk3[:,0,:,:] = smaps[:,2,:,:]
class_msk3[:,1,:,:] = smaps[:,2,:,:]
class_msk3[:,2,:,:] = smaps[:,2,:,:]
class_msk4 = torch.zeros([batch_size,3,128,128], dtype=torch.float32)
class_msk4[:,0,:,:] = smaps[:,3,:,:]
class_msk4[:,1,:,:] = smaps[:,3,:,:]
class_msk4[:,2,:,:] = smaps[:,3,:,:]
class1 = class1.float().cuda()
class2 = class2.float().cuda()
class3 = class3.float().cuda()
class4 = class4.float().cuda()
class_msk4 = class_msk4.float().cuda()
class_msk3 = class_msk3.float().cuda()
class_msk2 = class_msk2.float().cuda()
class_msk1 = class_msk1.float().cuda()
# Forward step
x_hat1,x_hat64,xmask1,xmask2,xmask3,xmask4,xcl_class1,xcl_class2,xcl_class3,xcl_class4 = netG(input,input_256,smaps_i,class1,class2,class3,class4,target,class_msk1,class_msk2,class_msk3,class_msk4)
x_hat = x_hat1
if ganIterations % 2 == 0:
netG.zero_grad() # start to update G
if epoch>-1:
# with torch.no_grad():
# smaps,smaps64 = netS(target,target_256)
L_img_ = 0.33*criterionCAE(x_hat64, target_256) #+ 0.5*criterionCAE(smaps_hat, smaps)
L_img_ = L_img_ + 1.2 *criterionCAE(xmask1*class_msk1*x_hat+(1-xmask1)*class_msk1*target, class_msk1*target)
L_img_ = L_img_ + 1.2 *criterionCAE(xmask2*class_msk2*x_hat+(1-xmask2)*class_msk2*target, class_msk2*target)
L_img_ = L_img_ + 3.6 *criterionCAE(xmask3*class_msk3*x_hat+(1-xmask3)*class_msk3*target, class_msk3*target)
L_img_ = L_img_ + 1.2 *criterionCAE(xmask4*class_msk4*x_hat+(1-xmask4)*class_msk4*target, class_msk4*target)
if ganIterations % (25*opt.display) == 0:
print(L_img_.data[0])
sys.stdout.flush()
if ganIterations< -1:
lam_cmp = 1.0
else :
lam_cmp = 0.06
sng = 0.00000001
L_img_ = L_img_ - (lam_cmp/(4.0))*torch.mean(torch.log(xmask1+sng))
L_img_ = L_img_ - (lam_cmp/(4.0))*torch.mean(torch.log(xmask2+sng))
L_img_ = L_img_ - (lam_cmp/(4.0))*torch.mean(torch.log(xmask3+sng))
L_img_ = L_img_ - (lam_cmp/(4.0))*torch.mean(torch.log(xmask4+sng))
if ganIterations % (50*opt.display) == 0:
print(L_img_.data[0])
sys.stdout.flush()
gradh_xhat,gradv_xhat=gradient(x_hat)
gradh_tar,gradv_tar=gradient(target)
gradh_xhat64,gradv_xhat64=gradient(x_hat64)
gradh_tar64,gradv_tar64=gradient(target_256)
L_img_ = L_img_ + 0.15*criterionCAE(gradh_xhat,gradh_tar)+ 0.15*criterionCAE(gradv_xhat,gradv_tar)+ 0.08*criterionCAE(gradh_xhat64,gradh_tar64)+0.08*criterionCAE(gradv_xhat64,gradv_tar64)
if ganIterations % (25*opt.display) == 0:
print(L_img_.data[0])
print((torch.mean(torch.log(xmask1)).data),(torch.mean(torch.log(xmask2)).data),(torch.mean(xmask3).data),(torch.mean(xmask4).data))
sys.stdout.flush()
L_img = lambdaIMG * L_img_
#Backward step or computing gradients
if lambdaIMG != 0:
L_img.backward(retain_graph=True)
# Perceptual Loss 1
features_content = vgg(target)
f_xc_c = Variable(features_content[1].data, requires_grad=False)
f_xc_c5 = Variable(features_content[4].data, requires_grad=False)
features_y = vgg(x_hat)
features_content = vgg(target_256)
f_xc_c64 = Variable(features_content[1].data, requires_grad=False)
features_y64 = vgg(x_hat64)
lambda_p=0.00018
content_loss = lambda_p*lambdaIMG* criterionCAE(features_y[1], f_xc_c) + lambda_p*0.33*lambdaIMG* criterionCAE(features_y64[1], f_xc_c64) + lambda_p*lambdaIMG* criterionCAE(features_y[4], f_xc_c5)
content_loss.backward(retain_graph=True)
# Perceptual Loss 2
features_content = vgg(target)
f_xc_c = Variable(features_content[0].data, requires_grad=False)
features_y = vgg(x_hat)
features_content = vgg(target_256)
f_xc_c64 = Variable(features_content[0].data, requires_grad=False)
features_y64 = vgg(x_hat64)
content_loss1 = lambda_p*lambdaIMG* criterionCAE(features_y[0], f_xc_c) + lambda_p*0.33*lambdaIMG* criterionCAE(features_y64[0], f_xc_c64)
content_loss1.backward(retain_graph=True)
else:
L_img_ = 1.2 *criterionCAE(xcl_class1, target)
L_img_ = L_img_ + 1.2 *criterionCAE(xcl_class2, target)
L_img_ = L_img_ + 3.6 *criterionCAE(xcl_class3, target)
L_img_ = L_img_ + 1.2 *criterionCAE(xcl_class4, target)
L_img = lambdaIMG * L_img_
if lambdaIMG != 0:
L_img.backward(retain_graph=True)
if ganIterations % (25*opt.display) == 0:
print(L_img_.data[0])
print("updating fisrt stage parameters")
sys.stdout.flush()
if ganIterations % 2 == 0:
optimizerG.step()
ganIterations += 1
if ganIterations % opt.display == 0:
print('[%d/%d][%d/%d] Loss: %f '
% (epoch, opt.niter, i, len(dataloader),
L_img.data[0]))
sys.stdout.flush()
trainLogger.write('%d\t%f\n' % \
(i, L_img.data[0]))
trainLogger.flush()
#validation
if ganIterations % (int(len(dataloader)/2)) == 0:
val_batch_output = torch.zeros([16,3,128,128], dtype=torch.float32)#torch.FloatTensor([10,3,128,128]).fill_(0)
for idx in range(val_input.size(0)):
single_img = val_input[idx,:,:,:].unsqueeze(0)
val_inputv = Variable(single_img, volatile=True)
with torch.no_grad():
index = idx+24500
val_inputv = val_inputv.cpu().numpy()
val_inputv[0,0,:,:]= signal.convolve(val_inputv[0,0,:,:],kernels[index,:,:],mode='same')
val_inputv[0,1,:,:]= signal.convolve(val_inputv[0,1,:,:],kernels[index,:,:],mode='same')
val_inputv[0,2,:,:]= signal.convolve(val_inputv[0,2,:,:],kernels[index,:,:],mode='same')
val_inputv = val_inputv[:,:,x1:x1+opt.imageSize,y1:y1+opt.imageSize]
val_inputv = val_inputv + (1.0/255.0)* np.random.normal(0,4,val_inputv.shape)
val_inputv = torch.from_numpy(val_inputv)
val_inputv = val_inputv.float().cuda()
val_inputv_256 = torch.nn.functional.interpolate(val_inputv,scale_factor=0.5)
#rint(val_inputv.size())
smaps,smaps64 = netS(val_inputv,val_inputv_256)
class1 = torch.zeros([1,1,128,128], dtype=torch.float32)
class1[:,0,:,:] = smaps[:,0,:,:]
class2 = torch.zeros([1,1,128,128], dtype=torch.float32)
class2[:,0,:,:] = smaps[:,1,:,:]
class3 = torch.zeros([1,1,128,128], dtype=torch.float32)
class3[:,0,:,:] = smaps[:,2,:,:]
class4 = torch.zeros([1,1,128,128], dtype=torch.float32)
class4[:,0,:,:] = smaps[:,3,:,:]
class_msk1 = torch.zeros([1,3,128,128], dtype=torch.float32)
class_msk1[:,0,:,:] = smaps[:,0,:,:]
class_msk1[:,1,:,:] = smaps[:,0,:,:]
class_msk1[:,2,:,:] = smaps[:,0,:,:]
class_msk2 = torch.zeros([1,3,128,128], dtype=torch.float32)
class_msk2[:,0,:,:] = smaps[:,1,:,:]
class_msk2[:,1,:,:] = smaps[:,1,:,:]
class_msk2[:,2,:,:] = smaps[:,1,:,:]
class_msk3 = torch.zeros([1,3,128,128], dtype=torch.float32)
class_msk3[:,0,:,:] = smaps[:,2,:,:]
class_msk3[:,1,:,:] = smaps[:,2,:,:]
class_msk3[:,2,:,:] = smaps[:,2,:,:]
class_msk4 = torch.zeros([1,3,128,128], dtype=torch.float32)
class_msk4[:,0,:,:] = smaps[:,3,:,:]
class_msk4[:,1,:,:] = smaps[:,3,:,:]
class_msk4[:,2,:,:] = smaps[:,3,:,:]
x_hat_val, x_hat_val64,xmask1,xmask2,xmask3,xmask4,xcl_class1,xcl_class2,xcl_class3,xcl_class4 = netG(val_inputv,val_inputv_256,smaps,class1,class2,class3,class4,val_inputv,class_msk1,class_msk2,class_msk3,class_msk4)
#x_hat_val.data[0,:,:,:] = masks*x_hat_val.data[0,:,:,:]
val_batch_output[idx,:,:,:].copy_(x_hat_val.data[0,:,:,:])
### We use a random label here just for intermediate result visuliztion (No need to worry about the label here) ##
if ganIterations % (int(len(dataloader)/2)) == 0:
vutils.save_image(val_batch_output, '%s/generated_epoch_iter%08d.png' % \
(opt.exp, ganIterations), normalize=True, scale_each=False)
del val_batch_output
if ganIterations % (int(len(dataloader)/2)) == 0:
torch.save(netG.state_dict(), '%s/Deblur_epoch_%d.pth' % (opt.exp, count))
#torch.save(netC.state_dict(), '%s/Deblur_first_epoch_%d.pth' % (opt.exp, count))
count = count +1
trainLogger.close()