in entity_linking.py [0:0]
def calculate_prf_one_group_entity(all_gts, all_pds, raw_pds, distinct_slot_values):
common = all_gts.intersection(all_pds)
try:
precision = len(common) / len(all_pds)
except:
precision = None
try:
recall = len(common) / len(all_gts)
except:
recall = None
try:
f1 = 2 * precision * recall / (precision + recall)
except:
f1 = None
none_gts = set([t for t in all_gts if t[-1] is None])
try:
none_recall = len(none_gts.intersection(all_pds)) / len(none_gts)
except:
none_recall = None
all_gts_dict = {f'{t[0]}-{t[1]}': t[-1] for t in all_gts}
all_pds_dict = {f'{t[0]}-{t[1]}': t[-1] for t in all_pds}
raw_pds_dict = {f'{t[0]}-{t[1]}': list(t)[2:] for t in raw_pds}
common_slots = set(all_gts_dict.keys()).intersection(set(all_pds_dict.keys()))
try:
accuracy = sum([1 for s in common_slots if all_gts_dict[s] == all_pds_dict[s]]) / len(common_slots)
except:
accuracy = None
link_accuracy_at = dict()
count_at = {n: 0 for n in [2, 3, 5, 10]}
if len(common_slots) == 0:
for n in [2, 3, 5, 10]:
link_accuracy_at[n] = None
else:
for s in common_slots:
sorted_entities = entity_sorting(raw_pds_dict[s][0], distinct_slot_values[raw_pds_dict[s][-1]],
raw_pds_dict[s][1], threshold=raw_pds_dict[s][2],
slot_name=raw_pds_dict[s][-1])
for n in [2, 3, 5, 10]:
if n >= len(distinct_slot_values[raw_pds_dict[s][-1]]) + 1:
count_at[n] += 1
else:
top_n = sorted_entities[:min(n, len(sorted_entities))]
if all_gts_dict[s] in top_n:
count_at[n] += 1
for n in [2, 3, 5, 10]:
link_accuracy_at[n] = count_at[n] / len(common_slots)
return {'none_recall': none_recall,
'link_accuracy': accuracy,
'link_accuracy_at_2': link_accuracy_at[2],
'link_accuracy_at_3': link_accuracy_at[3],
'link_accuracy_at_5': link_accuracy_at[5],
'link_accuracy_at_10': link_accuracy_at[10],
'precision': precision,
'recall': recall,
'f1': f1}