lmms_eval/tasks/vqav2/utils.py (69 lines of code) (raw):
import re
import os
import json
import logging
import datetime
import statistics
import lmms_eval.tasks._task_utils.file_utils as file_utils
from lmms_eval.tasks._task_utils.vqa_eval_metric import EvalAIAnswerProcessor
eval_logger = logging.getLogger("lmms-eval")
def vqav2_doc_to_visual(doc):
return [doc["image"].convert("RGB")]
def vqav2_process_results(doc, result):
eval_ai_processor = EvalAIAnswerProcessor()
assert len(result) == 1, f"The result should be a list of length 1, but got {len(result)}."
resAns = eval_ai_processor(result[0])
accuracy = 0
if "answers" in doc and doc["answers"] is not None:
for ansDic in doc["answers"]:
ansDic["answer"] = ansDic["answer"].replace("\n", " ")
ansDic["answer"] = ansDic["answer"].replace("\t", " ")
ansDic["answer"] = ansDic["answer"].strip()
gtAcc = []
gtAnswers = [ans["answer"] for ans in doc["answers"]]
if len(set(gtAnswers)) > 1:
for ansDic in doc["answers"]:
ansDic["answer"] = eval_ai_processor.process_punctuation(ansDic["answer"])
ansDic["answer"] = eval_ai_processor.process_digit_article(ansDic["answer"])
resAns = eval_ai_processor.process_punctuation(resAns)
resAns = eval_ai_processor.process_digit_article(resAns)
for gtAnsDatum in doc["answers"]:
otherGTAns = [item for item in doc["answers"] if item != gtAnsDatum]
matchingAns = [item for item in otherGTAns if item["answer"] == resAns]
acc = min(1, float(len(matchingAns)) / 3)
gtAcc.append(acc)
accuracy = statistics.mean(gtAcc)
return {
"exact_match": accuracy,
"submission": {
"question_id": doc["question_id"],
"answer": resAns,
},
}
def vqav2_process_results_test(doc, result):
res = vqav2_process_results(doc, result)
return {
"submission": res["submission"],
}
def vqav2_process_results_val(doc, result):
res = vqav2_process_results(doc, result)
return {
"exact_match": res["exact_match"],
}
def vqav2_doc_to_text(doc, model_specific_prompt_kwargs=None):
if model_specific_prompt_kwargs is None:
model_specific_prompt_kwargs = {}
pre_prompt = ""
post_prompt = ""
if "pre_prompt" in model_specific_prompt_kwargs:
pre_prompt = model_specific_prompt_kwargs["pre_prompt"]
if "post_prompt" in model_specific_prompt_kwargs:
post_prompt = model_specific_prompt_kwargs["post_prompt"]
return f"{pre_prompt}{doc['question']}{post_prompt}"
def vqav2_aggreate_submissions(results, args):
now_date_time = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
submission_file_name = f"vqav2-test-submission-{now_date_time}.json"
path = file_utils.generate_submission_file(submission_file_name, args)
with open(path, "w") as f:
json.dump(results, f)
eval_logger.info(f"Submission file saved to {path}")