def getImagenetTransform()

in src/dataset.py [0:0]


def getImagenetTransform(name, img_size=256, crop_size=224, normalization=True, as_list=False, differentiable=False):
    transform = []
    if differentiable:
        if name == "random":
            transform = RandomResizedCropFlip(crop_size)
        elif name == "center":
            transform = CenterCrop(img_size, crop_size)
        else:
            assert name == "none"
            transform = DifferentiableDataAugmentation()
    else:
        if name == "random":
            transform = [
                transforms.RandomResizedCrop(crop_size),
                transforms.RandomHorizontalFlip(),
            ]
        elif name == "tencrop":
            transform = [
                transforms.Resize(img_size),
                transforms.TenCrop(crop_size),
            ]
        elif name == "center":
            transform = [
                transforms.Resize(img_size),
                transforms.CenterCrop(crop_size),
            ]
        else:
            assert name == "none"

    if name == "tencrop":
        postprocess = [
            transforms.Lambda(lambda crops: [transforms.ToTensor()(crop) for crop in crops])
        ]
    else:
        postprocess = [
            transforms.ToTensor()
        ]

    if normalization:
        if name == "tencrop":
            postprocess.append(transforms.Lambda(lambda crops: torch.stack([NORMALIZE_IMAGENET(crop) for crop in crops])))
        else:
            postprocess.append(NORMALIZE_IMAGENET)

    if as_list:
        return transform + postprocess
    else:
        if differentiable:
            return transform, transforms.Compose(postprocess)
        else:
            return transforms.Compose(transform + postprocess)