diff --git a/train_face_deblur.py b/train_face_deblur.py index c81a225..f1164da 100644 --- a/train_face_deblur.py +++ b/train_face_deblur.py @@ -248,11 +248,7 @@ target_256 = Variable(target_256) input_256 = Variable(input_256) ato = Variable(ato) -# Initialize VGG-16 -vgg = Vgg16() -utils.init_vgg16('./models/') -vgg.load_state_dict(torch.load(os.path.join('./models/', "vgg16.weight"))) -vgg.cuda() + label_d = Variable(label_d.cuda())