datasets/tinyimages_300k.py (32 lines of code) (raw):
import os
import numpy as np
from PIL import Image
import torch
from torch.utils.data import DataLoader, Subset
from torchvision import transforms
class TinyImages(torch.utils.data.Dataset):
def __init__(self, root, transform):
super(TinyImages, self).__init__()
self.data = np.load(os.path.join(root, 'tinyimages80m', '300K_random_images.npy'))
self.transform = transform
print("TinyImages Contain {} images".format(len(self.data)))
def __getitem__(self, index):
img = self.data[index]
img = Image.fromarray(img)
if self.transform is not None:
img = self.transform(img)
return img, -1 # -1 is the class
def __len__(self):
return len(self.data)
def tinyimages300k_dataloaders(num_samples=300000, train_batch_size=64, num_workers=8, data_root_path='/ssd1/haotao/datasets'):
num_samples = int(num_samples)
data_dir = os.path.join(data_root_path)
train_transform = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
])
train_set = Subset(TinyImages(data_dir, train=True, transform=train_transform, download=True), list(range(num_samples)))
train_loader = DataLoader(train_set, batch_size=train_batch_size, shuffle=True, num_workers=num_workers,
drop_last=True, pin_memory=True)
return train_loader