You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
63 lines
2.1 KiB
63 lines
2.1 KiB
from typing import List
|
|
|
|
import albumentations as albu
|
|
from torchvision import transforms
|
|
|
|
def get_transforms(size: int, scope: str = 'geometric', crop='random'):
|
|
augs = {'strong': albu.Compose([albu.HorizontalFlip(),
|
|
albu.ShiftScaleRotate(shift_limit=0.0, scale_limit=0.2, rotate_limit=20, p=.4),
|
|
albu.ElasticTransform(),
|
|
albu.OpticalDistortion(),
|
|
albu.OneOf([
|
|
albu.CLAHE(clip_limit=2),
|
|
albu.IAASharpen(),
|
|
albu.IAAEmboss(),
|
|
albu.RandomBrightnessContrast(),
|
|
albu.RandomGamma()
|
|
], p=0.5),
|
|
albu.OneOf([
|
|
albu.RGBShift(),
|
|
albu.HueSaturationValue(),
|
|
], p=0.5),
|
|
]),
|
|
'weak': albu.Compose([albu.HorizontalFlip(),
|
|
]),
|
|
'geometric': albu.Compose([albu.HorizontalFlip(),
|
|
albu.VerticalFlip(),
|
|
albu.RandomRotate90(),
|
|
]),
|
|
'None': None
|
|
}
|
|
|
|
aug_fn = augs[scope]
|
|
crop_fn = {'random': albu.RandomCrop(size, size, always_apply=True),
|
|
'center': albu.CenterCrop(size, size, always_apply=True)}[crop]
|
|
|
|
pipeline = albu.Compose([aug_fn, crop_fn], additional_targets={'target': 'image'})
|
|
|
|
|
|
def process(a, b):
|
|
r = pipeline(image=a, target=b)
|
|
return r['image'], r['target']
|
|
|
|
return process
|
|
|
|
|
|
def get_normalize():
|
|
transform = transforms.Compose([
|
|
transforms.ToTensor()
|
|
])
|
|
|
|
def process(a, b):
|
|
image = transform(a).permute(1, 2, 0) - 0.5
|
|
target = transform(b).permute(1, 2, 0) - 0.5
|
|
return image, target
|
|
|
|
return process
|
|
|
|
|
|
|
|
|
|
|
|
|