in datasets.py [0:0]
def get_downstream_dataset(catalog, name, is_train, transform):
entry = catalog[name]
root = entry['path']
if entry['type'] == 'imagefolder':
dataset = t_datasets.ImageFolder(os.path.join(root, entry['train'] if is_train else entry['test']),
transform=transform)
elif entry['type'] == 'special':
if name == 'cifar10':
dataset = t_datasets.CIFAR10(root, train=is_train,
transform=transform, download=True)
elif name == 'cifar100':
dataset = t_datasets.CIFAR100(root, train=is_train,
transform=transform, download=True)
elif name == 'stl10':
dataset = t_datasets.STL10(root, split='train' if is_train else 'test',
transform=transform, download=True)
elif name == 'mnist':
dataset = t_datasets.MNIST(root, train=is_train,
transform=transform, download=True)
elif entry['type'] == 'filelist':
path = entry['train'] if is_train else entry['test']
val_images = os.path.join(root, path + '_images.npy')
val_labels = os.path.join(root, path + '_labels.npy')
if name == 'clevr_counts':
target_transform = lambda x: ['count_10', 'count_3', 'count_4', 'count_5', 'count_6', 'count_7', 'count_8', 'count_9'].index(x)
else:
target_transform = None
dataset = FileListDataset(val_images, val_labels, transform, target_transform)
else:
raise Exception('Unknown dataset')
return dataset