in ppuda/vision/loader.py [0:0]
def image_loader(dataset='imagenet', data_dir='./data/', test=True, fine_tune=False,
batch_size=64, test_batch_size=64, num_workers=0,
cutout=False, cutout_length=16, noise=False,
seed=1111, load_train_anyway=False, n_shots=None):
"""
:param dataset: image dataset: imagenet, cifar10, cifar100, etc.
:param data_dir: location of the dataset
:param test: True to load the test data for evaluation, False to load the validation data.
:param fine_tune: True when fine-tuning an ImageNet model on CIFAR-10
:param batch_size: training batch size of images
:param test_batch_size: evaluation batch size of images
:param num_workers: number of threads to load/preprocess images
:param cutout: use Cutout for data augmentation from
"Terrance DeVries, Graham W. Taylor. Improved Regularization of Convolutional Neural Networks with Cutout. 2017."
:param cutout_length: Cutout hyperparameter
:param noise: evaluate on the images with added Gaussian noise
:param seed: random seed to shuffle validation images on ImageNet
:param load_train_anyway: load training images even when evaluating on test data (test=True)
:param n_shots: the number of training images per class (only for CIFAR-10 and other torchvision datasets and when test=True)
:return: training and evaluation torch DataLoaders and number of classes in the dataset
"""
train_data = None
if dataset.lower() == 'imagenet':
train_transform, valid_transform = transforms_imagenet(noise=noise, cifar_style=False)
imagenet_dir = os.path.join(data_dir, 'imagenet')
if not test or load_train_anyway:
train_data = ImageNetDataset(imagenet_dir, 'train', transform=train_transform, has_validation=not test)
valid_data = ImageNetDataset(imagenet_dir, 'val', transform=valid_transform, has_validation=not test)
shuffle_val = True # to evaluate models with batch norm in the training mode (in case there is no running statistics)
n_classes = len(valid_data.classes)
generator = torch.Generator()
generator.manual_seed(seed) # to reproduce evaluation with shuffle=True on ImageNet
else:
dataset = dataset.upper()
train_transform, valid_transform = transforms_cifar(cutout=cutout, cutout_length=cutout_length, noise=noise, sz=224 if fine_tune else 32)
if test:
valid_data = eval('{}(data_dir, train=False, download=True, transform=valid_transform)'.format(dataset))
if load_train_anyway:
train_data = eval('{}(data_dir, train=True, download=True, transform=train_transform)'.format(dataset))
if n_shots is not None:
train_data = to_few_shot(train_data, n_shots=n_shots)
else:
if n_shots is not None:
print('few shot regime is only supported for evaluation on the test data')
# Held out 10% (e.g. 5000 images in case of CIFAR-10) of training data as the validation set
train_data = eval('{}(data_dir, train=True, download=True, transform=train_transform)'.format(dataset))
valid_data = eval('{}(data_dir, train=True, download=True, transform=valid_transform)'.format(dataset))
n_all = len(train_data.targets)
n_val = n_all // 10
idx_train, idx_val = torch.split(torch.arange(n_all), [n_all - n_val, n_val])
train_data.data = train_data.data[idx_train]
train_data.targets = [train_data.targets[i] for i in idx_train]
valid_data.data = valid_data.data[idx_val]
valid_data.targets = [valid_data.targets[i] for i in idx_val]
if train_data is not None:
train_data.checksum = train_data.data.mean()
train_data.num_examples = len(train_data.targets)
shuffle_val = False
n_classes = len(torch.unique(torch.tensor(valid_data.targets)))
generator = None
valid_data.checksum = valid_data.data.mean()
valid_data.num_examples = len(valid_data.targets)
print('loaded {}: {} classes, {} train samples (checksum={}), '
'{} {} samples (checksum={:.3f})'.format(dataset,
n_classes,
train_data.num_examples if train_data else 'none',
('%.3f' % train_data.checksum) if train_data else 'none',
valid_data.num_examples,
'test' if test else 'val',
valid_data.checksum))
if train_data is None:
train_loader = None
else:
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True,
pin_memory=True, num_workers=num_workers)
valid_loader = DataLoader(valid_data, batch_size=test_batch_size, shuffle=shuffle_val,
pin_memory=True, num_workers=num_workers, generator=generator)
return train_loader, valid_loader, n_classes