def compute()

in kilt/eval_retrieval.py [0:0]


def compute(gold_dataset, guess_dataset, ks, rank_keys):

    ks = sorted([int(x) for x in ks])

    result = OrderedDict()
    result["Rprec"] = 0.0
    result["entity_in_input"] = 0.0
    for k in ks:
        if k > 0:
            result["precision@{}".format(k)] = 0.0
            result["answer_in_context@{}".format(k)] = 0.0
            result["answer_and_ent_in_context@{}".format(k)] = 0.0
        if k > 1:
            result["recall@{}".format(k)] = 0.0
            result["success_rate@{}".format(k)] = 0.0

    assert len(guess_dataset) == len(
        gold_dataset
    ), "different size gold: {} guess: {}".format(len(guess_dataset), len(gold_dataset))

    for gold, guess in zip(guess_dataset, gold_dataset):
        assert (
            str(gold["id"]).strip() == str(guess["id"]).strip()
        ), "Items must have same order with same IDs"

    for guess_item, gold_item in zip(guess_dataset, gold_dataset):
        ranking_metrics = get_ranking_metrics(guess_item, gold_item, ks, rank_keys)
        result["Rprec"] += ranking_metrics["Rprec"]
        result["entity_in_input"] += ranking_metrics["entity_in_input"]
        for k in ks:
            if k > 0:
                result["precision@{}".format(k)] += ranking_metrics[
                    "precision@{}".format(k)
                ]
                result["answer_in_context@{}".format(k)] += ranking_metrics[
                    "answer_in_context@{}".format(k)
                ]
                result["answer_and_ent_in_context@{}".format(k)] += ranking_metrics[
                    "answer_and_ent_in_context@{}".format(k)
                ]
            if k > 1:
                result["recall@{}".format(k)] += ranking_metrics["recall@{}".format(k)]
                result["success_rate@{}".format(k)] += ranking_metrics[
                    "success_rate@{}".format(k)
                ]
    if len(guess_dataset) > 0:
        result["Rprec"] /= len(guess_dataset)
        result["entity_in_input"] /= len(guess_dataset)
        for k in ks:
            if k > 0:
                result["precision@{}".format(k)] /= len(guess_dataset)
                result["answer_in_context@{}".format(k)] /= len(guess_dataset)
                result["answer_and_ent_in_context@{}".format(k)] /= len(guess_dataset)
            if k > 1:
                result["recall@{}".format(k)] /= len(guess_dataset)
                result["success_rate@{}".format(k)] /= len(guess_dataset)

    return result