|
|
|
@ -248,11 +248,7 @@ target_256 = Variable(target_256)
|
|
|
|
input_256 = Variable(input_256)
|
|
|
|
input_256 = Variable(input_256)
|
|
|
|
ato = Variable(ato)
|
|
|
|
ato = Variable(ato)
|
|
|
|
|
|
|
|
|
|
|
|
# Initialize VGG-16
|
|
|
|
|
|
|
|
vgg = Vgg16()
|
|
|
|
|
|
|
|
utils.init_vgg16('./models/')
|
|
|
|
|
|
|
|
vgg.load_state_dict(torch.load(os.path.join('./models/', "vgg16.weight")))
|
|
|
|
|
|
|
|
vgg.cuda()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
label_d = Variable(label_d.cuda())
|
|
|
|
label_d = Variable(label_d.cuda())
|
|
|
|
|