in src/sagemaker_defect_detection/transforms.py [0:0]
def get_transform(split: str) -> Callable:
"""
Image data transformations such as normalization for train split for classification task
Parameters
----------
split : str
train or else
Returns
-------
Callable
Image transformation function
"""
normalize = transforms.Normalize(mean=[MEAN_RED, MEAN_GREEN, MEAN_BLUE], std=[STD_RED, STD_GREEN, STD_BLUE])
if split == "train":
return transforms.Compose(
[
transforms.RandomResizedCrop(IMAGE_HEIGHT),
transforms.RandomRotation(ROTATION_ANGLE),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
]
)
else:
return transforms.Compose(
[
transforms.Resize(IMAGE_RESIZE_HEIGHT),
transforms.CenterCrop(IMAGE_HEIGHT),
transforms.ToTensor(),
normalize,
]
)