in kilt/eval_downstream.py [0:0]
def _calculate_metrics(gold_records, guess_records):
assert len(gold_records) == len(
guess_records
), "different size gold: {} guess: {}".format(len(gold_records), len(guess_records))
total_count = 0
# downstream metrics
accuracy = 0
normalized_em = 0
normalized_f1 = 0
rougel = 0
# kilt metrics
kilt_accuracy = 0
kilt_em = 0
kilt_f1 = 0
kilt_rougel = 0
for guess_item, gold_item in zip(guess_records, gold_records):
# check ids
assert (
str(gold_item["id"]).strip() == str(guess_item["id"]).strip()
), "Items must have same order with same IDs"
total_count += 1
# check if each output of guess file exist in set of candidate answers
gold_candidate_answers = get_gold_answers(gold_item)
conditions = (len(guess_item["output"]) == 1) and (
"answer" in guess_item["output"][0]
)
assert (
conditions
), f"you should provide exactly one valid answer for {guess_item['id']}"
guess_answer = str(guess_item["output"][0]["answer"]).strip()
if len(guess_answer) == 0:
# empty answer
continue
# 0. accuracy = strict exact match
local_accuracy = 0
if guess_answer in gold_candidate_answers:
local_accuracy = 1
accuracy += local_accuracy
# 1. normalized exact match
local_em = _metric_max_over_ground_truths(
_exact_match_score, guess_answer, gold_candidate_answers
)
normalized_em += local_em
# 2. normalized f1
local_f1 = _metric_max_over_ground_truths(
_f1_score, guess_answer, gold_candidate_answers
)
normalized_f1 += local_f1
# 3. rougel
local_rougel = _metric_max_over_ground_truths(
_rougel_score, guess_answer, gold_candidate_answers
)
rougel += local_rougel
# KILT-metrics
Rprec = retrieval_metrics.rprecision(
guess_item, gold_item, rank_keys=["wikipedia_id"]
)
if Rprec == 1:
# 1. KILT-AC
kilt_accuracy += local_accuracy
# 2. KILT-EM
kilt_em += local_em
# 3. KILT-F1
kilt_f1 += local_f1
# 4. KILT-RL
kilt_rougel += local_rougel
if total_count > 0:
accuracy /= total_count
normalized_em /= total_count
normalized_f1 /= total_count
rougel /= total_count
kilt_accuracy /= total_count
kilt_em /= total_count
kilt_f1 /= total_count
kilt_rougel /= total_count
return {
"kilt": {
"KILT-accuracy": kilt_accuracy,
"KILT-em": kilt_em,
"KILT-f1": kilt_f1,
"KILT-rougel": kilt_rougel,
},
"downstream": {
"accuracy": accuracy,
"em": normalized_em,
"f1": normalized_f1,
"rougel": rougel,
},
}