in datasets.py [0:0]
def get_loaders(data_path, dataset_name, batch_size, method="erm", duplicates=None):
Dataset = {
"waterbirds": Waterbirds,
"celeba": CelebA,
"multinli": MultiNLI,
"civilcomments": CivilCommentsFine
if method in ("subg", "rwg")
else CivilComments,
"toy": Toy,
}[dataset_name]
def dl(dataset, bs, shuffle, weights):
if weights is not None:
sampler = torch.utils.data.WeightedRandomSampler(weights, len(weights))
else:
sampler = None
return DataLoader(
dataset,
batch_size=bs,
shuffle=shuffle,
sampler=sampler,
num_workers=4,
pin_memory=True,
)
if method == "subg":
subsample_what = "groups"
elif method == "suby":
subsample_what = "classes"
else:
subsample_what = None
dataset_tr = Dataset(data_path, "tr", subsample_what, duplicates)
if method == "rwg" or method == "dro":
weights_tr = dataset_tr.wg
elif method == "rwy":
weights_tr = dataset_tr.wy
else:
weights_tr = None
return {
"tr": dl(dataset_tr, batch_size, weights_tr is None, weights_tr),
"va": dl(Dataset(data_path, "va", None), 128, False, None),
"te": dl(Dataset(data_path, "te", None), 128, False, None),
}