in mmlu_eval.py [0:0]
def __call__(self, sampler: SamplerBase) -> EvalResult:
def fn(row: dict):
prompt_messages = [
sampler._pack_message(
content=format_multichoice_question(row), role="user"
)
]
response_text = normalize_response(sampler(prompt_messages))
extracted_answer = None
for answer_regex in MULTILINGUAL_ANSWER_REGEXES:
regex = MULTILINGUAL_ANSWER_PATTERN_TEMPLATE.format(answer_regex)
match = re.search(regex, response_text)
if match:
extracted_answer = normalize_extracted_answer(match.group(1))
break
score = 1.0 if extracted_answer == row["Answer"] else 0.0
html = common.jinja_env.from_string(HTML_JINJA).render(
prompt_messages=prompt_messages,
next_message=dict(content=response_text, role="assistant"),
score=score,
correct_answer=row["Answer"],
extracted_answer=extracted_answer,
)
convo = prompt_messages + [dict(content=response_text, role="assistant")]
category = subject2category.get(row["Subject"], "other")
return SingleEvalResult(
html=html, score=score, metrics={category: score}, convo=convo
)
results = common.map_with_progress(fn, self.examples)
return common.aggregate_results(results)