|
|
|
|
@ -71,14 +71,51 @@ from scipy import signal
|
|
|
|
|
import h5py
|
|
|
|
|
from scipy import signal
|
|
|
|
|
import random
|
|
|
|
|
|
|
|
|
|
#loading kernels mat file
|
|
|
|
|
k_filename ='./kernel.mat'
|
|
|
|
|
kfp = h5py.File(k_filename)
|
|
|
|
|
kernels = np.array(kfp['kernels'])
|
|
|
|
|
kernels = kernels.transpose([0,2,1])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
vgg = Vgg16()
|
|
|
|
|
utils.init_vgg16('./models/')
|
|
|
|
|
state_dict_g = torch.load('VGG_FACE.pth')
|
|
|
|
|
new_state_dict_g = {}
|
|
|
|
|
new_state_dict_g["conv1_1.weight"]= state_dict_g["0.weight"]
|
|
|
|
|
new_state_dict_g["conv1_1.bias"]= state_dict_g["0.bias"]
|
|
|
|
|
new_state_dict_g["conv1_2.weight"]= state_dict_g["2.weight"]
|
|
|
|
|
new_state_dict_g["conv1_2.bias"]= state_dict_g["2.bias"]
|
|
|
|
|
new_state_dict_g["conv2_1.weight"]= state_dict_g["5.weight"]
|
|
|
|
|
new_state_dict_g["conv2_1.bias"]= state_dict_g["5.bias"]
|
|
|
|
|
new_state_dict_g["conv2_2.weight"]= state_dict_g["7.weight"]
|
|
|
|
|
new_state_dict_g["conv2_2.bias"]= state_dict_g["7.bias"]
|
|
|
|
|
new_state_dict_g["conv3_1.weight"]= state_dict_g["10.weight"]
|
|
|
|
|
new_state_dict_g["conv3_1.bias"]= state_dict_g["10.bias"]
|
|
|
|
|
new_state_dict_g["conv3_2.weight"]= state_dict_g["12.weight"]
|
|
|
|
|
new_state_dict_g["conv3_2.bias"]= state_dict_g["12.bias"]
|
|
|
|
|
new_state_dict_g["conv3_3.weight"]= state_dict_g["14.weight"]
|
|
|
|
|
new_state_dict_g["conv3_3.bias"]= state_dict_g["14.bias"]
|
|
|
|
|
new_state_dict_g["conv4_1.weight"]= state_dict_g["17.weight"]
|
|
|
|
|
new_state_dict_g["conv4_1.bias"]= state_dict_g["17.bias"]
|
|
|
|
|
new_state_dict_g["conv4_2.weight"]= state_dict_g["19.weight"]
|
|
|
|
|
new_state_dict_g["conv4_2.bias"]= state_dict_g["19.bias"]
|
|
|
|
|
new_state_dict_g["conv4_3.weight"]= state_dict_g["21.weight"]
|
|
|
|
|
new_state_dict_g["conv4_3.bias"]= state_dict_g["21.bias"]
|
|
|
|
|
new_state_dict_g["conv5_1.weight"]= state_dict_g["24.weight"]
|
|
|
|
|
new_state_dict_g["conv5_1.bias"]= state_dict_g["24.bias"]
|
|
|
|
|
new_state_dict_g["conv5_2.weight"]= state_dict_g["26.weight"]
|
|
|
|
|
new_state_dict_g["conv5_2.bias"]= state_dict_g["26.bias"]
|
|
|
|
|
new_state_dict_g["conv5_3.weight"]= state_dict_g["28.weight"]
|
|
|
|
|
new_state_dict_g["conv5_3.bias"]= state_dict_g["28.bias"]
|
|
|
|
|
vgg.load_state_dict(new_state_dict_g)
|
|
|
|
|
|
|
|
|
|
vgg = torch.nn.DataParallel(vgg)
|
|
|
|
|
vgg.cuda()
|
|
|
|
|
|
|
|
|
|
create_exp_dir(opt.exp)
|
|
|
|
|
opt.manualSeed = random.randint(1, 10000)
|
|
|
|
|
# opt.manualSeed = 101
|
|
|
|
|
random.seed(opt.manualSeed)
|
|
|
|
|
torch.manual_seed(opt.manualSeed)
|
|
|
|
|
torch.cuda.manual_seed_all(opt.manualSeed)
|
|
|
|
|
@ -133,15 +170,17 @@ netG=net.Deblur_segdl()
|
|
|
|
|
|
|
|
|
|
netS.load_state_dict(torch.load('./pretrained_models/SMaps_Best.pth'))
|
|
|
|
|
|
|
|
|
|
netG.apply(weights_init)
|
|
|
|
|
if opt.netG != '':
|
|
|
|
|
state_dict_g = torch.load(opt.netG)
|
|
|
|
|
new_state_dict_g = {}
|
|
|
|
|
for k, v in state_dict_g.items():
|
|
|
|
|
name = k[7:]
|
|
|
|
|
new_state_dict_g[name] = v
|
|
|
|
|
# load params
|
|
|
|
|
netG.load_state_dict(new_state_dict_g)
|
|
|
|
|
print(netG)
|
|
|
|
|
|
|
|
|
|
# state_dict_g = torch.load('./face_deblur/Deblur_epoch_46.pth')
|
|
|
|
|
# new_state_dict_g = {}
|
|
|
|
|
# for k, v in state_dict_g.items():
|
|
|
|
|
# name = k[7:] # remove `module.`
|
|
|
|
|
# #print(k)
|
|
|
|
|
# new_state_dict_g[name] = v
|
|
|
|
|
# # load params
|
|
|
|
|
# netG.load_state_dict(new_state_dict_g)
|
|
|
|
|
|
|
|
|
|
netG = torch.nn.DataParallel(netG)
|
|
|
|
|
netS = torch.nn.DataParallel(netS)
|
|
|
|
|
@ -180,22 +219,11 @@ val_depth = torch.FloatTensor(opt.valBatchSize, inputChannelSize, opt.imageSize,
|
|
|
|
|
val_ato = torch.FloatTensor(opt.valBatchSize, inputChannelSize, opt.imageSize, opt.imageSize)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# NOTE: size of 2D output maps in the discriminator
|
|
|
|
|
sizePatchGAN = 30
|
|
|
|
|
real_label = 1
|
|
|
|
|
fake_label = 0
|
|
|
|
|
|
|
|
|
|
# image pool storing previously generated samples from G
|
|
|
|
|
imagePool = ImagePool(opt.poolSize)
|
|
|
|
|
|
|
|
|
|
# NOTE weight for L_cGAN and L_L1 (i.e. Eq.(4) in the paper)
|
|
|
|
|
lambdaGAN = opt.lambdaGAN
|
|
|
|
|
lambdaIMG = opt.lambdaIMG
|
|
|
|
|
|
|
|
|
|
netG.cuda()
|
|
|
|
|
#netC.cuda()
|
|
|
|
|
netS.cuda()
|
|
|
|
|
criterionCAE.cuda()
|
|
|
|
|
criterionCAE1.cuda()
|
|
|
|
|
@ -218,8 +246,6 @@ target_128 = Variable(target_128)
|
|
|
|
|
input_128 = Variable(input_128)
|
|
|
|
|
target_256 = Variable(target_256)
|
|
|
|
|
input_256 = Variable(input_256)
|
|
|
|
|
# input = Variable(input,requires_grad=False)
|
|
|
|
|
# depth = Variable(depth)
|
|
|
|
|
ato = Variable(ato)
|
|
|
|
|
|
|
|
|
|
# Initialize VGG-16
|
|
|
|
|
@ -255,7 +281,7 @@ vutils.save_image(val_input, '%s/real_input.png' % opt.exp, normalize=True)
|
|
|
|
|
optimizerG = optim.Adam(netG.parameters(), lr = opt.lrG, betas = (opt.beta1, 0.999), weight_decay=0.00005)
|
|
|
|
|
# NOTE training loop
|
|
|
|
|
ganIterations = 0
|
|
|
|
|
count = 48
|
|
|
|
|
count = 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for epoch in range(opt.niter):
|
|
|
|
|
@ -274,6 +300,8 @@ for epoch in range(opt.niter):
|
|
|
|
|
b,ch,x,y = target_cpu.size()
|
|
|
|
|
x1 = int((x-opt.imageSize)/2)
|
|
|
|
|
y1 = int((y-opt.imageSize)/2)
|
|
|
|
|
|
|
|
|
|
#generating blurry image
|
|
|
|
|
input_cpu = input_cpu.numpy()
|
|
|
|
|
target_cpu = target_cpu.numpy()
|
|
|
|
|
for j in range(batch_size):
|
|
|
|
|
@ -292,8 +320,7 @@ for epoch in range(opt.niter):
|
|
|
|
|
|
|
|
|
|
target_cpu, input_cpu = target_cpu.float().cuda(), input_cpu.float().cuda()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# get paired data
|
|
|
|
|
# getting input and target image at 0.5 scale
|
|
|
|
|
target.data.resize_as_(target_cpu).copy_(target_cpu)
|
|
|
|
|
input.data.resize_as_(input_cpu).copy_(input_cpu)
|
|
|
|
|
input_256 = torch.nn.functional.interpolate(input,scale_factor=0.5)
|
|
|
|
|
@ -301,6 +328,7 @@ for epoch in range(opt.niter):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# computing segmentation masks for input and target
|
|
|
|
|
with torch.no_grad():
|
|
|
|
|
smaps_i,smaps_i64 = netS(input,input_256)
|
|
|
|
|
smaps,smaps64 = netS(target,target_256)
|
|
|
|
|
@ -329,10 +357,6 @@ for epoch in range(opt.niter):
|
|
|
|
|
class_msk4[:,1,:,:] = smaps[:,3,:,:]
|
|
|
|
|
class_msk4[:,2,:,:] = smaps[:,3,:,:]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class1 = class1.float().cuda()
|
|
|
|
|
class2 = class2.float().cuda()
|
|
|
|
|
class3 = class3.float().cuda()
|
|
|
|
|
@ -341,24 +365,21 @@ for epoch in range(opt.niter):
|
|
|
|
|
class_msk3 = class_msk3.float().cuda()
|
|
|
|
|
class_msk2 = class_msk2.float().cuda()
|
|
|
|
|
class_msk1 = class_msk1.float().cuda()
|
|
|
|
|
|
|
|
|
|
# Forward step
|
|
|
|
|
x_hat1,x_hat64,xmask1,xmask2,xmask3,xmask4,xcl_class1,xcl_class2,xcl_class3,xcl_class4 = netG(input,input_256,smaps_i,class1,class2,class3,class4,target,class_msk1,class_msk2,class_msk3,class_msk4)
|
|
|
|
|
|
|
|
|
|
x_hat = x_hat1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#xeff = conf*x_hat+(1-conf)*target
|
|
|
|
|
#xeff_64 = conf_64*x_hat64+(1-conf_64)*target_256
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#print(x_hat.size())
|
|
|
|
|
|
|
|
|
|
if ganIterations % 2 == 0:
|
|
|
|
|
netG.zero_grad() # start to update G
|
|
|
|
|
|
|
|
|
|
#x1 = xmask1*class_msk1*x_hat+(1-xmask1)*class_msk1*target
|
|
|
|
|
#smaps_hat,smaps64_hat = netS(x_hat1,x_hat64)
|
|
|
|
|
|
|
|
|
|
if epoch>-1:
|
|
|
|
|
with torch.no_grad():
|
|
|
|
|
smaps,smaps64 = netS(target,target_256)
|
|
|
|
|
# with torch.no_grad():
|
|
|
|
|
# smaps,smaps64 = netS(target,target_256)
|
|
|
|
|
L_img_ = 0.33*criterionCAE(x_hat64, target_256) #+ 0.5*criterionCAE(smaps_hat, smaps)
|
|
|
|
|
L_img_ = L_img_ + 1.2 *criterionCAE(xmask1*class_msk1*x_hat+(1-xmask1)*class_msk1*target, class_msk1*target)
|
|
|
|
|
L_img_ = L_img_ + 1.2 *criterionCAE(xmask2*class_msk2*x_hat+(1-xmask2)*class_msk2*target, class_msk2*target)
|
|
|
|
|
@ -367,10 +388,10 @@ for epoch in range(opt.niter):
|
|
|
|
|
if ganIterations % (25*opt.display) == 0:
|
|
|
|
|
print(L_img_.data[0])
|
|
|
|
|
sys.stdout.flush()
|
|
|
|
|
if ganIterations< 1000:
|
|
|
|
|
if ganIterations< -1:
|
|
|
|
|
lam_cmp = 1.0
|
|
|
|
|
else :
|
|
|
|
|
lam_cmp = 0.09
|
|
|
|
|
lam_cmp = 0.06
|
|
|
|
|
sng = 0.00000001
|
|
|
|
|
L_img_ = L_img_ - (lam_cmp/(4.0))*torch.mean(torch.log(xmask1+sng))
|
|
|
|
|
L_img_ = L_img_ - (lam_cmp/(4.0))*torch.mean(torch.log(xmask2+sng))
|
|
|
|
|
@ -380,8 +401,6 @@ for epoch in range(opt.niter):
|
|
|
|
|
print(L_img_.data[0])
|
|
|
|
|
sys.stdout.flush()
|
|
|
|
|
|
|
|
|
|
#L_img_ = L_img_ + 2*criterionCAE(class_msk3*x_hat,class_msk3*target)
|
|
|
|
|
# L_res = lambdaIMG * L_res_
|
|
|
|
|
gradh_xhat,gradv_xhat=gradient(x_hat)
|
|
|
|
|
gradh_tar,gradv_tar=gradient(target)
|
|
|
|
|
gradh_xhat64,gradv_xhat64=gradient(x_hat64)
|
|
|
|
|
@ -391,24 +410,23 @@ for epoch in range(opt.niter):
|
|
|
|
|
print(L_img_.data[0])
|
|
|
|
|
print((torch.mean(torch.log(xmask1)).data),(torch.mean(torch.log(xmask2)).data),(torch.mean(xmask3).data),(torch.mean(xmask4).data))
|
|
|
|
|
sys.stdout.flush()
|
|
|
|
|
# L_res = lambdaIMG * L_res_
|
|
|
|
|
L_img = lambdaIMG * L_img_
|
|
|
|
|
|
|
|
|
|
if lambdaIMG <> 0:
|
|
|
|
|
#L_img.backward(retain_graph=True) # in case of current version of pytorch
|
|
|
|
|
L_img = lambdaIMG * L_img_
|
|
|
|
|
#Backward step or computing gradients
|
|
|
|
|
if lambdaIMG != 0:
|
|
|
|
|
L_img.backward(retain_graph=True)
|
|
|
|
|
# L_res.backward(retain_variables=True)
|
|
|
|
|
|
|
|
|
|
# Perceptual Loss 1
|
|
|
|
|
features_content = vgg(target)
|
|
|
|
|
f_xc_c = Variable(features_content[1].data, requires_grad=False)
|
|
|
|
|
f_xc_c5 = Variable(features_content[4].data, requires_grad=False)
|
|
|
|
|
features_y = vgg(x_hat)
|
|
|
|
|
|
|
|
|
|
features_content = vgg(target_256)
|
|
|
|
|
f_xc_c64 = Variable(features_content[1].data, requires_grad=False)
|
|
|
|
|
features_y64 = vgg(x_hat64)
|
|
|
|
|
|
|
|
|
|
content_loss = 1.8*lambdaIMG* criterionCAE(features_y[1], f_xc_c) + 1.8*0.33*lambdaIMG* criterionCAE(features_y64[1], f_xc_c64)
|
|
|
|
|
lambda_p=0.00018
|
|
|
|
|
content_loss = lambda_p*lambdaIMG* criterionCAE(features_y[1], f_xc_c) + lambda_p*0.33*lambdaIMG* criterionCAE(features_y64[1], f_xc_c64) + lambda_p*lambdaIMG* criterionCAE(features_y[4], f_xc_c5)
|
|
|
|
|
content_loss.backward(retain_graph=True)
|
|
|
|
|
|
|
|
|
|
# Perceptual Loss 2
|
|
|
|
|
@ -420,7 +438,7 @@ for epoch in range(opt.niter):
|
|
|
|
|
f_xc_c64 = Variable(features_content[0].data, requires_grad=False)
|
|
|
|
|
features_y64 = vgg(x_hat64)
|
|
|
|
|
|
|
|
|
|
content_loss1 = 1.8*lambdaIMG* criterionCAE(features_y[0], f_xc_c) + 1.8*0.33*lambdaIMG* criterionCAE(features_y64[0], f_xc_c64)
|
|
|
|
|
content_loss1 = lambda_p*lambdaIMG* criterionCAE(features_y[0], f_xc_c) + lambda_p*0.33*lambdaIMG* criterionCAE(features_y64[0], f_xc_c64)
|
|
|
|
|
content_loss1.backward(retain_graph=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -430,7 +448,7 @@ for epoch in range(opt.niter):
|
|
|
|
|
L_img_ = L_img_ + 3.6 *criterionCAE(xcl_class3, target)
|
|
|
|
|
L_img_ = L_img_ + 1.2 *criterionCAE(xcl_class4, target)
|
|
|
|
|
L_img = lambdaIMG * L_img_
|
|
|
|
|
if lambdaIMG <> 0:
|
|
|
|
|
if lambdaIMG != 0:
|
|
|
|
|
L_img.backward(retain_graph=True)
|
|
|
|
|
if ganIterations % (25*opt.display) == 0:
|
|
|
|
|
print(L_img_.data[0])
|
|
|
|
|
@ -451,16 +469,15 @@ for epoch in range(opt.niter):
|
|
|
|
|
trainLogger.write('%d\t%f\n' % \
|
|
|
|
|
(i, L_img.data[0]))
|
|
|
|
|
trainLogger.flush()
|
|
|
|
|
|
|
|
|
|
#validation
|
|
|
|
|
if ganIterations % (int(len(dataloader)/2)) == 0:
|
|
|
|
|
val_batch_output = torch.zeros([16,3,128,128], dtype=torch.float32)#torch.FloatTensor([10,3,128,128]).fill_(0)
|
|
|
|
|
for idx in range(val_input.size(0)):
|
|
|
|
|
single_img = val_input[idx,:,:,:].unsqueeze(0)
|
|
|
|
|
val_inputv = Variable(single_img, volatile=True)
|
|
|
|
|
with torch.no_grad():
|
|
|
|
|
#smaps_vl = netS(val_inputv)
|
|
|
|
|
#S_valinput = torch.cat([smaps_vl,val_inputv],1)
|
|
|
|
|
index = idx+24500
|
|
|
|
|
#rint(val_inputv.size())
|
|
|
|
|
val_inputv = val_inputv.cpu().numpy()
|
|
|
|
|
val_inputv[0,0,:,:]= signal.convolve(val_inputv[0,0,:,:],kernels[index,:,:],mode='same')
|
|
|
|
|
val_inputv[0,1,:,:]= signal.convolve(val_inputv[0,1,:,:],kernels[index,:,:],mode='same')
|
|
|
|
|
|