You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

225 lines
6.8 KiB

from __future__ import division
import torch
import math
import random
from PIL import Image, ImageOps
import numpy as np
import numbers
import types
class Compose(object):
"""Composes several transforms together.
Args:
transforms (List[Transform]): list of transforms to compose.
Example:
>>> transforms.Compose([
>>> transforms.CenterCrop(10),
>>> transforms.ToTensor(),
>>> ])
"""
def __init__(self, transforms):
self.transforms = transforms
def __call__(self, imgA, imgB, imgC):
for t in self.transforms:
imgA, imgB, imgC = t(imgA, imgB, imgC)
return imgA, imgB, imgC
class ToTensor(object):
"""Converts a PIL.Image or numpy.ndarray (H x W x C) in the range
[0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0].
"""
def __call__(self, picA, picB, picC):
pics = [picA, picB, picC]
output = []
for pic in pics:
if isinstance(pic, np.ndarray):
# handle numpy array
img = torch.from_numpy(pic.transpose((2, 0, 1)))
else:
# handle PIL Image
img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
# PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK
if pic.mode == 'YCbCr':
nchannel = 3
else:
nchannel = len(pic.mode)
img = img.view(pic.size[1], pic.size[0], nchannel)
# put it from HWC to CHW format
# yikes, this transpose takes 80% of the loading time/CPU
img = img.transpose(0, 1).transpose(0, 2).contiguous()
img = img.float().div(255.)
output.append(img)
return output[0], output[1], output[2]
class ToPILImage(object):
"""Converts a torch.*Tensor of range [0, 1] and shape C x H x W
or numpy ndarray of dtype=uint8, range[0, 255] and shape H x W x C
to a PIL.Image of range [0, 255]
"""
def __call__(self, picA, picB, picC):
pics = [picA, picB, picC]
output = []
for pic in pics:
npimg = pic
mode = None
if not isinstance(npimg, np.ndarray):
npimg = pic.mul(255).byte().numpy()
npimg = np.transpose(npimg, (1, 2, 0))
if npimg.shape[2] == 1:
npimg = npimg[:, :, 0]
mode = "L"
output.append(Image.fromarray(npimg, mode=mode))
return output[0], output[1], output[2]
class Normalize(object):
"""Given mean: (R, G, B) and std: (R, G, B),
will normalize each channel of the torch.*Tensor, i.e.
channel = (channel - mean) / std
"""
def __init__(self, mean, std):
self.mean = mean
self.std = std
def __call__(self, tensorA, tensorB, tensorC):
tensors = [tensorA, tensorB, tensorC]
output = []
for tensor in tensors:
# TODO: make efficient
for t, m, s in zip(tensor, self.mean, self.std):
t.sub_(m).div_(s)
output.append(tensor)
return output[0], output[1], output[2]
class Scale(object):
"""Rescales the input PIL.Image to the given 'size'.
'size' will be the size of the smaller edge.
For example, if height > width, then image will be
rescaled to (size * height / width, size)
size: size of the smaller edge
interpolation: Default: PIL.Image.BILINEAR
"""
def __init__(self, size, interpolation=Image.BILINEAR):
self.size = size
self.interpolation = interpolation
def __call__(self, imgA, imgB, imgC):
imgs = [imgA, imgB, imgC]
output = []
for img in imgs:
w, h = img.size
if (w <= h and w == self.size) or (h <= w and h == self.size):
output.append(img)
continue
if w < h:
ow = self.size
oh = int(self.size * h / w)
output.append(img.resize((ow, oh), self.interpolation))
continue
else:
oh = self.size
ow = int(self.size * w / h)
output.append(img.resize((ow, oh), self.interpolation))
return output[0], output[1], output[2]
class CenterCrop(object):
"""Crops the given PIL.Image at the center to have a region of
the given size. size can be a tuple (target_height, target_width)
or an integer, in which case the target will be of a square shape (size, size)
"""
def __init__(self, size):
if isinstance(size, numbers.Number):
self.size = (int(size), int(size))
else:
self.size = size
def __call__(self, imgA, imgB, imgC):
imgs = [imgA, imgB, imgC]
output = []
for img in imgs:
w, h = img.size
th, tw = self.size
x1 = int(round((w - tw) / 2.))
y1 = int(round((h - th) / 2.))
output.append(img.crop((x1, y1, x1 + tw, y1 + th)))
return output[0], output[1], output[2]
class Pad(object):
"""Pads the given PIL.Image on all sides with the given "pad" value"""
def __init__(self, padding, fill=0):
assert isinstance(padding, numbers.Number)
assert isinstance(fill, numbers.Number) or isinstance(fill, str) or isinstance(fill, tuple)
self.padding = padding
self.fill = fill
def __call__(self, imgA, imgB, imgC):
imgs = [imgA, imgB, imgC]
output = []
for img in imgs:
output.append(ImageOps.expand(img, border=self.padding, fill=self.fill))
return output[0], output[1], output[2]
class Lambda(object):
"""Applies a lambda as a transform."""
def __init__(self, lambd):
assert isinstance(lambd, types.LambdaType)
self.lambd = lambd
def __call__(self, imgA, imgB, imgC):
imgs = [imgA, imgB, imgC]
output = []
for img in imgs:
output.append(self.lambd(img))
return output[0], output[1], output[2]
class RandomCrop(object):
"""Crops the given PIL.Image at a random location to have a region of
the given size. size can be a tuple (target_height, target_width)
or an integer, in which case the target will be of a square shape (size, size)
"""
def __init__(self, size, padding=0):
if isinstance(size, numbers.Number):
self.size = (int(size), int(size))
else:
self.size = size
self.padding = padding
def __call__(self, imgA, imgB, imgC):
imgs = [imgA, imgB, imgC]
output = []
x1 = -1
y1 = -1
for img in imgs:
if self.padding > 0:
img = ImageOps.expand(img, border=self.padding, fill=0)
w, h = img.size
th, tw = self.size
if w == tw and h == th:
output.append(img)
continue
if x1 == -1 and y1 == -1:
x1 = random.randint(0, w - tw)
y1 = random.randint(0, h - th)
output.append(img.crop((x1, y1, x1 + tw, y1 + th)))
return output[0], output[1], output[2]
class RandomHorizontalFlip(object):
"""Randomly horizontally flips the given PIL.Image with a probability of 0.5
"""
def __call__(self, imgA, imgB, imgC):
imgs = [imgA, imgB, imgC]
output = []
# flag = random.random() < 0.5
flag = random.random() < -1
for img in imgs:
if flag:
output.append(img.transpose(Image.FLIP_LEFT_RIGHT))
else:
output.append(img)
return output[0], output[1], output[2]