def get_transform()

in datasets/__init__.py [0:0]


def get_transform(dataset, aug, is_train):
    if dataset == "cifar10":
        if aug and is_train:
            print('Using data augmentation to train model')
            augmentations = [transforms.RandomCrop(32, padding=4),transforms.RandomHorizontalFlip()]
            normalize = [transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
            transform = transforms.Compose(augmentations + normalize)
        else:
            print('Not using data augmentation to train model')
            transform = transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    elif dataset=='mnist':
        transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])
    elif dataset=='imagenet':
        if aug and is_train:
            print('Using data augmentation to train model')
            augmentations = [transforms.Resize(256),transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip()]
            normalize = [transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]
            transform = transforms.Compose(augmentations + normalize)
        else:
            print('Not using data augmentation to train model')
            transform = transforms.Compose( [transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])])
    elif dataset=='cifar100':
        if aug and is_train:
            print('Using data augmentation to train model')
            augmentations = [transforms.RandomCrop(32, padding=4),transforms.RandomHorizontalFlip()]
            normalize = [transforms.ToTensor(),transforms.Normalize(mean=[n/255 for n in [129.3, 124.1, 112.4]], std=[n/255 for n in [68.2,  65.4,  70.4]])]
            transform = transforms.Compose(augmentations + normalize)
        else:
            print('Not using data augmentation to train model')
            transform = transforms.Compose( [transforms.ToTensor(), transforms.Normalize(mean=[n/255 for n in [129.3, 124.1, 112.4]], std=[n/255 for n in [68.2,  65.4,  70.4]])])

    return transform