in simpleqa_eval.py [0:0]
def __call__(self, sampler: SamplerBase) -> EvalResult:
def fn(row: dict):
prompt_messages = [
sampler._pack_message(content=row.get("problem", ""), role="user")
]
response_text = sampler(prompt_messages)
grade_letter = self.grade_sample(row.get("problem", ""), row.get("answer", ""), response_text)
# Metrics based on grading response
is_correct = grade_letter == "A"
is_incorrect = grade_letter == "B"
is_not_attempted = grade_letter == "C"
score = is_correct
# Create HTML for each sample result
html = common.jinja_env.from_string(common.HTML_JINJA).render(
prompt_messages=prompt_messages,
next_message=dict(content=response_text, role="assistant"),
score=score,
correct_answer=row["answer"],
extracted_answer=response_text,
)
convo = prompt_messages + [dict(content=response_text, role="assistant")]
return SingleEvalResult(html=html, score=score, convo=convo, metrics={
"is_correct": is_correct,
"is_incorrect": is_incorrect,
"is_not_attempted": is_not_attempted
})
# Run evaluation and collect results
results = common.map_with_progress(fn, self.examples)
# Aggregate metrics
aggregate_metrics = {
"is_correct": sum(result.metrics["is_correct"] for result in results) / len(results),
"is_incorrect": sum(result.metrics["is_incorrect"] for result in results) / len(results),
"is_not_attempted": sum(result.metrics["is_not_attempted"] for result in results) / len(results),
}
aggregate_metrics["is_given_attempted"] = aggregate_metrics["is_correct"] + aggregate_metrics["is_incorrect"]
# Calculate accuracy_given_attempted
aggregate_metrics["accuracy_given_attempted"] = (
aggregate_metrics["is_correct"]
/ aggregate_metrics["is_given_attempted"]
if aggregate_metrics["is_given_attempted"] > 0
else 0
)
print("AGGREGATE METRICS")
print(aggregate_metrics)
print("##################")
output_d = {
"accuracy_given_attempted": aggregate_metrics["accuracy_given_attempted"],
"f1": (
2 * aggregate_metrics["accuracy_given_attempted"] * aggregate_metrics["is_correct"]
/ (aggregate_metrics["accuracy_given_attempted"] + aggregate_metrics["is_correct"])
if (aggregate_metrics["accuracy_given_attempted"] + aggregate_metrics["is_correct"]) > 0
else 0
)
}
print(f"Accuracy Given Attempted: {output_d['accuracy_given_attempted']:.3f}")
print(f"F1 Score: {output_d['f1']:.3f}")
return common.aggregate_results(results)