def get_data_loader()

in src/dataset.py [0:0]


def get_data_loader(params, split, transform, shuffle, distributed_sampler, watermark_path=""):
    """
    Get data loader over imagenet dataset.
    """
    assert params.dataset in DATASETS
    # assert (start is None) == (size is None) # either both are None, or neither
    if params.num_classes == -1:
        params.num_classes = DATASETS[params.dataset]['num_classes']
    class_mapper = lambda x: x

    # Transform
    if params.dataset.startswith("cifar") or params.dataset == "tiny" or params.dataset == "mini_imagenet":
        transform = getCifarTransform(transform, img_size=params.img_size, crop_size=params.crop_size, normalization=True)
    elif params.dataset in ["imagenet", "flickr", "cub", "places205"]:
        transform = getImagenetTransform(transform, img_size=params.img_size, crop_size=params.crop_size, normalization=True)

    # Data
    if params.dataset in ["cifar10", "mini_imagenet"]:
        pass
        # if split == "valid":
        #     if data_path == "":
        #         data = CIFAR10(root=DATASETS[params.dataset][split], transform=transform, return_index=return_index, overlay=overlay, blend_type=blend_type, alpha=alpha, overlay_class=overlay_class)
        #     else:
        #         data = CIFAR10(root=join(dirname(DATASETS[params.dataset][split]), data_path), transform=transform, return_index=return_index, overlay=overlay, blend_type=blend_type, alpha=alpha, overlay_class=overlay_class)
        # else:
        #     data = CIFAR10(root=join(DATASETS[params.dataset][split], data_path), transform=transform, return_index=return_index, overlay=overlay, blend_type=blend_type, alpha=alpha, overlay_class=overlay_class)
    elif params.dataset in ["imagenet", "places205"]:
        vanilla_data = ImageFolder(root=DATASETS[params.dataset][split], transform=transform)
        if watermark_path != "":
            data = WatermarkedSet(vanilla_data, watermark_path=watermark_path, transform=transform)
        else:
            data = vanilla_data
    else:
        raise NotImplementedError()


    # Restricted the number of classes, remap them to [0, n_cl - 1]
    if params.num_classes != DATASETS[params.dataset]['num_classes']:
        indices = []
        for cl_id in SUBCLASSES[params.dataset][params.num_classes]:
            indices.extend(data.class2position[cl_id])

        data = Subset(data, indices)
        if params.num_classes != DATASETS[params.dataset]['num_classes']:
            class_mapper = lambda i: SUBCLASSES[params.dataset][params.num_classes].index(i)
            data = TargetTransformDataset(data, class_mapper)

    # sampler
    sampler = None
    if distributed_sampler:
        if sampler is None:
            # sampler = torch.utils.data.distributed.DistributedSampler(data)
            sampler = SeededDistributedSampler(data, seed=params.seed)
        else:
            sampler = DistributedSampler(sampler)


    # data loader
    data_loader = torch.utils.data.DataLoader(
        data,
        batch_size=params.batch_size,
        shuffle=shuffle and sampler is None,
        num_workers=params.nb_workers,
        pin_memory=True,
        sampler=sampler
    )

    return data_loader, sampler, class_mapper