You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
38 lines
1.0 KiB
38 lines
1.0 KiB
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
|
## Created by: Hang Zhang
|
|
## ECE Department, Rutgers University
|
|
## Email: zhang.hang@rutgers.edu
|
|
## Copyright (c) 2017
|
|
##
|
|
## This source code is licensed under the MIT-style license found in the
|
|
## LICENSE file in the root directory of this source tree
|
|
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
|
|
|
import os
|
|
from torch.autograd import Variable
|
|
|
|
from myutils import utils
|
|
|
|
class StyleLoader():
|
|
def __init__(self, style_folder, style_size, cuda=True):
|
|
self.folder = style_folder
|
|
self.style_size = style_size
|
|
self.files = os.listdir(style_folder)
|
|
self.cuda = cuda
|
|
|
|
def get(self, i):
|
|
idx = i%len(self.files)
|
|
filepath = os.path.join(self.folder, self.files[idx])
|
|
style = utils.tensor_load_rgbimage(filepath, self.style_size)
|
|
style = style.unsqueeze(0)
|
|
style = utils.preprocess_batch(style)
|
|
if self.cuda:
|
|
style = style.cuda()
|
|
style_v = Variable(style, requires_grad=False)
|
|
return style_v
|
|
|
|
def size(self):
|
|
return len(self.files)
|
|
|
|
|