in datasets/__init__.py [0:0]
def get_transform(dataset, aug, is_train):
if dataset == "cifar10":
if aug and is_train:
print('Using data augmentation to train model')
augmentations = [transforms.RandomCrop(32, padding=4),transforms.RandomHorizontalFlip()]
normalize = [transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
transform = transforms.Compose(augmentations + normalize)
else:
print('Not using data augmentation to train model')
transform = transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
elif dataset=='mnist':
transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])
elif dataset=='imagenet':
if aug and is_train:
print('Using data augmentation to train model')
augmentations = [transforms.Resize(256),transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip()]
normalize = [transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]
transform = transforms.Compose(augmentations + normalize)
else:
print('Not using data augmentation to train model')
transform = transforms.Compose( [transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])])
elif dataset=='cifar100':
if aug and is_train:
print('Using data augmentation to train model')
augmentations = [transforms.RandomCrop(32, padding=4),transforms.RandomHorizontalFlip()]
normalize = [transforms.ToTensor(),transforms.Normalize(mean=[n/255 for n in [129.3, 124.1, 112.4]], std=[n/255 for n in [68.2, 65.4, 70.4]])]
transform = transforms.Compose(augmentations + normalize)
else:
print('Not using data augmentation to train model')
transform = transforms.Compose( [transforms.ToTensor(), transforms.Normalize(mean=[n/255 for n in [129.3, 124.1, 112.4]], std=[n/255 for n in [68.2, 65.4, 70.4]])])
return transform