in kilt/eval_retrieval.py [0:0]
def compute(gold_dataset, guess_dataset, ks, rank_keys):
ks = sorted([int(x) for x in ks])
result = OrderedDict()
result["Rprec"] = 0.0
result["entity_in_input"] = 0.0
for k in ks:
if k > 0:
result["precision@{}".format(k)] = 0.0
result["answer_in_context@{}".format(k)] = 0.0
result["answer_and_ent_in_context@{}".format(k)] = 0.0
if k > 1:
result["recall@{}".format(k)] = 0.0
result["success_rate@{}".format(k)] = 0.0
assert len(guess_dataset) == len(
gold_dataset
), "different size gold: {} guess: {}".format(len(guess_dataset), len(gold_dataset))
for gold, guess in zip(guess_dataset, gold_dataset):
assert (
str(gold["id"]).strip() == str(guess["id"]).strip()
), "Items must have same order with same IDs"
for guess_item, gold_item in zip(guess_dataset, gold_dataset):
ranking_metrics = get_ranking_metrics(guess_item, gold_item, ks, rank_keys)
result["Rprec"] += ranking_metrics["Rprec"]
result["entity_in_input"] += ranking_metrics["entity_in_input"]
for k in ks:
if k > 0:
result["precision@{}".format(k)] += ranking_metrics[
"precision@{}".format(k)
]
result["answer_in_context@{}".format(k)] += ranking_metrics[
"answer_in_context@{}".format(k)
]
result["answer_and_ent_in_context@{}".format(k)] += ranking_metrics[
"answer_and_ent_in_context@{}".format(k)
]
if k > 1:
result["recall@{}".format(k)] += ranking_metrics["recall@{}".format(k)]
result["success_rate@{}".format(k)] += ranking_metrics[
"success_rate@{}".format(k)
]
if len(guess_dataset) > 0:
result["Rprec"] /= len(guess_dataset)
result["entity_in_input"] /= len(guess_dataset)
for k in ks:
if k > 0:
result["precision@{}".format(k)] /= len(guess_dataset)
result["answer_in_context@{}".format(k)] /= len(guess_dataset)
result["answer_and_ent_in_context@{}".format(k)] /= len(guess_dataset)
if k > 1:
result["recall@{}".format(k)] /= len(guess_dataset)
result["success_rate@{}".format(k)] /= len(guess_dataset)
return result