You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

76 lines
1.9 KiB

import torch.utils.data as data
from PIL import Image
import os
import os.path
import numpy as np
import h5py
import glob
IMG_EXTENSIONS = [
'.jpg', '.JPG', '.jpeg', '.JPEG',
'.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
]
def is_image_file(filename):
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
def make_dataset(dir):
images = []
if not os.path.isdir(dir):
raise Exception('Check dataroot')
for root, _, fnames in sorted(os.walk(dir)):
for fname in fnames:
if is_image_file(fname):
path = os.path.join(dir, fname)
item = path
images.append(item)
return images
def default_loader(path):
return Image.open(path).convert('RGB')
class classification(data.Dataset):
def __init__(self, root, transform=None, loader=default_loader, seed=None):
# imgs = make_dataset(root)
# if len(imgs) == 0:
# raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n"
# "Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))
self.root = root
# self.imgs = imgs
self.transform = transform
self.loader = loader
if seed is not None:
np.random.seed(seed)
def __getitem__(self, _):
index = np.random.randint(1,self.__len__())
# path = self.imgs[index]
# img = self.loader(path)
#img = img.resize((w, h), Image.BILINEAR)
file_name=self.root+'/'+str(index)+'.h5'
f=h5py.File(file_name,'r')
haze_image=f['haze'][:]
label=f['label'][:]
label=label.mean()-1
haze_image=np.swapaxes(haze_image,0,2)
haze_image=np.swapaxes(haze_image,1,2)
# if self.transform is not None:
# # NOTE preprocessing for each pair of images
# imgA, imgB = self.transform(imgA, imgB)
return haze_image, label
def __len__(self):
train_list=glob.glob(self.root+'/*h5')
# print len(train_list)
return len(train_list)
# return len(self.imgs)