lmms_eval/tasks/vizwiz_vqa/utils.py (56 lines of code) (raw):

import re import os import json import yaml import pathlib import logging import datetime import statistics from lmms_eval.tasks._task_utils.file_utils import generate_submission_file from lmms_eval.tasks._task_utils.vqa_eval_metric import EvalAIAnswerProcessor eval_logger = logging.getLogger("lmms-eval") def vizwiz_vqa_doc_to_visual(doc): return [doc["image"].convert("RGB")] def vizwiz_vqa_process_results(doc, result): eval_ai_processor = EvalAIAnswerProcessor() assert len(result) == 1, f"The result should be a list of length 1, but got {len(result)}." resAns = eval_ai_processor(result[0]) accuracy = 0 if "answers" in doc and doc["answers"] is not None: gtAcc = [] for i in range(len(doc["answers"])): doc["answers"][i] = eval_ai_processor(doc["answers"][i]) for i in range(len(doc["answers"])): otherGTAns = [doc["answers"][j] for j in range(len(doc["answers"])) if i != j] matchingAns = [item for item in otherGTAns if item == resAns] acc = min(1, float(len(matchingAns)) / 3) gtAcc.append(acc) if gtAcc: accuracy = statistics.mean(gtAcc) else: accuracy = 0 return { "exact_match": accuracy, "submission": { "image": f"{doc['question_id']}.jpg", "answer": resAns, }, } def vizwiz_vqa_doc_to_text(doc, model_specific_prompt_kwargs=None): if model_specific_prompt_kwargs is None: model_specific_prompt_kwargs = {} pre_prompt = "" post_prompt = "" if "pre_prompt" in model_specific_prompt_kwargs: pre_prompt = model_specific_prompt_kwargs["pre_prompt"] if "post_prompt" in model_specific_prompt_kwargs: post_prompt = model_specific_prompt_kwargs["post_prompt"] text = f"{pre_prompt}{doc['question'].capitalize()}{post_prompt}" return text def vizwiz_vqa_aggreate_submissions(results, args): now_date_time = datetime.datetime.now().strftime("%Y-%m%d-%H%M-%S") submission_file_name = f"vizwiz_vqa-test-submission-{now_date_time}.json" path = generate_submission_file(submission_file_name, args) with open(path, "w") as f: json.dump(results, f) print(f"Submission file saved to {path}")