grade_school_math/dataset.py (59 lines of code) (raw):

import json import os import re import torch as th def read_jsonl(path: str): with open(path) as fh: return [json.loads(line) for line in fh.readlines() if line] def get_examples(split): path = os.path.join("data/", f"{split}.jsonl") examples = read_jsonl(path) for ex in examples: ex.update(question=ex["question"] + "\n") ex.update(answer=ex["answer"] + "<|endoftext|>") print(f"{len(examples)} {split} examples") return examples ANS_RE = re.compile(r"#### (\-?[0-9\.\,]+)") INVALID_ANS = "[invalid]" def extract_answer(completion): match = ANS_RE.search(completion) if match: match_str = match.group(1).strip() match_str = match_str.replace(",", "") return match_str else: return INVALID_ANS def is_correct(model_completion, gt_example): gt_answer = extract_answer(gt_example["answer"]) assert gt_answer != INVALID_ANS return extract_answer(model_completion) == gt_answer class GSMDataset(th.utils.data.Dataset): def __init__(self, tokenizer, examples, loss_on_prefix=True): self.examples = examples self.qns = [ex["question"] for ex in self.examples] self.ans = [ex["answer"] for ex in self.examples] self.qns = tokenizer(self.qns, padding=False) self.ans = tokenizer(self.ans, padding=False) self.loss_on_prefix = loss_on_prefix self.max_len = max( [ len(self.qns["input_ids"][i]) + len(self.ans["input_ids"][i]) for i in range(len(self.examples)) ] ) print(f"Max tokens: {self.max_len}") def __len__(self): return len(self.examples) def __getitem__(self, idx): qn_tokens = self.qns["input_ids"][idx] ans_tokens = self.ans["input_ids"][idx] pad_tokens = [0] * (self.max_len - len(qn_tokens) - len(ans_tokens)) tokens = qn_tokens + ans_tokens + pad_tokens mask = ( ([int(self.loss_on_prefix)] * len(qn_tokens)) + ([1] * len(ans_tokens)) + ([0] * len(pad_tokens)) ) tokens = th.tensor(tokens) mask = th.tensor(mask) return dict(input_ids=tokens, attention_mask=mask)