diff --git a/VGG_FACE.pth b/VGG_FACE.pth new file mode 100644 index 0000000..3d7fc7f Binary files /dev/null and b/VGG_FACE.pth differ diff --git a/train_face_deblur.py b/train_face_deblur.py index dfc61f1..d62f4d7 100644 --- a/train_face_deblur.py +++ b/train_face_deblur.py @@ -71,14 +71,51 @@ from scipy import signal import h5py from scipy import signal import random + +#loading kernels mat file k_filename ='./kernel.mat' kfp = h5py.File(k_filename) kernels = np.array(kfp['kernels']) kernels = kernels.transpose([0,2,1]) + +vgg = Vgg16() +utils.init_vgg16('./models/') +state_dict_g = torch.load('VGG_FACE.pth') +new_state_dict_g = {} +new_state_dict_g["conv1_1.weight"]= state_dict_g["0.weight"] +new_state_dict_g["conv1_1.bias"]= state_dict_g["0.bias"] +new_state_dict_g["conv1_2.weight"]= state_dict_g["2.weight"] +new_state_dict_g["conv1_2.bias"]= state_dict_g["2.bias"] +new_state_dict_g["conv2_1.weight"]= state_dict_g["5.weight"] +new_state_dict_g["conv2_1.bias"]= state_dict_g["5.bias"] +new_state_dict_g["conv2_2.weight"]= state_dict_g["7.weight"] +new_state_dict_g["conv2_2.bias"]= state_dict_g["7.bias"] +new_state_dict_g["conv3_1.weight"]= state_dict_g["10.weight"] +new_state_dict_g["conv3_1.bias"]= state_dict_g["10.bias"] +new_state_dict_g["conv3_2.weight"]= state_dict_g["12.weight"] +new_state_dict_g["conv3_2.bias"]= state_dict_g["12.bias"] +new_state_dict_g["conv3_3.weight"]= state_dict_g["14.weight"] +new_state_dict_g["conv3_3.bias"]= state_dict_g["14.bias"] +new_state_dict_g["conv4_1.weight"]= state_dict_g["17.weight"] +new_state_dict_g["conv4_1.bias"]= state_dict_g["17.bias"] +new_state_dict_g["conv4_2.weight"]= state_dict_g["19.weight"] +new_state_dict_g["conv4_2.bias"]= state_dict_g["19.bias"] +new_state_dict_g["conv4_3.weight"]= state_dict_g["21.weight"] +new_state_dict_g["conv4_3.bias"]= state_dict_g["21.bias"] +new_state_dict_g["conv5_1.weight"]= state_dict_g["24.weight"] +new_state_dict_g["conv5_1.bias"]= state_dict_g["24.bias"] +new_state_dict_g["conv5_2.weight"]= state_dict_g["26.weight"] +new_state_dict_g["conv5_2.bias"]= state_dict_g["26.bias"] +new_state_dict_g["conv5_3.weight"]= state_dict_g["28.weight"] +new_state_dict_g["conv5_3.bias"]= state_dict_g["28.bias"] +vgg.load_state_dict(new_state_dict_g) + +vgg = torch.nn.DataParallel(vgg) +vgg.cuda() + create_exp_dir(opt.exp) opt.manualSeed = random.randint(1, 10000) -# opt.manualSeed = 101 random.seed(opt.manualSeed) torch.manual_seed(opt.manualSeed) torch.cuda.manual_seed_all(opt.manualSeed) @@ -133,15 +170,17 @@ netG=net.Deblur_segdl() netS.load_state_dict(torch.load('./pretrained_models/SMaps_Best.pth')) +netG.apply(weights_init) +if opt.netG != '': + state_dict_g = torch.load(opt.netG) + new_state_dict_g = {} + for k, v in state_dict_g.items(): + name = k[7:] + new_state_dict_g[name] = v + # load params + netG.load_state_dict(new_state_dict_g) +print(netG) -# state_dict_g = torch.load('./face_deblur/Deblur_epoch_46.pth') -# new_state_dict_g = {} -# for k, v in state_dict_g.items(): -# name = k[7:] # remove `module.` -# #print(k) -# new_state_dict_g[name] = v -# # load params -# netG.load_state_dict(new_state_dict_g) netG = torch.nn.DataParallel(netG) netS = torch.nn.DataParallel(netS) @@ -180,22 +219,11 @@ val_depth = torch.FloatTensor(opt.valBatchSize, inputChannelSize, opt.imageSize, val_ato = torch.FloatTensor(opt.valBatchSize, inputChannelSize, opt.imageSize, opt.imageSize) - - -# NOTE: size of 2D output maps in the discriminator -sizePatchGAN = 30 -real_label = 1 -fake_label = 0 - # image pool storing previously generated samples from G -imagePool = ImagePool(opt.poolSize) - -# NOTE weight for L_cGAN and L_L1 (i.e. Eq.(4) in the paper) lambdaGAN = opt.lambdaGAN lambdaIMG = opt.lambdaIMG netG.cuda() -#netC.cuda() netS.cuda() criterionCAE.cuda() criterionCAE1.cuda() @@ -218,8 +246,6 @@ target_128 = Variable(target_128) input_128 = Variable(input_128) target_256 = Variable(target_256) input_256 = Variable(input_256) -# input = Variable(input,requires_grad=False) -# depth = Variable(depth) ato = Variable(ato) # Initialize VGG-16 @@ -255,7 +281,7 @@ vutils.save_image(val_input, '%s/real_input.png' % opt.exp, normalize=True) optimizerG = optim.Adam(netG.parameters(), lr = opt.lrG, betas = (opt.beta1, 0.999), weight_decay=0.00005) # NOTE training loop ganIterations = 0 -count = 48 +count = 0 for epoch in range(opt.niter): @@ -274,6 +300,8 @@ for epoch in range(opt.niter): b,ch,x,y = target_cpu.size() x1 = int((x-opt.imageSize)/2) y1 = int((y-opt.imageSize)/2) + + #generating blurry image input_cpu = input_cpu.numpy() target_cpu = target_cpu.numpy() for j in range(batch_size): @@ -292,8 +320,7 @@ for epoch in range(opt.niter): target_cpu, input_cpu = target_cpu.float().cuda(), input_cpu.float().cuda() - - # get paired data + # getting input and target image at 0.5 scale target.data.resize_as_(target_cpu).copy_(target_cpu) input.data.resize_as_(input_cpu).copy_(input_cpu) input_256 = torch.nn.functional.interpolate(input,scale_factor=0.5) @@ -301,6 +328,7 @@ for epoch in range(opt.niter): + # computing segmentation masks for input and target with torch.no_grad(): smaps_i,smaps_i64 = netS(input,input_256) smaps,smaps64 = netS(target,target_256) @@ -329,10 +357,6 @@ for epoch in range(opt.niter): class_msk4[:,1,:,:] = smaps[:,3,:,:] class_msk4[:,2,:,:] = smaps[:,3,:,:] - - - - class1 = class1.float().cuda() class2 = class2.float().cuda() class3 = class3.float().cuda() @@ -341,24 +365,21 @@ for epoch in range(opt.niter): class_msk3 = class_msk3.float().cuda() class_msk2 = class_msk2.float().cuda() class_msk1 = class_msk1.float().cuda() + + # Forward step x_hat1,x_hat64,xmask1,xmask2,xmask3,xmask4,xcl_class1,xcl_class2,xcl_class3,xcl_class4 = netG(input,input_256,smaps_i,class1,class2,class3,class4,target,class_msk1,class_msk2,class_msk3,class_msk4) x_hat = x_hat1 - #xeff = conf*x_hat+(1-conf)*target - #xeff_64 = conf_64*x_hat64+(1-conf_64)*target_256 - - - #print(x_hat.size()) + if ganIterations % 2 == 0: netG.zero_grad() # start to update G - #x1 = xmask1*class_msk1*x_hat+(1-xmask1)*class_msk1*target - #smaps_hat,smaps64_hat = netS(x_hat1,x_hat64) + if epoch>-1: - with torch.no_grad(): - smaps,smaps64 = netS(target,target_256) + # with torch.no_grad(): + # smaps,smaps64 = netS(target,target_256) L_img_ = 0.33*criterionCAE(x_hat64, target_256) #+ 0.5*criterionCAE(smaps_hat, smaps) L_img_ = L_img_ + 1.2 *criterionCAE(xmask1*class_msk1*x_hat+(1-xmask1)*class_msk1*target, class_msk1*target) L_img_ = L_img_ + 1.2 *criterionCAE(xmask2*class_msk2*x_hat+(1-xmask2)*class_msk2*target, class_msk2*target) @@ -367,10 +388,10 @@ for epoch in range(opt.niter): if ganIterations % (25*opt.display) == 0: print(L_img_.data[0]) sys.stdout.flush() - if ganIterations< 1000: + if ganIterations< -1: lam_cmp = 1.0 else : - lam_cmp = 0.09 + lam_cmp = 0.06 sng = 0.00000001 L_img_ = L_img_ - (lam_cmp/(4.0))*torch.mean(torch.log(xmask1+sng)) L_img_ = L_img_ - (lam_cmp/(4.0))*torch.mean(torch.log(xmask2+sng)) @@ -380,8 +401,6 @@ for epoch in range(opt.niter): print(L_img_.data[0]) sys.stdout.flush() - #L_img_ = L_img_ + 2*criterionCAE(class_msk3*x_hat,class_msk3*target) - # L_res = lambdaIMG * L_res_ gradh_xhat,gradv_xhat=gradient(x_hat) gradh_tar,gradv_tar=gradient(target) gradh_xhat64,gradv_xhat64=gradient(x_hat64) @@ -391,24 +410,23 @@ for epoch in range(opt.niter): print(L_img_.data[0]) print((torch.mean(torch.log(xmask1)).data),(torch.mean(torch.log(xmask2)).data),(torch.mean(xmask3).data),(torch.mean(xmask4).data)) sys.stdout.flush() - # L_res = lambdaIMG * L_res_ - L_img = lambdaIMG * L_img_ - if lambdaIMG <> 0: - #L_img.backward(retain_graph=True) # in case of current version of pytorch + L_img = lambdaIMG * L_img_ + #Backward step or computing gradients + if lambdaIMG != 0: L_img.backward(retain_graph=True) - # L_res.backward(retain_variables=True) # Perceptual Loss 1 features_content = vgg(target) f_xc_c = Variable(features_content[1].data, requires_grad=False) + f_xc_c5 = Variable(features_content[4].data, requires_grad=False) features_y = vgg(x_hat) features_content = vgg(target_256) f_xc_c64 = Variable(features_content[1].data, requires_grad=False) features_y64 = vgg(x_hat64) - - content_loss = 1.8*lambdaIMG* criterionCAE(features_y[1], f_xc_c) + 1.8*0.33*lambdaIMG* criterionCAE(features_y64[1], f_xc_c64) + lambda_p=0.00018 + content_loss = lambda_p*lambdaIMG* criterionCAE(features_y[1], f_xc_c) + lambda_p*0.33*lambdaIMG* criterionCAE(features_y64[1], f_xc_c64) + lambda_p*lambdaIMG* criterionCAE(features_y[4], f_xc_c5) content_loss.backward(retain_graph=True) # Perceptual Loss 2 @@ -420,7 +438,7 @@ for epoch in range(opt.niter): f_xc_c64 = Variable(features_content[0].data, requires_grad=False) features_y64 = vgg(x_hat64) - content_loss1 = 1.8*lambdaIMG* criterionCAE(features_y[0], f_xc_c) + 1.8*0.33*lambdaIMG* criterionCAE(features_y64[0], f_xc_c64) + content_loss1 = lambda_p*lambdaIMG* criterionCAE(features_y[0], f_xc_c) + lambda_p*0.33*lambdaIMG* criterionCAE(features_y64[0], f_xc_c64) content_loss1.backward(retain_graph=True) @@ -430,7 +448,7 @@ for epoch in range(opt.niter): L_img_ = L_img_ + 3.6 *criterionCAE(xcl_class3, target) L_img_ = L_img_ + 1.2 *criterionCAE(xcl_class4, target) L_img = lambdaIMG * L_img_ - if lambdaIMG <> 0: + if lambdaIMG != 0: L_img.backward(retain_graph=True) if ganIterations % (25*opt.display) == 0: print(L_img_.data[0]) @@ -451,16 +469,15 @@ for epoch in range(opt.niter): trainLogger.write('%d\t%f\n' % \ (i, L_img.data[0])) trainLogger.flush() + + #validation if ganIterations % (int(len(dataloader)/2)) == 0: val_batch_output = torch.zeros([16,3,128,128], dtype=torch.float32)#torch.FloatTensor([10,3,128,128]).fill_(0) for idx in range(val_input.size(0)): single_img = val_input[idx,:,:,:].unsqueeze(0) val_inputv = Variable(single_img, volatile=True) with torch.no_grad(): - #smaps_vl = netS(val_inputv) - #S_valinput = torch.cat([smaps_vl,val_inputv],1) index = idx+24500 - #rint(val_inputv.size()) val_inputv = val_inputv.cpu().numpy() val_inputv[0,0,:,:]= signal.convolve(val_inputv[0,0,:,:],kernels[index,:,:],mode='same') val_inputv[0,1,:,:]= signal.convolve(val_inputv[0,1,:,:],kernels[index,:,:],mode='same')