def __call__()

in mgsm_eval.py [0:0]


    def __call__(self, sampler: SamplerBase) -> EvalResult:
        def fn(example: dict[str, str]):
            language = example["lang"]
            latin_language = "group_latin" if language in LATIN_LANGUAGES else "group_non_latin"
            correct_answer = example["targets"]
            instruction = LANG_TO_INSTRUCTIONS[language]
            prompt_messages = [
                sampler._pack_message(
                    content=instruction.format(input=example["inputs"]), role="user"
                )
            ]
            try:
                response_text = sampler(prompt_messages)
            except Exception as e:
                response_text = ""

            answer_prefix = LANG_TO_ANSWER_PREFIX[language]
            extracted_answer = parse_answer(response_text, answer_prefix)

            score = score_mgsm(correct_answer, extracted_answer)
            html = common.jinja_env.from_string(HTML_JINJA).render(
                prompt_messages=prompt_messages,
                next_message=dict(content=response_text, role="assistant"),
                score=score,
                correct_answer=correct_answer,
                extracted_answer=extracted_answer or None,
            )
            convo = prompt_messages + [dict(content=response_text, role="assistant")]
            return SingleEvalResult(
                html=html,
                score=score,
                convo=convo,
                metrics={language: score, latin_language: score},
            )

        results = common.map_with_progress(fn, self.examples)
        return common.aggregate_results(results, default_stats=("mean", "std"))