def _compute()

in vision/m4/evaluation/custom_metrics/open_ended_vqa_metrics.py [0:0]


    def _compute(self, example_ids, generated_texts, answers, buckets):
        data_per_id = {}
        for ex_id, gen_text, ans, bucket in zip(example_ids, generated_texts, answers, buckets):
            # If condition is a dirty trick to handle the case of distributed evaluation where some instances can be
            # repeated over a few proceses to make the batches even.
            # In this case, we just verify that all processes predicted the same thing, and only take one copy of predictions
            # in order to not mess up metrics. Ideally this "unique" logic should be handled outside of the metric or maybe
            # in the add_batch call...
            if ex_id not in data_per_id:
                data_per_id[ex_id] = {
                    "generated_text": gen_text,
                    "answers": ans,
                    "bucket": bucket,
                }
            else:
                if data_per_id[ex_id]["generated_text"] != gen_text:
                    logger.warning(
                        f"Example {ex_id} has different predictions accross processes. We have: {gen_text} and"
                        f" {data_per_id[ex_id]['generated_text']}"
                    )
                if data_per_id[ex_id]["answers"] != ans:
                    logger.warning(
                        f"Example {ex_id} has different answers accross processes. We have: {ans} and"
                        f" {data_per_id[ex_id]['answers']}"
                    )

        # assert list(range(len(data_per_id))) == sorted(data_per_id.keys())
        generated_texts_unique = [data_per_id[i]["generated_text"] for i in set(example_ids)]
        answers_unique = [data_per_id[i]["answers"] for i in set(example_ids)]
        results = {}
        default_to_save_generations = (
            answers_unique[0] is None or answers_unique[0][0] == ""
        ) and OEVQAMetrics.DEFAULT_TO_SERVER_RESULTS in self.metrics
        if self.save_generations or default_to_save_generations:
            # If answers are None, we default to the server results
            if (
                OEVQAMetrics.OE_MMMU_STYLE_VQA_ACCURACY in self.metrics
                or OEVQAMetrics.OE_ONLY_MMMU_STYLE_VQA_ACCURACY in self.metrics
            ):
                results["server_results"] = [
                    {
                        "question_id": ex_id,
                        "answer": data["generated_text"],
                    }
                    for ex_id, data in data_per_id.items()
                ]
            else:
                results["server_results"] = [
                    {
                        "question_id": ex_id,
                        "answer": self.gt_vision_lab_normalization.vqa_normalize_text(data["generated_text"]),
                    }
                    for ex_id, data in data_per_id.items()
                ]

        if default_to_save_generations:
            return results

        if OEVQAMetrics.DEFAULT_TO_SERVER_RESULTS_MMVET in self.metrics:
            results["server_results"] = {ex_id: data["generated_text"] for ex_id, data in data_per_id.items()}
            return results
        elif OEVQAMetrics.DEFAULT_TO_SERVER_RESULTS_LLAVA_WILD in self.metrics:
            results["server_results"] = [
                {
                    "question_id": ex_id,
                    "answer": data["generated_text"],
                }
                for ex_id, data in data_per_id.items()
            ]
            return results

        # VQA Accuracy
        # From "VQA: Visual Question Answering" paper:
        # an answer is deemed 100% accurate if at least 3 workers provided that exact answer. 2 Before comparison,
        #  all responses are made lowercase, numbers converted to digits, and punctuation & articles removed.
        if (
            OEVQAMetrics.GT_VISION_LAB_FIRST_WORD_VQA_ACCURACY in self.metrics
            or OEVQAMetrics.GT_VISION_LAB_OE_VQA_ACCURACY in self.metrics
        ):
            gt_vision_lab_first_word_vqa_accuracy_scores, get_visison_lab_oe_accuracy_scores = (
                self._compute_vqa_accuracy(
                    generated_texts_unique, answers_unique, self.gt_vision_lab_normalization.vqa_normalize_text
                )
            )

            if OEVQAMetrics.GT_VISION_LAB_FIRST_WORD_VQA_ACCURACY in self.metrics:
                results["gt_vision_lab_first_word_vqa_accuracy"] = float(
                    sum(gt_vision_lab_first_word_vqa_accuracy_scores)
                    / len(gt_vision_lab_first_word_vqa_accuracy_scores)
                )

            if OEVQAMetrics.GT_VISION_LAB_OE_VQA_ACCURACY in self.metrics:
                results["gt_vision_lab_oe_vqa_accuracy"] = float(
                    sum(get_visison_lab_oe_accuracy_scores) / len(get_visison_lab_oe_accuracy_scores)
                )
                confidence_level = 0.95
                z_score = stats.norm.ppf((1 + confidence_level) / 2)
                std_dev = np.std(get_visison_lab_oe_accuracy_scores)
                results["gt_vision_lab_oe_vqa_accuracy_std"] = std_dev
                results["gt_vision_lab_oe_vqa_accuracy_margin_of_error"] = z_score * (
                    std_dev / np.sqrt(len(get_visison_lab_oe_accuracy_scores))
                )

        if (
            OEVQAMetrics.OE_MMMU_STYLE_PER_BUCKET_ACCURACY in self.metrics
            or OEVQAMetrics.OE_ONLY_MMMU_STYLE_PER_BUCKET_ACCURACY in self.metrics
        ):
            if (
                OEVQAMetrics.OE_MMMU_STYLE_PER_BUCKET_ACCURACY in self.metrics
                and OEVQAMetrics.OE_ONLY_MMMU_STYLE_PER_BUCKET_ACCURACY in self.metrics
            ):
                raise ValueError(
                    "Cannot compute both OE_MMMU_STYLE_PER_BUCKET_ACCURACY and OE_ONLY_MMMU_STYLE_PER_BUCKET_ACCURACY"
                    " at the same time."
                )

            # Here each bucket has the form "bucket_col_0=sub_bucket_name_x/bucket_col_1=sub_bucket_name_y/... etc."
            buckets_aggregated = [data_per_id[example_id]["bucket"] for example_id in set(example_ids)]

            # Get columns and unique buckets
            unique_buckets = set(buckets_aggregated)
            bucket_columns = [column_buckets.split("=")[0] for column_buckets in buckets_aggregated[0].split("/")]

            # Initialize the scores dict
            scores_dict = {}
            for bucket_column in bucket_columns:
                scores_dict[bucket_column] = {}
            for unique_bucket in unique_buckets:
                column_sub_bucket_names = [column_bucket.split("=")[1] for column_bucket in unique_bucket.split("/")]
                for bucket_column, sub_bucket_name in zip(bucket_columns, column_sub_bucket_names):
                    scores_dict[bucket_column][sub_bucket_name] = []

            # Need np array to use .where
            generated_texts_unique_np = np.array(generated_texts_unique)
            answers_unique_np = np.array(answers_unique, dtype=object)
            buckets_aggregated = np.array(buckets_aggregated)
            for b_ in unique_buckets:
                # Find the positions of the unique_bucket in the buckets_aggregated to compute the scores
                bucket_position = np.where(buckets_aggregated == b_)[0]
                oe_mmmu_style_bucket_scores = self._compute_mmmu_style_vqa_accuracy(
                    generated_texts_unique_np[bucket_position],
                    answers_unique_np[bucket_position],
                    normalize_str_mmmu,
                    # Do not accept mcq when using OE_ONLY_MMMU_STYLE_PER_BUCKET_ACCURACY metric
                    accept_mcq=OEVQAMetrics.OE_ONLY_MMMU_STYLE_PER_BUCKET_ACCURACY not in self.metrics,
                )

                # Each sub_bucket (column, name) pair from this buckets_aggregated entry
                # extends the list of scores of its corresponding entry in the scores_dict
                # with the oe_mmmu_style_bucket_scores.
                sub_buckets_tuples = [
                    (column_bucket.split("=")[0], column_bucket.split("=")[1]) for column_bucket in b_.split("/")
                ]
                for sub_bucket_col, sub_bucket_name in sub_buckets_tuples:
                    scores_dict[sub_bucket_col][sub_bucket_name].extend(oe_mmmu_style_bucket_scores)

            for key, value in scores_dict.items():
                for k, v in value.items():
                    scores_dict[key][k] = {"accuracy": float(sum(v) / len(v)), "std": np.std(v)}

            if OEVQAMetrics.OE_ONLY_MMMU_STYLE_PER_BUCKET_ACCURACY in self.metrics:
                results["oe_only_mmmu_style_per_bucket_accuracy"] = scores_dict
            elif OEVQAMetrics.OE_MMMU_STYLE_PER_BUCKET_ACCURACY in self.metrics:
                results["oe_mmmu_style_per_bucket_accuracy"] = scores_dict

        if OEVQAMetrics.OE_MMMU_STYLE_VQA_ACCURACY in self.metrics:
            oe_mmmu_style_accuracy_scores = self._compute_mmmu_style_vqa_accuracy(
                generated_texts_unique, answers_unique, normalize_str_mmmu, accept_mcq=True
            )

            results["oe_mmmu_style_vqa_accuracy"] = float(
                sum(oe_mmmu_style_accuracy_scores) / len(oe_mmmu_style_accuracy_scores)
            )
        if OEVQAMetrics.OE_ONLY_MMMU_STYLE_VQA_ACCURACY in self.metrics:
            oe_mmmu_style_accuracy_scores = self._compute_mmmu_style_vqa_accuracy(
                generated_texts_unique, answers_unique, normalize_str_mmmu, accept_mcq=False
            )

            results["oe_only_mmmu_style_vqa_accuracy"] = float(
                sum(oe_mmmu_style_accuracy_scores) / len(oe_mmmu_style_accuracy_scores)
            )

        if OEVQAMetrics.OE_RELAXED_VQA_ACCURACY in self.metrics:
            oe_relaxed_vqa_accuracy = self._compute_relaxed_vqa_accuracy(
                generated_texts_unique, answers_unique, lambda txt: txt.strip(".")
            )

            results["oe_relaxed_vqa_accuracy"] = float(sum(oe_relaxed_vqa_accuracy) / len(oe_relaxed_vqa_accuracy))

        if OEVQAMetrics.FIRST_WORD_VQA_ACCURACY in self.metrics or OEVQAMetrics.OE_VQA_ACCURACY in self.metrics:
            first_word_vqa_accuracy_scores, oe_accuracy_scores = self._compute_vqa_accuracy(
                generated_texts_unique, answers_unique, vqa_normalize_text
            )

            if OEVQAMetrics.FIRST_WORD_VQA_ACCURACY in self.metrics:
                results["first_word_vqa_accuracy"] = float(
                    sum(first_word_vqa_accuracy_scores) / len(first_word_vqa_accuracy_scores)
                )
            if OEVQAMetrics.OE_VQA_ACCURACY in self.metrics:
                results["oe_vqa_accuracy"] = float(sum(oe_accuracy_scores) / len(oe_accuracy_scores))
        return results