lmms_eval/tasks/vqav2/utils.py (69 lines of code) (raw):

import re import os import json import logging import datetime import statistics import lmms_eval.tasks._task_utils.file_utils as file_utils from lmms_eval.tasks._task_utils.vqa_eval_metric import EvalAIAnswerProcessor eval_logger = logging.getLogger("lmms-eval") def vqav2_doc_to_visual(doc): return [doc["image"].convert("RGB")] def vqav2_process_results(doc, result): eval_ai_processor = EvalAIAnswerProcessor() assert len(result) == 1, f"The result should be a list of length 1, but got {len(result)}." resAns = eval_ai_processor(result[0]) accuracy = 0 if "answers" in doc and doc["answers"] is not None: for ansDic in doc["answers"]: ansDic["answer"] = ansDic["answer"].replace("\n", " ") ansDic["answer"] = ansDic["answer"].replace("\t", " ") ansDic["answer"] = ansDic["answer"].strip() gtAcc = [] gtAnswers = [ans["answer"] for ans in doc["answers"]] if len(set(gtAnswers)) > 1: for ansDic in doc["answers"]: ansDic["answer"] = eval_ai_processor.process_punctuation(ansDic["answer"]) ansDic["answer"] = eval_ai_processor.process_digit_article(ansDic["answer"]) resAns = eval_ai_processor.process_punctuation(resAns) resAns = eval_ai_processor.process_digit_article(resAns) for gtAnsDatum in doc["answers"]: otherGTAns = [item for item in doc["answers"] if item != gtAnsDatum] matchingAns = [item for item in otherGTAns if item["answer"] == resAns] acc = min(1, float(len(matchingAns)) / 3) gtAcc.append(acc) accuracy = statistics.mean(gtAcc) return { "exact_match": accuracy, "submission": { "question_id": doc["question_id"], "answer": resAns, }, } def vqav2_process_results_test(doc, result): res = vqav2_process_results(doc, result) return { "submission": res["submission"], } def vqav2_process_results_val(doc, result): res = vqav2_process_results(doc, result) return { "exact_match": res["exact_match"], } def vqav2_doc_to_text(doc, model_specific_prompt_kwargs=None): if model_specific_prompt_kwargs is None: model_specific_prompt_kwargs = {} pre_prompt = "" post_prompt = "" if "pre_prompt" in model_specific_prompt_kwargs: pre_prompt = model_specific_prompt_kwargs["pre_prompt"] if "post_prompt" in model_specific_prompt_kwargs: post_prompt = model_specific_prompt_kwargs["post_prompt"] return f"{pre_prompt}{doc['question']}{post_prompt}" def vqav2_aggreate_submissions(results, args): now_date_time = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S") submission_file_name = f"vqav2-test-submission-{now_date_time}.json" path = file_utils.generate_submission_file(submission_file_name, args) with open(path, "w") as f: json.dump(results, f) eval_logger.info(f"Submission file saved to {path}")