in utils.py [0:0]
def get_imagenet_data(data_dir, batch_size):
print("==> Preparing ImageNet data...")
traindir = os.path.join(data_dir, "training")
valdir = os.path.join(data_dir, "validation")
normalize = transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
)
trainset = torchvision.datasets.ImageFolder(
traindir,
transforms.Compose(
[
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
]
),
)
trainloader = torch.utils.data.DataLoader(
trainset,
batch_size=batch_size,
shuffle=True,
num_workers=16,
pin_memory=True,
)
testset = torchvision.datasets.ImageFolder(
valdir,
transforms.Compose(
[
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize,
]
),
)
# Note that, we perform an analysis of BatchNorm statistics when validating,
# so we *must* shuffle the validation set.
testloader = torch.utils.data.DataLoader(
testset,
batch_size=batch_size,
shuffle=True,
num_workers=16,
pin_memory=True,
)
return trainset, trainloader, testset, testloader