def get_imagenet_data()

in utils.py [0:0]


def get_imagenet_data(data_dir, batch_size):
    print("==> Preparing ImageNet data...")

    traindir = os.path.join(data_dir, "training")
    valdir = os.path.join(data_dir, "validation")
    normalize = transforms.Normalize(
        mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
    )

    trainset = torchvision.datasets.ImageFolder(
        traindir,
        transforms.Compose(
            [
                transforms.RandomResizedCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ]
        ),
    )

    trainloader = torch.utils.data.DataLoader(
        trainset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=16,
        pin_memory=True,
    )

    testset = torchvision.datasets.ImageFolder(
        valdir,
        transforms.Compose(
            [
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                normalize,
            ]
        ),
    )

    # Note that, we perform an analysis of BatchNorm statistics when validating,
    # so we *must* shuffle the validation set.
    testloader = torch.utils.data.DataLoader(
        testset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=16,
        pin_memory=True,
    )

    return trainset, trainloader, testset, testloader