in vision/m4/evaluation/custom_metrics/open_ended_vqa_metrics.py [0:0]
def _compute(self, example_ids, generated_texts, answers, buckets):
data_per_id = {}
for ex_id, gen_text, ans, bucket in zip(example_ids, generated_texts, answers, buckets):
# If condition is a dirty trick to handle the case of distributed evaluation where some instances can be
# repeated over a few proceses to make the batches even.
# In this case, we just verify that all processes predicted the same thing, and only take one copy of predictions
# in order to not mess up metrics. Ideally this "unique" logic should be handled outside of the metric or maybe
# in the add_batch call...
if ex_id not in data_per_id:
data_per_id[ex_id] = {
"generated_text": gen_text,
"answers": ans,
"bucket": bucket,
}
else:
if data_per_id[ex_id]["generated_text"] != gen_text:
logger.warning(
f"Example {ex_id} has different predictions accross processes. We have: {gen_text} and"
f" {data_per_id[ex_id]['generated_text']}"
)
if data_per_id[ex_id]["answers"] != ans:
logger.warning(
f"Example {ex_id} has different answers accross processes. We have: {ans} and"
f" {data_per_id[ex_id]['answers']}"
)
# assert list(range(len(data_per_id))) == sorted(data_per_id.keys())
generated_texts_unique = [data_per_id[i]["generated_text"] for i in set(example_ids)]
answers_unique = [data_per_id[i]["answers"] for i in set(example_ids)]
results = {}
default_to_save_generations = (
answers_unique[0] is None or answers_unique[0][0] == ""
) and OEVQAMetrics.DEFAULT_TO_SERVER_RESULTS in self.metrics
if self.save_generations or default_to_save_generations:
# If answers are None, we default to the server results
if (
OEVQAMetrics.OE_MMMU_STYLE_VQA_ACCURACY in self.metrics
or OEVQAMetrics.OE_ONLY_MMMU_STYLE_VQA_ACCURACY in self.metrics
):
results["server_results"] = [
{
"question_id": ex_id,
"answer": data["generated_text"],
}
for ex_id, data in data_per_id.items()
]
else:
results["server_results"] = [
{
"question_id": ex_id,
"answer": self.gt_vision_lab_normalization.vqa_normalize_text(data["generated_text"]),
}
for ex_id, data in data_per_id.items()
]
if default_to_save_generations:
return results
if OEVQAMetrics.DEFAULT_TO_SERVER_RESULTS_MMVET in self.metrics:
results["server_results"] = {ex_id: data["generated_text"] for ex_id, data in data_per_id.items()}
return results
elif OEVQAMetrics.DEFAULT_TO_SERVER_RESULTS_LLAVA_WILD in self.metrics:
results["server_results"] = [
{
"question_id": ex_id,
"answer": data["generated_text"],
}
for ex_id, data in data_per_id.items()
]
return results
# VQA Accuracy
# From "VQA: Visual Question Answering" paper:
# an answer is deemed 100% accurate if at least 3 workers provided that exact answer. 2 Before comparison,
# all responses are made lowercase, numbers converted to digits, and punctuation & articles removed.
if (
OEVQAMetrics.GT_VISION_LAB_FIRST_WORD_VQA_ACCURACY in self.metrics
or OEVQAMetrics.GT_VISION_LAB_OE_VQA_ACCURACY in self.metrics
):
gt_vision_lab_first_word_vqa_accuracy_scores, get_visison_lab_oe_accuracy_scores = (
self._compute_vqa_accuracy(
generated_texts_unique, answers_unique, self.gt_vision_lab_normalization.vqa_normalize_text
)
)
if OEVQAMetrics.GT_VISION_LAB_FIRST_WORD_VQA_ACCURACY in self.metrics:
results["gt_vision_lab_first_word_vqa_accuracy"] = float(
sum(gt_vision_lab_first_word_vqa_accuracy_scores)
/ len(gt_vision_lab_first_word_vqa_accuracy_scores)
)
if OEVQAMetrics.GT_VISION_LAB_OE_VQA_ACCURACY in self.metrics:
results["gt_vision_lab_oe_vqa_accuracy"] = float(
sum(get_visison_lab_oe_accuracy_scores) / len(get_visison_lab_oe_accuracy_scores)
)
confidence_level = 0.95
z_score = stats.norm.ppf((1 + confidence_level) / 2)
std_dev = np.std(get_visison_lab_oe_accuracy_scores)
results["gt_vision_lab_oe_vqa_accuracy_std"] = std_dev
results["gt_vision_lab_oe_vqa_accuracy_margin_of_error"] = z_score * (
std_dev / np.sqrt(len(get_visison_lab_oe_accuracy_scores))
)
if (
OEVQAMetrics.OE_MMMU_STYLE_PER_BUCKET_ACCURACY in self.metrics
or OEVQAMetrics.OE_ONLY_MMMU_STYLE_PER_BUCKET_ACCURACY in self.metrics
):
if (
OEVQAMetrics.OE_MMMU_STYLE_PER_BUCKET_ACCURACY in self.metrics
and OEVQAMetrics.OE_ONLY_MMMU_STYLE_PER_BUCKET_ACCURACY in self.metrics
):
raise ValueError(
"Cannot compute both OE_MMMU_STYLE_PER_BUCKET_ACCURACY and OE_ONLY_MMMU_STYLE_PER_BUCKET_ACCURACY"
" at the same time."
)
# Here each bucket has the form "bucket_col_0=sub_bucket_name_x/bucket_col_1=sub_bucket_name_y/... etc."
buckets_aggregated = [data_per_id[example_id]["bucket"] for example_id in set(example_ids)]
# Get columns and unique buckets
unique_buckets = set(buckets_aggregated)
bucket_columns = [column_buckets.split("=")[0] for column_buckets in buckets_aggregated[0].split("/")]
# Initialize the scores dict
scores_dict = {}
for bucket_column in bucket_columns:
scores_dict[bucket_column] = {}
for unique_bucket in unique_buckets:
column_sub_bucket_names = [column_bucket.split("=")[1] for column_bucket in unique_bucket.split("/")]
for bucket_column, sub_bucket_name in zip(bucket_columns, column_sub_bucket_names):
scores_dict[bucket_column][sub_bucket_name] = []
# Need np array to use .where
generated_texts_unique_np = np.array(generated_texts_unique)
answers_unique_np = np.array(answers_unique, dtype=object)
buckets_aggregated = np.array(buckets_aggregated)
for b_ in unique_buckets:
# Find the positions of the unique_bucket in the buckets_aggregated to compute the scores
bucket_position = np.where(buckets_aggregated == b_)[0]
oe_mmmu_style_bucket_scores = self._compute_mmmu_style_vqa_accuracy(
generated_texts_unique_np[bucket_position],
answers_unique_np[bucket_position],
normalize_str_mmmu,
# Do not accept mcq when using OE_ONLY_MMMU_STYLE_PER_BUCKET_ACCURACY metric
accept_mcq=OEVQAMetrics.OE_ONLY_MMMU_STYLE_PER_BUCKET_ACCURACY not in self.metrics,
)
# Each sub_bucket (column, name) pair from this buckets_aggregated entry
# extends the list of scores of its corresponding entry in the scores_dict
# with the oe_mmmu_style_bucket_scores.
sub_buckets_tuples = [
(column_bucket.split("=")[0], column_bucket.split("=")[1]) for column_bucket in b_.split("/")
]
for sub_bucket_col, sub_bucket_name in sub_buckets_tuples:
scores_dict[sub_bucket_col][sub_bucket_name].extend(oe_mmmu_style_bucket_scores)
for key, value in scores_dict.items():
for k, v in value.items():
scores_dict[key][k] = {"accuracy": float(sum(v) / len(v)), "std": np.std(v)}
if OEVQAMetrics.OE_ONLY_MMMU_STYLE_PER_BUCKET_ACCURACY in self.metrics:
results["oe_only_mmmu_style_per_bucket_accuracy"] = scores_dict
elif OEVQAMetrics.OE_MMMU_STYLE_PER_BUCKET_ACCURACY in self.metrics:
results["oe_mmmu_style_per_bucket_accuracy"] = scores_dict
if OEVQAMetrics.OE_MMMU_STYLE_VQA_ACCURACY in self.metrics:
oe_mmmu_style_accuracy_scores = self._compute_mmmu_style_vqa_accuracy(
generated_texts_unique, answers_unique, normalize_str_mmmu, accept_mcq=True
)
results["oe_mmmu_style_vqa_accuracy"] = float(
sum(oe_mmmu_style_accuracy_scores) / len(oe_mmmu_style_accuracy_scores)
)
if OEVQAMetrics.OE_ONLY_MMMU_STYLE_VQA_ACCURACY in self.metrics:
oe_mmmu_style_accuracy_scores = self._compute_mmmu_style_vqa_accuracy(
generated_texts_unique, answers_unique, normalize_str_mmmu, accept_mcq=False
)
results["oe_only_mmmu_style_vqa_accuracy"] = float(
sum(oe_mmmu_style_accuracy_scores) / len(oe_mmmu_style_accuracy_scores)
)
if OEVQAMetrics.OE_RELAXED_VQA_ACCURACY in self.metrics:
oe_relaxed_vqa_accuracy = self._compute_relaxed_vqa_accuracy(
generated_texts_unique, answers_unique, lambda txt: txt.strip(".")
)
results["oe_relaxed_vqa_accuracy"] = float(sum(oe_relaxed_vqa_accuracy) / len(oe_relaxed_vqa_accuracy))
if OEVQAMetrics.FIRST_WORD_VQA_ACCURACY in self.metrics or OEVQAMetrics.OE_VQA_ACCURACY in self.metrics:
first_word_vqa_accuracy_scores, oe_accuracy_scores = self._compute_vqa_accuracy(
generated_texts_unique, answers_unique, vqa_normalize_text
)
if OEVQAMetrics.FIRST_WORD_VQA_ACCURACY in self.metrics:
results["first_word_vqa_accuracy"] = float(
sum(first_word_vqa_accuracy_scores) / len(first_word_vqa_accuracy_scores)
)
if OEVQAMetrics.OE_VQA_ACCURACY in self.metrics:
results["oe_vqa_accuracy"] = float(sum(oe_accuracy_scores) / len(oe_accuracy_scores))
return results