You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
95 lines
2.8 KiB
95 lines
2.8 KiB
import os
|
|
|
|
import numpy as np
|
|
import torch
|
|
from PIL import Image
|
|
from torch.autograd import Variable
|
|
from torch.utils.serialization import load_lua
|
|
|
|
from myutils.vgg16 import Vgg16
|
|
|
|
def tensor_load_rgbimage(filename, size=None, scale=None, keep_asp=False):
|
|
img = Image.open(filename).convert('RGB')
|
|
if size is not None:
|
|
if keep_asp:
|
|
size2 = int(size * 1.0 / img.size[0] * img.size[1])
|
|
img = img.resize((size, size2), Image.ANTIALIAS)
|
|
else:
|
|
img = img.resize((size, size), Image.ANTIALIAS)
|
|
|
|
elif scale is not None:
|
|
img = img.resize((int(img.size[0] / scale), int(img.size[1] / scale)), Image.ANTIALIAS)
|
|
img = np.array(img).transpose(2, 0, 1)
|
|
img = torch.from_numpy(img).float()
|
|
return img
|
|
|
|
|
|
def tensor_save_rgbimage(tensor, filename, cuda=False):
|
|
if cuda:
|
|
img = tensor.clone().cpu().clamp(0, 255).numpy()
|
|
else:
|
|
img = tensor.clone().clamp(0, 255).numpy()
|
|
img = img.transpose(1, 2, 0).astype('uint8')
|
|
img = Image.fromarray(img)
|
|
img.save(filename)
|
|
|
|
|
|
def tensor_save_bgrimage(tensor, filename, cuda=False):
|
|
(b, g, r) = torch.chunk(tensor, 3)
|
|
tensor = torch.cat((r, g, b))
|
|
tensor_save_rgbimage(tensor, filename, cuda)
|
|
|
|
|
|
def gram_matrix(y):
|
|
(b, ch, h, w) = y.size()
|
|
features = y.view(b, ch, w * h)
|
|
features_t = features.transpose(1, 2)
|
|
gram = features.bmm(features_t) / (ch * h * w)
|
|
return gram
|
|
|
|
|
|
def subtract_imagenet_mean_batch(batch):
|
|
"""Subtract ImageNet mean pixel-wise from a BGR image."""
|
|
tensortype = type(batch.data)
|
|
mean = tensortype(batch.data.size())
|
|
mean[:, 0, :, :] = 103.939
|
|
mean[:, 1, :, :] = 116.779
|
|
mean[:, 2, :, :] = 123.680
|
|
return batch - Variable(mean)
|
|
|
|
|
|
def add_imagenet_mean_batch(batch):
|
|
"""Add ImageNet mean pixel-wise from a BGR image."""
|
|
tensortype = type(batch.data)
|
|
mean = tensortype(batch.data.size())
|
|
mean[:, 0, :, :] = 103.939
|
|
mean[:, 1, :, :] = 116.779
|
|
mean[:, 2, :, :] = 123.680
|
|
return batch + Variable(mean)
|
|
|
|
def imagenet_clamp_batch(batch, low, high):
|
|
batch[:,0,:,:].data.clamp_(low-103.939, high-103.939)
|
|
batch[:,1,:,:].data.clamp_(low-116.779, high-116.779)
|
|
batch[:,2,:,:].data.clamp_(low-123.680, high-123.680)
|
|
|
|
|
|
def preprocess_batch(batch):
|
|
batch = batch.transpose(0, 1)
|
|
(r, g, b) = torch.chunk(batch, 3)
|
|
batch = torch.cat((b, g, r))
|
|
batch = batch.transpose(0, 1)
|
|
return batch
|
|
|
|
|
|
def init_vgg16(model_folder):
|
|
"""load the vgg16 model feature"""
|
|
if not os.path.exists(os.path.join(model_folder, 'vgg16.weight')):
|
|
if not os.path.exists(os.path.join(model_folder, 'vgg16.t7')):
|
|
os.system(
|
|
'wget http://cs.stanford.edu/people/jcjohns/fast-neural-style/models/vgg16.t7 -O ' + os.path.join(model_folder, 'vgg16.t7'))
|
|
vgglua = load_lua(os.path.join(model_folder, 'vgg16.t7'))
|
|
vgg = Vgg16()
|
|
for (src, dst) in zip(vgglua.parameters()[0], vgg.parameters()):
|
|
dst.data[:] = src
|
|
torch.save(vgg.state_dict(), os.path.join(model_folder, 'vgg16.weight'))
|