def calculate_prf_one_group_entity()

in entity_linking.py [0:0]


def calculate_prf_one_group_entity(all_gts, all_pds, raw_pds, distinct_slot_values):
    common = all_gts.intersection(all_pds)
    try:
        precision = len(common) / len(all_pds)
    except:
        precision = None
    try:
        recall = len(common) / len(all_gts)
    except:
        recall = None
    try:
        f1 = 2 * precision * recall / (precision + recall)
    except:
        f1 = None

    none_gts = set([t for t in all_gts if t[-1] is None])
    try:
        none_recall = len(none_gts.intersection(all_pds)) / len(none_gts)
    except:
        none_recall = None

    all_gts_dict = {f'{t[0]}-{t[1]}': t[-1] for t in all_gts}
    all_pds_dict = {f'{t[0]}-{t[1]}': t[-1] for t in all_pds}
    raw_pds_dict = {f'{t[0]}-{t[1]}': list(t)[2:] for t in raw_pds}
    common_slots = set(all_gts_dict.keys()).intersection(set(all_pds_dict.keys()))
    try:
        accuracy = sum([1 for s in common_slots if all_gts_dict[s] == all_pds_dict[s]]) / len(common_slots)
    except:
        accuracy = None

    link_accuracy_at = dict()
    count_at = {n: 0 for n in [2, 3, 5, 10]}
    if len(common_slots) == 0:
        for n in [2, 3, 5, 10]:
            link_accuracy_at[n] = None
    else:
        for s in common_slots:
            sorted_entities = entity_sorting(raw_pds_dict[s][0], distinct_slot_values[raw_pds_dict[s][-1]],
                                             raw_pds_dict[s][1], threshold=raw_pds_dict[s][2],
                                             slot_name=raw_pds_dict[s][-1])
            for n in [2, 3, 5, 10]:
                if n >= len(distinct_slot_values[raw_pds_dict[s][-1]]) + 1:
                    count_at[n] += 1
                else:
                    top_n = sorted_entities[:min(n, len(sorted_entities))]
                    if all_gts_dict[s] in top_n:
                        count_at[n] += 1
        for n in [2, 3, 5, 10]:
            link_accuracy_at[n] = count_at[n] / len(common_slots)

    return {'none_recall': none_recall,
            'link_accuracy': accuracy,
            'link_accuracy_at_2': link_accuracy_at[2],
            'link_accuracy_at_3': link_accuracy_at[3],
            'link_accuracy_at_5': link_accuracy_at[5],
            'link_accuracy_at_10': link_accuracy_at[10],
            'precision': precision,
            'recall': recall,
            'f1': f1}