def get_loaders()

in datasets.py [0:0]


def get_loaders(data_path, dataset_name, batch_size, method="erm", duplicates=None):
    Dataset = {
        "waterbirds": Waterbirds,
        "celeba": CelebA,
        "multinli": MultiNLI,
        "civilcomments": CivilCommentsFine
        if method in ("subg", "rwg")
        else CivilComments,
        "toy": Toy,
    }[dataset_name]

    def dl(dataset, bs, shuffle, weights):
        if weights is not None:
            sampler = torch.utils.data.WeightedRandomSampler(weights, len(weights))
        else:
            sampler = None
        return DataLoader(
            dataset,
            batch_size=bs,
            shuffle=shuffle,
            sampler=sampler,
            num_workers=4,
            pin_memory=True,
        )

    if method == "subg":
        subsample_what = "groups"
    elif method == "suby":
        subsample_what = "classes"
    else:
        subsample_what = None

    dataset_tr = Dataset(data_path, "tr", subsample_what, duplicates)

    if method == "rwg" or method == "dro":
        weights_tr = dataset_tr.wg
    elif method == "rwy":
        weights_tr = dataset_tr.wy
    else:
        weights_tr = None

    return {
        "tr": dl(dataset_tr, batch_size, weights_tr is None, weights_tr),
        "va": dl(Dataset(data_path, "va", None), 128, False, None),
        "te": dl(Dataset(data_path, "te", None), 128, False, None),
    }