in kilt/retrievers/BLINK_connector.py [0:0]
def run(self):
(
biencoder_accuracy,
recall_at,
crossencoder_normalized_accuracy,
overall_unormalized_accuracy,
num_datapoints,
predictions,
scores,
) = main_dense.run(
self.args, self.logger, *self.models, test_data=self.test_data
)
# aggregate multiple records for the same datapoint
print("aggregate multiple records for the same datapoint", flush=True)
id_2_results = {}
for r, p, s in zip(self.test_data, predictions, scores):
if r["id"] not in id_2_results:
id_2_results[r["id"]] = {"predictions": [], "scores": []}
id_2_results[r["id"]]["predictions"].extend(p)
id_2_results[r["id"]]["scores"].extend(s)
provenance = {}
for id, results in id_2_results.items():
element = []
# merge predictions when multiple entities are found
sorted_titles = []
sorted_scores = []
for y, x in sorted(
zip(results["scores"], results["predictions"]), reverse=True
):
if x not in sorted_titles:
sorted_titles.append(x)
sorted_scores.append(y)
local_doc_id = []
for e_title, score in zip(sorted_titles, sorted_scores):
if e_title not in self.Wikipedia_title2id:
print(
"WARNING: title: {} not recognized".format(e_title), flush=True
)
else:
"""
if e_title in self.cache_pages:
page = self.cache_pages[e_title]
else:
page = self.ks.get_page_by_title(e_title)
self.cache_pages[e_title] = page
wikipedia_id = page["wikipedia_id"]
"""
wikipedia_id = self.Wikipedia_title2id[e_title]
element.append(
{
"score": str(score),
# "text": page["text"],
"wikipedia_title": str(e_title),
"wikipedia_id": str(wikipedia_id),
}
)
provenance[id] = element
return provenance