in utils.py [0:0]
def get_cifar10_data(batch_size):
print("==> Preparing CIFAR-10 data...")
transform_train = transforms.Compose(
[
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(
(0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)
),
]
)
transform_test = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize(
(0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)
),
]
)
trainset = torchvision.datasets.CIFAR10(
root="./data", train=True, download=True, transform=transform_train
)
trainloader = torch.utils.data.DataLoader(
trainset,
batch_size=batch_size,
shuffle=True,
num_workers=2,
pin_memory=True,
)
testset = torchvision.datasets.CIFAR10(
root="./data", train=False, download=True, transform=transform_test
)
# Note that, we perform an analysis of BatchNorm statistics when validating,
# so we *must* shuffle the validation set.
testloader = torch.utils.data.DataLoader(
testset,
batch_size=batch_size,
shuffle=True,
num_workers=2,
pin_memory=True,
)
return trainset, trainloader, testset, testloader