in kilt/eval_retrieval.py [0:0]
def get_rank(guess_item, gold_item, k, rank_keys, verbose=False):
"""
The main idea is to consider each evidence set as a single point in the rank.
The score in the rank for an evidence set is given by the lowest scored evidence in the set.
"""
assert k > 0, "k must be a positive integer grater than 0."
rank = []
num_distinct_evidence_sets = 0
guess_ids = _get_ids_list(guess_item, rank_keys)[0]
if guess_ids and len(guess_ids) > 0:
# 1. collect evidence sets and their sizes
evidence_sets = []
e_size = defaultdict(int)
for output in gold_item["output"]:
if "provenance" in output:
e_set = {
"+".join(
[
str(provenance[rank_key]).strip()
for rank_key in rank_keys
if rank_key in provenance
]
)
for provenance in output["provenance"]
}
if e_set not in evidence_sets: # no duplicate evidence set
evidence_sets.append(e_set)
e_size[len(e_set)] += 1
num_distinct_evidence_sets = len(evidence_sets)
# 2. check what's the minimum number of predicted pages needed to get a robust P/R@k
min_prediction_size = 0
c = 0
for size, freq in sorted(e_size.items(), reverse=True):
for _ in range(freq):
min_prediction_size += size
c += 1
if c == k:
break
if c == k:
break
# if the number of evidence sets is smaller than k
min_prediction_size += k - c
if verbose and len(guess_ids) < min_prediction_size:
print(
f"WARNING: you should provide at least {min_prediction_size} provenance items for a robust recall@{k} computation (you provided {len(guess_ids)} item(s))."
)
# 3. rank by gruping pages in each evidence set (each evidence set count as 1),
# the position in the rank of each evidence set is given by the last page in guess_ids
# non evidence pages counts as 1
rank = []
for guess_id in guess_ids:
guess_id = str(guess_id).strip()
found = False
for idx, e_set in enumerate(evidence_sets):
e_set_id = f"evidence_set:{idx}"
if guess_id in e_set:
found = True
# remove from the rank previous points referring to this evidence set
if e_set_id in rank:
rank.remove(e_set_id)
# remove the guess_id from the evidence set
e_set.remove(guess_id)
if len(e_set) == 0:
# it was the last evidence, it counts as true in the rank
rank.append(True)
else:
# add a point for this partial evidence set
rank.append(e_set_id)
if not found:
rank.append(False)
return rank, num_distinct_evidence_sets