lmms_eval/tasks/olympiadbench/cn_utils.py (58 lines of code) (raw):
import os
import json
import datetime
from lmms_eval.tasks.olympiadbench.olympiadbench_evals import OlympiadBenchEvaluator
from lmms_eval.tasks._task_utils.file_utils import generate_submission_file
import logging
eval_logger = logging.getLogger("lmms-eval")
dir_name = os.path.dirname(os.path.abspath(__file__))
olympiadbench_evaluator = OlympiadBenchEvaluator()
def olympiadbench_doc_to_visual(doc):
return [image.convert("RGB") for image in doc["images"]]
def olympiadbench_doc_to_text(doc):
question = doc["question"]
subject = doc["subfield"]
mul_ans = doc["is_multiple_answer"]
if mul_ans is None:
mul_ans = False
ans_type = doc["answer_type"]
if ans_type == "Need_human_evaluate":
ans_type = "proof based"
pre_prompt = f"以下是中国{subject}竞赛中的解答题。\n"
post_prompt = ""
if not mul_ans:
post_prompt += f"答案类型为{ans_type}。\n"
else:
post_prompt += f"题目有多个答案,答案类型均为{ans_type}。\n"
post_prompt += "请根据题目的要求和所提供的信息计算得出答案。解答过程和结果中使用的变量和公式请使用LaTeX格式表示。请在最后以"
if not mul_ans:
post_prompt += '"所以最终答案是\\boxed{答案}。"\n'
else:
post_prompt += '"所以最终答案是\\boxed{用英⽂逗号连接的多个答案}。"\n'
final_question = pre_prompt + question + '\n' + post_prompt
return final_question
def olympiadbench_process_results(doc, results):
precision = doc["error"]
is_proving = "TP" in doc["source"]
if precision is None:
precision = 0
prediction = results[0].strip()
if is_proving:
return {
"submission": prediction
}
else:
prediction = prediction.split("所以最终答案是")[-1]
prediction = prediction.replace('"', "").replace("\n", "").replace(" ", "").strip(".").strip("。")
accuracy = olympiadbench_evaluator.judge(prediction, doc["final_answer"][0], precision)
accuracy = int(accuracy)
return {
"exact_match": accuracy
}
def olympiadbench_aggregate_results(results, args):
now_date_time = datetime.datetime.now().strftime("%Y-%m%d-%H%M-%S")
submission_file_name = f"olympiadbench-test-cn-submission-{now_date_time}.json"
path = generate_submission_file(submission_file_name, args)
with open(path, "w") as f:
json.dump(results, f, ensure_ascii=False)
print(f"Submission file saved to {path}")