vision/data.py (15 lines of code) (raw):
import torch
import torchvision
RESIZE, CROP = 256, 224
TRANSFORM = torchvision.transforms.Compose(
[
torchvision.transforms.Resize(RESIZE),
torchvision.transforms.CenterCrop(CROP),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
def get_imagenet(datapath, split, batch_size, shuffle, transform=TRANSFORM):
ds = torchvision.datasets.ImageNet(root=datapath, split=split, transform=transform)
loader = torch.utils.data.DataLoader(ds, shuffle=shuffle, batch_size=batch_size)
return ds, loader