def main()

in dpr_scale/msmarco_eval.py [0:0]


def main():
    """Command line:
    python msmarco_eval_ranking.py <path_to_reference_file> <path_to_candidate_file>
    """

    if len(sys.argv) == 3:
        path_to_reference = sys.argv[1]
        path_to_candidate = sys.argv[2]
        metrics = compute_metrics_from_files(path_to_reference, path_to_candidate)
        print('#####################')
        for metric in sorted(metrics):
            print('{}: {}'.format(metric, metrics[metric]))
        print('#####################')
        print("pytrec eval")
        evaluator = pytrec_eval.RelevanceEvaluator(
            load_reference_for_trec_eval(path_to_reference),
            {'map_cut', 'ndcg_cut', 'recip_rank', 'recall_20', 'recall_50', 'recall_100', 'recall_1000'},
        )
        result = evaluator.evaluate(load_candidate_for_trec_eval(path_to_candidate))
        eval_query_cnt = 0
        ndcg = 0
        Map = 0
        mrr = 0
        recalls = Counter()
        for k in result.keys():
            eval_query_cnt += 1
            ndcg += result[k]["ndcg_cut_10"]
            Map += result[k]["map_cut_10"]
            mrr += result[k]["recip_rank"]
            for topk in [20, 50, 100, 1000]:
                recalls[topk] += result[k][f"recall_{topk}"]
        final_ndcg = ndcg / eval_query_cnt
        final_Map = Map / eval_query_cnt
        final_mrr = mrr / eval_query_cnt
        final_recalls = {}
        for topk in [20, 50, 100, 1000]:
            final_recalls[topk] = recalls[topk] / eval_query_cnt
        print("NDCG@10:" + str(final_ndcg))
        print("map@10:" + str(final_Map))
        print("pytrec_mrr:" + str(final_mrr))
        for topk in [20, 50, 100, 1000]:
            print(f"recall@{topk}"+":" + str(final_recalls[topk]))
    else:
        print('Usage: msmarco_eval_ranking.py <reference ranking> <candidate ranking>')
        exit()