def get_augmentations()

in data/transforms.py [0:0]


def get_augmentations(aug_type):
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    default_train_augs = [
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
    ]
    default_val_augs = [
        transforms.Resize(256),
        transforms.CenterCrop(224),
    ]
    appendix_augs = [
        transforms.ToTensor(),
        normalize,
    ]
    if aug_type == 'DefaultTrain':
        augs = default_train_augs + appendix_augs
    elif aug_type == 'DefaultVal':
        augs = default_val_augs + appendix_augs
    elif aug_type == 'RandAugment':
        augs = default_train_augs + [RandAugment(n=2, m=10)] + appendix_augs
    elif aug_type == 'MoCoV1':
        augs = [
            transforms.RandomResizedCrop(224, scale=(0.2, 1.)),
            transforms.RandomGrayscale(p=0.2),
            transforms.ColorJitter(0.4, 0.4, 0.4, 0.4),
            transforms.RandomHorizontalFlip()
        ] + appendix_augs
    elif aug_type == 'MoCoV2':
        augs = [
            transforms.RandomResizedCrop(224, scale=(0.2, 1.)),
            transforms.RandomApply([
                transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)  # not strengthened
            ], p=0.8),
            transforms.RandomGrayscale(p=0.2),
            transforms.RandomApply([GaussianBlur([.1, 2.])], p=0.5),
            transforms.RandomHorizontalFlip(),
        ] + appendix_augs
    else:
        raise NotImplementedError('augmentation type not found: {}'.format(aug_type))

    return augs