vision/m4/evaluation/custom_metrics/open_ended_vqa_metrics.py (296 lines of code) (raw):

import logging from enum import Enum from typing import List import datasets import evaluate import numpy as np from scipy import stats from m4.evaluation.custom_metrics.utils import ( VQANormalizationGtVisionLab, check_is_number, convert_to_number, normalize_str_mmmu, parse_open_response_mmmu, vqa_normalize_text, ) logger = logging.getLogger(__name__) _DESCRIPTION = "" _CITATION = "" _KWARGS_DESCRIPTION = "" logger = logging.getLogger(__name__) MCQ_POSSIBLE_CHOICES = ["A", "B", "C", "D", "E", "F", "G", "H", "I"] class OEVQAMetrics(Enum): FIRST_WORD_VQA_ACCURACY = "first_word_vqa_accuracy" OE_VQA_ACCURACY = "oe_vqa_accuracy" OE_MMMU_STYLE_VQA_ACCURACY = "oe_mmmu_style_vqa_accuracy" OE_MMMU_STYLE_PER_BUCKET_ACCURACY = "oe_mmmu_style_per_bucket_accuracy" OE_ONLY_MMMU_STYLE_VQA_ACCURACY = "oe_only_mmmu_style_vqa_accuracy" OE_ONLY_MMMU_STYLE_PER_BUCKET_ACCURACY = "oe_only_mmmu_style_per_bucket_accuracy" OE_RELAXED_VQA_ACCURACY = "oe_relaxed_vqa_accuracy" GT_VISION_LAB_FIRST_WORD_VQA_ACCURACY = "gt_vision_lab_first_word_vqa_accuracy" GT_VISION_LAB_OE_VQA_ACCURACY = "gt_vision_lab_oe_vqa_accuracy" DEFAULT_TO_SERVER_RESULTS = "default_to_server_results" DEFAULT_TO_SERVER_RESULTS_MMVET = "default_to_server_results_mmvet" DEFAULT_TO_SERVER_RESULTS_LLAVA_WILD = "default_to_server_results_llava_wild" class OpenEndedVQAMetrics(evaluate.Metric): """This class takes care of computing the metrics listed in `metrics`.""" def __init__(self, metrics: List[OEVQAMetrics], save_generations: bool, **kwargs): super().__init__(**kwargs) self.metrics = metrics self.save_generations = save_generations self.gt_vision_lab_normalization = VQANormalizationGtVisionLab() if metrics is None: raise ValueError("`metrics` must be specified") def _info(self): return evaluate.MetricInfo( description=_DESCRIPTION, citation=_CITATION, inputs_description=_KWARGS_DESCRIPTION, features=datasets.Features( { "example_ids": datasets.Value("string"), "generated_texts": datasets.Value("string"), "answers": datasets.Sequence(datasets.Value("string")), "buckets": datasets.Value("string"), } ), ) def _compute_vqa_accuracy(self, generated_texts_unique, answers_unique, normalize_text_fn): first_word_vqa_accuracy_scores = [] oe_accuracy_scores = [] for generated_text, answers_ in zip(generated_texts_unique, answers_unique): generated_text = normalize_text_fn(generated_text) generated_first_word = generated_text.split(" ")[0] answers_ = [normalize_text_fn(answer_) for answer_ in answers_] if len(answers_) == 1: # This is the case for GQA for example first_word_vqa_accuracy_scores.append((generated_first_word == answers_[0]) * 1.0) oe_accuracy_scores.append((generated_text == answers_[0]) * 1.0) else: gt_first_word_acc = [] gt_oe_acc = [] for idx_ref in range(len(answers_)): other_answers_ = [other_answer for idx, other_answer in enumerate(answers_) if idx != idx_ref] matched_with_first_word = [ other_answer for other_answer in other_answers_ if other_answer == generated_first_word ] matched_with_oe_text = [ other_answer for other_answer in other_answers_ if other_answer == generated_text ] first_word_acc = min(1, len(matched_with_first_word) / 3) oe_acc = min(1, len(matched_with_oe_text) / 3) gt_first_word_acc.append(first_word_acc) gt_oe_acc.append(oe_acc) first_word_vqa_accuracy_scores.append(sum(gt_first_word_acc) / len(gt_first_word_acc)) oe_accuracy_scores.append(sum(gt_oe_acc) / len(gt_oe_acc)) return first_word_vqa_accuracy_scores, oe_accuracy_scores def _compute_mmmu_style_vqa_accuracy(self, generated_texts_unique, answers_unique, normalize_text_fn, accept_mcq): oe_accuracy_scores = [] for generated_text, answers in zip(generated_texts_unique, answers_unique): is_mcq = answers[0] in MCQ_POSSIBLE_CHOICES and accept_mcq if is_mcq: generated_text_extracted_answer_candidates = [normalize_text_fn(generated_text)] else: generated_text_extracted_answer_candidates = parse_open_response_mmmu( generated_text, normalize_text_fn ) answers = [normalize_text_fn(answer) for answer in answers] correct = 0 for answer in answers: for generated_answer_candidate in generated_text_extracted_answer_candidates: if isinstance(answer, str) and isinstance(generated_answer_candidate, str): # In the case of an mcq question, there is only one answer, and the answer has to be exact. if is_mcq and generated_answer_candidate == answer: correct = 1 break elif answer in generated_answer_candidate: correct = 1 break # If it's a number, it has been converted to a float rounded to 2 decimals elif ( isinstance(answer, float) and isinstance(generated_answer_candidate, float) and generated_answer_candidate == answer ): correct = 1 break else: pass # This is the case of a number and a string, we don't want to compare them oe_accuracy_scores.append(correct) return oe_accuracy_scores def _compute_relaxed_vqa_accuracy(self, generated_texts_unique, answers_unique, normalize_text_fn): """ From https://aclanthology.org/2022.findings-acl.177.pdf We use a relaxed accuracy measure for the numeric answers to allow a minor inaccuracy that may result from the automatic data extraction process. We consider an answer to be correct if it is within 5% of the gold answer. For non-numeric answers, we still need an exact match to consider an answer to be correct. """ oe_accuracy_scores = [] for generated_text, answers in zip(generated_texts_unique, answers_unique): generated_text = normalize_text_fn(generated_text) answers = [normalize_text_fn(a) for a in answers] correct = 0 for answer in answers: if check_is_number(answer): if check_is_number(generated_text): generated_text_f = convert_to_number(generated_text) answer_f = convert_to_number(answer) if answer_f != 0.0: correct = abs(generated_text_f - answer_f) / answer_f < 0.05 or correct else: correct = generated_text_f == answer_f break elif generated_text == answer: correct = 1 oe_accuracy_scores.append(correct) return oe_accuracy_scores def _compute(self, example_ids, generated_texts, answers, buckets): data_per_id = {} for ex_id, gen_text, ans, bucket in zip(example_ids, generated_texts, answers, buckets): # If condition is a dirty trick to handle the case of distributed evaluation where some instances can be # repeated over a few proceses to make the batches even. # In this case, we just verify that all processes predicted the same thing, and only take one copy of predictions # in order to not mess up metrics. Ideally this "unique" logic should be handled outside of the metric or maybe # in the add_batch call... if ex_id not in data_per_id: data_per_id[ex_id] = { "generated_text": gen_text, "answers": ans, "bucket": bucket, } else: if data_per_id[ex_id]["generated_text"] != gen_text: logger.warning( f"Example {ex_id} has different predictions accross processes. We have: {gen_text} and" f" {data_per_id[ex_id]['generated_text']}" ) if data_per_id[ex_id]["answers"] != ans: logger.warning( f"Example {ex_id} has different answers accross processes. We have: {ans} and" f" {data_per_id[ex_id]['answers']}" ) # assert list(range(len(data_per_id))) == sorted(data_per_id.keys()) generated_texts_unique = [data_per_id[i]["generated_text"] for i in set(example_ids)] answers_unique = [data_per_id[i]["answers"] for i in set(example_ids)] results = {} default_to_save_generations = ( answers_unique[0] is None or answers_unique[0][0] == "" ) and OEVQAMetrics.DEFAULT_TO_SERVER_RESULTS in self.metrics if self.save_generations or default_to_save_generations: # If answers are None, we default to the server results if ( OEVQAMetrics.OE_MMMU_STYLE_VQA_ACCURACY in self.metrics or OEVQAMetrics.OE_ONLY_MMMU_STYLE_VQA_ACCURACY in self.metrics ): results["server_results"] = [ { "question_id": ex_id, "answer": data["generated_text"], } for ex_id, data in data_per_id.items() ] else: results["server_results"] = [ { "question_id": ex_id, "answer": self.gt_vision_lab_normalization.vqa_normalize_text(data["generated_text"]), } for ex_id, data in data_per_id.items() ] if default_to_save_generations: return results if OEVQAMetrics.DEFAULT_TO_SERVER_RESULTS_MMVET in self.metrics: results["server_results"] = {ex_id: data["generated_text"] for ex_id, data in data_per_id.items()} return results elif OEVQAMetrics.DEFAULT_TO_SERVER_RESULTS_LLAVA_WILD in self.metrics: results["server_results"] = [ { "question_id": ex_id, "answer": data["generated_text"], } for ex_id, data in data_per_id.items() ] return results # VQA Accuracy # From "VQA: Visual Question Answering" paper: # an answer is deemed 100% accurate if at least 3 workers provided that exact answer. 2 Before comparison, # all responses are made lowercase, numbers converted to digits, and punctuation & articles removed. if ( OEVQAMetrics.GT_VISION_LAB_FIRST_WORD_VQA_ACCURACY in self.metrics or OEVQAMetrics.GT_VISION_LAB_OE_VQA_ACCURACY in self.metrics ): gt_vision_lab_first_word_vqa_accuracy_scores, get_visison_lab_oe_accuracy_scores = ( self._compute_vqa_accuracy( generated_texts_unique, answers_unique, self.gt_vision_lab_normalization.vqa_normalize_text ) ) if OEVQAMetrics.GT_VISION_LAB_FIRST_WORD_VQA_ACCURACY in self.metrics: results["gt_vision_lab_first_word_vqa_accuracy"] = float( sum(gt_vision_lab_first_word_vqa_accuracy_scores) / len(gt_vision_lab_first_word_vqa_accuracy_scores) ) if OEVQAMetrics.GT_VISION_LAB_OE_VQA_ACCURACY in self.metrics: results["gt_vision_lab_oe_vqa_accuracy"] = float( sum(get_visison_lab_oe_accuracy_scores) / len(get_visison_lab_oe_accuracy_scores) ) confidence_level = 0.95 z_score = stats.norm.ppf((1 + confidence_level) / 2) std_dev = np.std(get_visison_lab_oe_accuracy_scores) results["gt_vision_lab_oe_vqa_accuracy_std"] = std_dev results["gt_vision_lab_oe_vqa_accuracy_margin_of_error"] = z_score * ( std_dev / np.sqrt(len(get_visison_lab_oe_accuracy_scores)) ) if ( OEVQAMetrics.OE_MMMU_STYLE_PER_BUCKET_ACCURACY in self.metrics or OEVQAMetrics.OE_ONLY_MMMU_STYLE_PER_BUCKET_ACCURACY in self.metrics ): if ( OEVQAMetrics.OE_MMMU_STYLE_PER_BUCKET_ACCURACY in self.metrics and OEVQAMetrics.OE_ONLY_MMMU_STYLE_PER_BUCKET_ACCURACY in self.metrics ): raise ValueError( "Cannot compute both OE_MMMU_STYLE_PER_BUCKET_ACCURACY and OE_ONLY_MMMU_STYLE_PER_BUCKET_ACCURACY" " at the same time." ) # Here each bucket has the form "bucket_col_0=sub_bucket_name_x/bucket_col_1=sub_bucket_name_y/... etc." buckets_aggregated = [data_per_id[example_id]["bucket"] for example_id in set(example_ids)] # Get columns and unique buckets unique_buckets = set(buckets_aggregated) bucket_columns = [column_buckets.split("=")[0] for column_buckets in buckets_aggregated[0].split("/")] # Initialize the scores dict scores_dict = {} for bucket_column in bucket_columns: scores_dict[bucket_column] = {} for unique_bucket in unique_buckets: column_sub_bucket_names = [column_bucket.split("=")[1] for column_bucket in unique_bucket.split("/")] for bucket_column, sub_bucket_name in zip(bucket_columns, column_sub_bucket_names): scores_dict[bucket_column][sub_bucket_name] = [] # Need np array to use .where generated_texts_unique_np = np.array(generated_texts_unique) answers_unique_np = np.array(answers_unique, dtype=object) buckets_aggregated = np.array(buckets_aggregated) for b_ in unique_buckets: # Find the positions of the unique_bucket in the buckets_aggregated to compute the scores bucket_position = np.where(buckets_aggregated == b_)[0] oe_mmmu_style_bucket_scores = self._compute_mmmu_style_vqa_accuracy( generated_texts_unique_np[bucket_position], answers_unique_np[bucket_position], normalize_str_mmmu, # Do not accept mcq when using OE_ONLY_MMMU_STYLE_PER_BUCKET_ACCURACY metric accept_mcq=OEVQAMetrics.OE_ONLY_MMMU_STYLE_PER_BUCKET_ACCURACY not in self.metrics, ) # Each sub_bucket (column, name) pair from this buckets_aggregated entry # extends the list of scores of its corresponding entry in the scores_dict # with the oe_mmmu_style_bucket_scores. sub_buckets_tuples = [ (column_bucket.split("=")[0], column_bucket.split("=")[1]) for column_bucket in b_.split("/") ] for sub_bucket_col, sub_bucket_name in sub_buckets_tuples: scores_dict[sub_bucket_col][sub_bucket_name].extend(oe_mmmu_style_bucket_scores) for key, value in scores_dict.items(): for k, v in value.items(): scores_dict[key][k] = {"accuracy": float(sum(v) / len(v)), "std": np.std(v)} if OEVQAMetrics.OE_ONLY_MMMU_STYLE_PER_BUCKET_ACCURACY in self.metrics: results["oe_only_mmmu_style_per_bucket_accuracy"] = scores_dict elif OEVQAMetrics.OE_MMMU_STYLE_PER_BUCKET_ACCURACY in self.metrics: results["oe_mmmu_style_per_bucket_accuracy"] = scores_dict if OEVQAMetrics.OE_MMMU_STYLE_VQA_ACCURACY in self.metrics: oe_mmmu_style_accuracy_scores = self._compute_mmmu_style_vqa_accuracy( generated_texts_unique, answers_unique, normalize_str_mmmu, accept_mcq=True ) results["oe_mmmu_style_vqa_accuracy"] = float( sum(oe_mmmu_style_accuracy_scores) / len(oe_mmmu_style_accuracy_scores) ) if OEVQAMetrics.OE_ONLY_MMMU_STYLE_VQA_ACCURACY in self.metrics: oe_mmmu_style_accuracy_scores = self._compute_mmmu_style_vqa_accuracy( generated_texts_unique, answers_unique, normalize_str_mmmu, accept_mcq=False ) results["oe_only_mmmu_style_vqa_accuracy"] = float( sum(oe_mmmu_style_accuracy_scores) / len(oe_mmmu_style_accuracy_scores) ) if OEVQAMetrics.OE_RELAXED_VQA_ACCURACY in self.metrics: oe_relaxed_vqa_accuracy = self._compute_relaxed_vqa_accuracy( generated_texts_unique, answers_unique, lambda txt: txt.strip(".") ) results["oe_relaxed_vqa_accuracy"] = float(sum(oe_relaxed_vqa_accuracy) / len(oe_relaxed_vqa_accuracy)) if OEVQAMetrics.FIRST_WORD_VQA_ACCURACY in self.metrics or OEVQAMetrics.OE_VQA_ACCURACY in self.metrics: first_word_vqa_accuracy_scores, oe_accuracy_scores = self._compute_vqa_accuracy( generated_texts_unique, answers_unique, vqa_normalize_text ) if OEVQAMetrics.FIRST_WORD_VQA_ACCURACY in self.metrics: results["first_word_vqa_accuracy"] = float( sum(first_word_vqa_accuracy_scores) / len(first_word_vqa_accuracy_scores) ) if OEVQAMetrics.OE_VQA_ACCURACY in self.metrics: results["oe_vqa_accuracy"] = float(sum(oe_accuracy_scores) / len(oe_accuracy_scores)) return results