in grade_school_math/dataset.py [0:0]
def __init__(self, tokenizer, examples, loss_on_prefix=True):
self.examples = examples
self.qns = [ex["question"] for ex in self.examples]
self.ans = [ex["answer"] for ex in self.examples]
self.qns = tokenizer(self.qns, padding=False)
self.ans = tokenizer(self.ans, padding=False)
self.loss_on_prefix = loss_on_prefix
self.max_len = max(
[
len(self.qns["input_ids"][i]) + len(self.ans["input_ids"][i])
for i in range(len(self.examples))
]
)
print(f"Max tokens: {self.max_len}")