weak_to_strong/datasets.py (107 lines of code) (raw):
import functools
from dataclasses import dataclass
from random import Random
from typing import Any, Callable, Optional
from datasets import Dataset as HfDataset
from datasets import load_dataset as hf_load_dataset
@dataclass
class DatasetConfig:
# split -> unshuffled dataset of items
loader: Callable[[str], HfDataset]
# formats items to have keys 'txt' and 'hard_label', takes a random.Random rng
formatter: Callable[[Any], Any]
# mapping from dataset name to load function and format function
_REGISTRY: dict[str, DatasetConfig] = {}
def register_dataset(name: str, config: DatasetConfig):
_REGISTRY[name] = config
def load_dataset(ds_name: str, seed: int = 0, split_sizes: Optional[dict] = None):
if split_sizes is None:
split_sizes = dict(train=None, test=None)
if ds_name not in _REGISTRY:
raise ValueError(f"Unknown dataset {ds_name}, please register")
cfg = _REGISTRY[ds_name]
results = {}
for split, n_docs in split_sizes.items():
ds = cfg.loader(split)
try:
ds = ds.select(range(n_docs))
except IndexError as e:
print(f"Warning {ds_name} has less than {n_docs} docs, using all: {e}")
ds = ds.map(functools.partial(cfg.formatter, rng=Random(seed)))
ds = ds.map(
lambda ex: {"soft_label": [1 - float(ex["hard_label"]), float(ex["hard_label"])]}
)
ds = ds.shuffle(seed=seed) # shuffling a bit pointless for test set but wtv
results[split] = ds
return results
def tokenize_dataset(
raw_ds: HfDataset,
tokenizer: Callable,
max_ctx: int,
):
"""
This function prepares the dataset for training. It takes the raw dataset, a formatting function,
a tokenizer, a maximum context length
Parameters:
raw_ds: The raw dataset to be processed.
tokenizer: The tokenizer to be used on the formatted dataset.
max_ctx: The maximum context length for the tokenizer.
Returns:
ds: The processed and shuffled dataset ready for training.
"""
def process_function(res):
toks = tokenizer(res["txt"])
return dict(
input_ids=toks["input_ids"],
)
ds = raw_ds.map(process_function, batched=False).filter(lambda x: len(x["input_ids"]) < max_ctx)
return ds
def hf_loader(*hf_name, split_names=None):
if split_names is None:
split_names = dict()
return lambda split: hf_load_dataset(*hf_name, split=split_names.get(split, split))
##########
# ACTUAL DATASETS
##########
def format_amazon_polarity(ex, rng):
return dict(txt=f"{ex['title']} {ex['content']}", hard_label=ex["label"])
register_dataset(
"amazon_polarity",
DatasetConfig(loader=hf_loader("amazon_polarity"), formatter=format_amazon_polarity),
)
def format_sciq(ex, rng):
hard_label = int(rng.random() < 0.5)
if hard_label:
ans = ex["correct_answer"]
else:
ans = rng.choice([ex["distractor1"], ex["distractor2"], ex["distractor3"]])
txt = f"Q: {ex['question']} A: {ans}"
return dict(txt=txt, hard_label=hard_label)
register_dataset(
"sciq",
DatasetConfig(loader=hf_loader("sciq"), formatter=format_sciq),
)
def format_anthropic_hh(ex, rng):
hard_label = int(rng.random() < 0.5)
txt = ex["chosen"] if hard_label else ex["rejected"]
return dict(txt=txt, hard_label=hard_label)
register_dataset(
"anthropic_hh",
DatasetConfig(loader=hf_loader("Anthropic/hh-rlhf"), formatter=format_anthropic_hh),
)
def format_cosmosqa(ex, rng):
true_answer = ex["answer" + str(ex["label"])]
if "None of the above choices ." in true_answer:
hard_label = 0
else:
assert "None of the above choices" not in true_answer, true_answer
hard_label = int(rng.random() < 0.5)
if hard_label:
answer = true_answer
else:
candidate_answers = [ex["answer" + str(i)] for i in range(4)]
answer = rng.choice([x for x in candidate_answers if x != true_answer])
txt = f"Context: {ex['context']}\nQuestion: {ex['question']}\nAnswer: {answer}"
return dict(txt=txt, hard_label=hard_label)
register_dataset(
"cosmos_qa",
DatasetConfig(
loader=hf_loader("cosmos_qa", split_names=dict(test="validation")),
formatter=format_cosmosqa,
),
)
def format_boolq(ex, rng):
hard_label = int(ex["answer"])
txt = f"Passage: {ex['passage']}\nQuestion: {ex['question']}"
return dict(txt=txt, hard_label=hard_label)
register_dataset(
"boolq",
DatasetConfig(
loader=hf_loader("boolq", split_names=dict(test="validation")), formatter=format_boolq
),
)
VALID_DATASETS: list[str] = list(_REGISTRY.keys())
"""
from datasets import disable_caching
disable_caching()
from weak_to_strong.datasets import load_dataset, VALID_DATASETS
import numpy as np
ds_name = "boolq"
print(VALID_DATASETS)
ds = load_dataset(ds_name, split_sizes=dict(train=500, test=10))
train = list(ds['train'])
test = list(ds['test'])
print(test[0])
print(np.mean([x['hard_label'] for x in train]))
"""