def get_transform()

in src/sagemaker_defect_detection/transforms.py [0:0]


def get_transform(split: str) -> Callable:
    """
    Image data transformations such as normalization for train split for classification task

    Parameters
    ----------
    split : str
        train or else

    Returns
    -------
    Callable
        Image transformation function
    """
    normalize = transforms.Normalize(mean=[MEAN_RED, MEAN_GREEN, MEAN_BLUE], std=[STD_RED, STD_GREEN, STD_BLUE])
    if split == "train":
        return transforms.Compose(
            [
                transforms.RandomResizedCrop(IMAGE_HEIGHT),
                transforms.RandomRotation(ROTATION_ANGLE),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ]
        )

    else:
        return transforms.Compose(
            [
                transforms.Resize(IMAGE_RESIZE_HEIGHT),
                transforms.CenterCrop(IMAGE_HEIGHT),
                transforms.ToTensor(),
                normalize,
            ]
        )