lmms_eval/tasks/textvqa/utils.py (53 lines of code) (raw):

import re import os import json import yaml import pathlib import logging import datetime import statistics from lmms_eval.tasks._task_utils.vqa_eval_metric import EvalAIAnswerProcessor from lmms_eval.tasks._task_utils.file_utils import generate_submission_file eval_logger = logging.getLogger("lmms-eval") def textvqa_doc_to_visual(doc): return [doc["image"].convert("RGB")] def textvqa_process_results(doc, result): eval_ai_processor = EvalAIAnswerProcessor() assert len(result) == 1, f"The result should be a list of length 1, but got {len(result)}." resAns = eval_ai_processor(result[0]) accuracy = 0 if "answers" in doc and doc["answers"] is not None: gtAcc = [] for i in range(len(doc["answers"])): doc["answers"][i] = eval_ai_processor(doc["answers"][i]) for i in range(len(doc["answers"])): otherGTAns = [doc["answers"][j] for j in range(len(doc["answers"])) if i != j] matchingAns = [item for item in otherGTAns if item == resAns] acc = min(1, float(len(matchingAns)) / 3) gtAcc.append(acc) accuracy = statistics.mean(gtAcc) return { "exact_match": accuracy, "submission": { "question_id": doc["question_id"], "answer": resAns, }, } def textvqa_doc_to_text(doc, model_specific_prompt_kwargs=None): pre_prompt = "" post_post = "" ocr_ref = "" if model_specific_prompt_kwargs: if "pre_prompt" in model_specific_prompt_kwargs: pre_prompt = model_specific_prompt_kwargs["pre_prompt"] if "post_prompt" in model_specific_prompt_kwargs: post_prompt = model_specific_prompt_kwargs["post_prompt"] if "ocr" in model_specific_prompt_kwargs and model_specific_prompt_kwargs["ocr"]: ocr_ref = f"\nReference OCR token: {', '.join(doc['ocr_tokens'])}" return f"{pre_prompt}{doc['question'].capitalize()}{ocr_ref}{post_prompt}" def textvqa_aggreate_submissions(results, args): now_date_time = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S") path = generate_submission_file(f"textvqa_submission_{now_date_time}.json", args) with open(path, "w") as f: json.dump(results, f) # print(f"Submission file saved to {path}") eval_logger.info(f"Submission file saved to {path}")