lmms_eval/tasks/multidocvqa/utils.py (76 lines of code) (raw):

import os import re import ast import json import logging from lmms_eval.api.metrics import levenshtein_distance from lmms_eval.tasks._task_utils.file_utils import generate_submission_file lmms_logger = logging.getLogger("lmms-eval") def multidocvqa_doc_to_text(doc, model_specific_prompt_kwargs): question = doc["question"] pre_prompt = model_specific_prompt_kwargs["pre_prompt"] post_prompt = model_specific_prompt_kwargs["post_prompt"] return f"{pre_prompt}{question}{post_prompt}" def multidocvqa_doc_to_visual(doc): return [doc[f"image_{i}"].convert("RGB") for i in range(1, 21) if doc[f"image_{i}"] is not None] def multidocvqa_process_results(doc, results): pred_answer = results[0] answer = ast.literal_eval(doc["answers"]) return {"anls": {"questionId": int(doc["questionId"]), "answer": answer, "pred_answer": pred_answer}, "accuracy": {"questionId": int(doc["questionId"]), "answer": answer, "pred_answer": pred_answer}} def multidocvqa_aggregate_results_anls(results): keys = {k for result in results for k in result} results = {key: [result.get(key, None) for result in results] for key in keys} evaluator = Evaluator(case_sensitive=False) metric = evaluator.get_metrics(results["answer"], results["pred_answer"]) return sum(metric["anls"]) / len(metric["anls"]) def multidocvqa_aggregate_results_accuracy(results): keys = {k for result in results for k in result} results = {key: [result.get(key, None) for result in results] for key in keys} evaluator = Evaluator(case_sensitive=False) metric = evaluator.get_metrics(results["answer"], results["pred_answer"]) return sum(metric["accuracy"]) / len(metric["accuracy"]) def multidocvqa_process_test_results_for_submission(doc, results): answer = results[0] return {"submission": {"questionId": int(doc["questionId"]), "answer": answer, "answer_page": None}} def multidocvqa_test_aggregate_results_for_submission(results, args): path = generate_submission_file("multidocvqa_test_for_submission.json", args) with open(path, "w") as f: json.dump(results, f) lmms_logger.info(f"Results saved to {path}.") ################## # Helper functions ################## class Evaluator: def __init__(self, case_sensitive=False): self.case_sensitive = case_sensitive self.get_edit_distance = levenshtein_distance self.anls_threshold = 0.5 def get_metrics(self, gt_answers, preds): batch_accuracy = [] batch_anls = [] for batch_idx in range(len(preds)): gt = [self._preprocess_str(gt_elm) for gt_elm in gt_answers[batch_idx]] pred = self._preprocess_str(preds[batch_idx]) batch_accuracy.append(self._calculate_accuracy(gt, pred)) batch_anls.append(self._calculate_anls(gt, pred)) return {"accuracy": batch_accuracy, "anls": batch_anls} def _preprocess_str(self, string): if not self.case_sensitive: string = string.lower() return string.strip() def _calculate_accuracy(self, gt, pred): if pred == "none": return 0 for gt_elm in gt: if gt_elm == pred: return 1 return 0 def _calculate_anls(self, gt, pred): if len(pred) == 0: return 0 if pred == "none": return 0 answers_similarity = [1 - self.get_edit_distance(gt_elm, pred) / max(len(gt_elm), len(pred)) for gt_elm in gt] max_similarity = max(answers_similarity) anls = max_similarity if max_similarity >= self.anls_threshold else 0 return anls if __name__ == "__main__": print("-----------------") multidocvqa_aggregate_results_anls([{"questionId": 1, "answer": ["answer"], "pred_answer": "pred_answer"}, {"questionId": 2, "answer": ["nswer"], "pred_answer": "nswer"}])