def get_loader()

in src/data_loader.py [0:0]


def get_loader(dataset,
               dataset_root,
               split,
               transform,
               batch_size,
               shuffle,
               num_workers,
               include_eos,
               drop_last=False,
               shuffle_labels=False,
               seed=1234,
               checkpoint=None):

    # reads the file with ids to use for this split
    perm_file = os.path.join('../data/splits/', dataset, split + '.txt')
    with open(perm_file, 'r') as f:
        perm = np.array([int(line.rstrip('\n')) for line in f])

    if dataset == 'coco':
        if split == 'train' or split == 'val':
            annFile = os.path.join(dataset_root, 'annotations', 'instances_train2014.json')
            impath = os.path.join(dataset_root, 'train2014')
        else:
            annFile = os.path.join(dataset_root, 'annotations', 'instances_val2014.json')
            impath = os.path.join(dataset_root, 'val2014')

        dataset = COCO(
            root=impath,
            annFile=annFile,
            transform=transform,
            shuffle=shuffle_labels,
            perm=perm,
            include_eos=include_eos)

    elif dataset == 'voc':
        dataset = VOC(
            root=dataset_root,
            year='2007',
            image_set=split,
            download=False,
            transform=transform,
            shuffle=shuffle_labels,
            perm=perm,
            include_eos=include_eos)

    elif dataset == 'nuswide':
        dataset = NUSWIDE(
            dataset_root,
            split,
            transform=transform,
            shuffle=shuffle_labels,
            perm=perm,
            include_eos=include_eos)

    elif dataset == 'ade20k':
        dataset = ADE20K(
            dataset_root,
            split,
            transform=transform,
            shuffle=shuffle_labels,
            perm=perm,
            include_eos=include_eos)

    elif dataset == 'recipe1m':
        dataset = Recipe1M(
            dataset_root,
            split,
            maxnumims=5,
            shuffle=shuffle_labels,
            transform=transform,
            use_lmdb=False,
            suff='final_',
            perm=perm,
            include_eos=include_eos)

    def worker_init_fn(worker_id):
        np.random.seed(seed)

    if shuffle:
        # for training
        sampler = RandomSamplerWithState(dataset, batch_size, seed)
        if checkpoint is not None:
            sampler.set_state(checkpoint['args'].current_epoch, checkpoint['current_step'])
    else:
        # for validation and test
        sampler = SequentialSampler(dataset)

    data_loader = torch.utils.data.DataLoader(
        dataset=dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        drop_last=drop_last,
        pin_memory=True,
        collate_fn=collate_fn,
        worker_init_fn=worker_init_fn,
        sampler=sampler)

    return data_loader, dataset