in src/dataset.py [0:0]
def get_data_loader(params, split, transform, shuffle, distributed_sampler, watermark_path=""):
"""
Get data loader over imagenet dataset.
"""
assert params.dataset in DATASETS
# assert (start is None) == (size is None) # either both are None, or neither
if params.num_classes == -1:
params.num_classes = DATASETS[params.dataset]['num_classes']
class_mapper = lambda x: x
# Transform
if params.dataset.startswith("cifar") or params.dataset == "tiny" or params.dataset == "mini_imagenet":
transform = getCifarTransform(transform, img_size=params.img_size, crop_size=params.crop_size, normalization=True)
elif params.dataset in ["imagenet", "flickr", "cub", "places205"]:
transform = getImagenetTransform(transform, img_size=params.img_size, crop_size=params.crop_size, normalization=True)
# Data
if params.dataset in ["cifar10", "mini_imagenet"]:
pass
# if split == "valid":
# if data_path == "":
# data = CIFAR10(root=DATASETS[params.dataset][split], transform=transform, return_index=return_index, overlay=overlay, blend_type=blend_type, alpha=alpha, overlay_class=overlay_class)
# else:
# data = CIFAR10(root=join(dirname(DATASETS[params.dataset][split]), data_path), transform=transform, return_index=return_index, overlay=overlay, blend_type=blend_type, alpha=alpha, overlay_class=overlay_class)
# else:
# data = CIFAR10(root=join(DATASETS[params.dataset][split], data_path), transform=transform, return_index=return_index, overlay=overlay, blend_type=blend_type, alpha=alpha, overlay_class=overlay_class)
elif params.dataset in ["imagenet", "places205"]:
vanilla_data = ImageFolder(root=DATASETS[params.dataset][split], transform=transform)
if watermark_path != "":
data = WatermarkedSet(vanilla_data, watermark_path=watermark_path, transform=transform)
else:
data = vanilla_data
else:
raise NotImplementedError()
# Restricted the number of classes, remap them to [0, n_cl - 1]
if params.num_classes != DATASETS[params.dataset]['num_classes']:
indices = []
for cl_id in SUBCLASSES[params.dataset][params.num_classes]:
indices.extend(data.class2position[cl_id])
data = Subset(data, indices)
if params.num_classes != DATASETS[params.dataset]['num_classes']:
class_mapper = lambda i: SUBCLASSES[params.dataset][params.num_classes].index(i)
data = TargetTransformDataset(data, class_mapper)
# sampler
sampler = None
if distributed_sampler:
if sampler is None:
# sampler = torch.utils.data.distributed.DistributedSampler(data)
sampler = SeededDistributedSampler(data, seed=params.seed)
else:
sampler = DistributedSampler(sampler)
# data loader
data_loader = torch.utils.data.DataLoader(
data,
batch_size=params.batch_size,
shuffle=shuffle and sampler is None,
num_workers=params.nb_workers,
pin_memory=True,
sampler=sampler
)
return data_loader, sampler, class_mapper