in grade_school_math/calculator.py [0:0]
def sample(model, qn, tokenizer, device, sample_len):
# Inefficient version of calculator sampling -- no batches, doesn't
# cache activations from previous tokens
EQUALS_TOKENS = set([28, 796, 47505])
for _ in range(sample_len):
with th.no_grad():
toks = tokenizer([qn], padding=False, return_tensors="pt").to(device)
orig_len = toks["input_ids"].shape[1]
out = model.generate(
**toks, max_length=orig_len + 1, pad_token_id=model.config.eos_token_id
)
text = tokenizer.batch_decode(out)[0]
if out[0, -1].item() in EQUALS_TOKENS:
answer = use_calculator(text)
if answer is not None:
print("Triggered calculator, answer", answer)
text = text + str(answer) + ">>"
qn = text
return qn