in browsecomp_eval.py [0:0]
def __call__(self, sampler: SamplerBase) -> EvalResult:
def fn(row: dict):
problem = decrypt(row.get("problem", ""), row.get("canary", ""))
answer = decrypt(row.get("answer", ""), row.get("canary", ""))
prompt_messages = [
sampler._pack_message(content=QUERY_TEMPLATE.format(Question=problem), role="user")
]
response_text = sampler(prompt_messages)
grade_result = self.grade_sample(problem, answer, response_text)
# Metrics based on grading response
is_correct = grade_result == "yes"
is_incorrect = grade_result == "no"
score = is_correct
# Create HTML for each sample result
html = common.jinja_env.from_string(common.HTML_JINJA).render(
prompt_messages=prompt_messages,
next_message=dict(content=response_text, role="assistant"),
score=score,
correct_answer=row["answer"],
extracted_answer=response_text,
)
convo = prompt_messages + [dict(content=response_text, role="assistant")]
return SingleEvalResult(html=html, score=score, convo=convo, metrics={
"is_correct": is_correct,
"is_incorrect": is_incorrect,
})
# Run evaluation and collect results
results = common.map_with_progress(fn, self.examples)
# Aggregate metrics
aggregate_metrics = {
"is_correct": sum(result.metrics["is_correct"] for result in results) / len(results),
"is_incorrect": sum(result.metrics["is_incorrect"] for result in results) / len(results),
}
print("AGGREGATE METRICS")
print(aggregate_metrics)
print("##################")
output_d = {
"accuracy": aggregate_metrics["is_correct"],
}
print(f"Accuracy: {output_d['accuracy']:.3f}")
return common.aggregate_results(results)