in mgsm_eval.py [0:0]
def __call__(self, sampler: SamplerBase) -> EvalResult:
def fn(example: dict[str, str]):
language = example["lang"]
latin_language = "group_latin" if language in LATIN_LANGUAGES else "group_non_latin"
correct_answer = example["targets"]
instruction = LANG_TO_INSTRUCTIONS[language]
prompt_messages = [
sampler._pack_message(
content=instruction.format(input=example["inputs"]), role="user"
)
]
try:
response_text = sampler(prompt_messages)
except Exception as e:
response_text = ""
answer_prefix = LANG_TO_ANSWER_PREFIX[language]
extracted_answer = parse_answer(response_text, answer_prefix)
score = score_mgsm(correct_answer, extracted_answer)
html = common.jinja_env.from_string(HTML_JINJA).render(
prompt_messages=prompt_messages,
next_message=dict(content=response_text, role="assistant"),
score=score,
correct_answer=correct_answer,
extracted_answer=extracted_answer or None,
)
convo = prompt_messages + [dict(content=response_text, role="assistant")]
return SingleEvalResult(
html=html,
score=score,
convo=convo,
metrics={language: score, latin_language: score},
)
results = common.map_with_progress(fn, self.examples)
return common.aggregate_results(results, default_stats=("mean", "std"))