in weak_to_strong/datasets.py [0:0]
def load_dataset(ds_name: str, seed: int = 0, split_sizes: Optional[dict] = None):
if split_sizes is None:
split_sizes = dict(train=None, test=None)
if ds_name not in _REGISTRY:
raise ValueError(f"Unknown dataset {ds_name}, please register")
cfg = _REGISTRY[ds_name]
results = {}
for split, n_docs in split_sizes.items():
ds = cfg.loader(split)
try:
ds = ds.select(range(n_docs))
except IndexError as e:
print(f"Warning {ds_name} has less than {n_docs} docs, using all: {e}")
ds = ds.map(functools.partial(cfg.formatter, rng=Random(seed)))
ds = ds.map(
lambda ex: {"soft_label": [1 - float(ex["hard_label"]), float(ex["hard_label"])]}
)
ds = ds.shuffle(seed=seed) # shuffling a bit pointless for test set but wtv
results[split] = ds
return results