def evaluate_model()

in src/beir_utils.py [0:0]


def evaluate_model(
        query_encoder, 
        doc_encoder, 
        tokenizer, 
        dataset, 
        batch_size=128, 
        add_special_tokens=True, 
        norm_query=False, 
        norm_doc=False, 
        is_main=True, 
        split='test', 
        metric='dot',
        beir_data_path="BEIR/datasets",
    ):

    if hasattr(query_encoder, "module"):
        query_encoder = query_encoder.module
    query_encoder.eval()

    if doc_encoder is not None:
        if hasattr(doc_encoder, "module"):
            doc_encoder = doc_encoder.module
        doc_encoder.eval()
    else:
        doc_encoder = query_encoder
    
    dmodel = DenseRetrievalExactSearch(
        DenseEncoderModel(
            query_encoder=query_encoder, 
            doc_encoder=doc_encoder, 
            tokenizer=tokenizer, 
            add_special_tokens=add_special_tokens, 
            norm_query=norm_query, 
            norm_doc=norm_doc
        ), 
        batch_size=128
    )
    retriever = EvaluateRetrieval(dmodel, score_function=metric) 
    url = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip".format(dataset)
    data_path = beir.util.download_and_unzip(url, beir_data_path)
    if dataset == 'cqadupstack':
        ndcgs, _maps, recalls, precisions, mrrs, recall_caps, holes = [], [], [], [], [], [], []
        cqasubsets = [
            'android', 
            'english', 
            'gaming', 
            'gis', 
            'mathematica', 
            'physics', 
            'programmers', 
            'stats', 
            'tex', 
            'unix', 
            'webmasters', 
            'wordpress'
        ]
        for sub in cqasubsets:
            data_folder = os.path.join(data_path, sub)
            corpus, queries, qrels = GenericDataLoader(data_folder=data_folder).load(split=split)
            results = retriever.retrieve(corpus, queries)
            if is_main:
                ndcg, _map, recall, precision = retriever.evaluate(qrels, results, retriever.k_values)
                mrr = retriever.evaluate_custom(qrels, results, retriever.k_values, metric="mrr")
                recall_cap = retriever.evaluate_custom(qrels, results, retriever.k_values, metric="recall_cap")
                hole = retriever.evaluate_custom(qrels, results, retriever.k_values, metric="hole")
                ndcgs.append(ndcg)
                _maps.append(_map)
                recalls.append(recall)
                precisions.append(precision)
                mrrs.append(mrr)
                recall_caps.append(recall_cap)
                holes.append(hole)
        if is_main:
            ndcg = {key: sum(item.get(key) for item in ndcgs) / 12 for key in ndcgs[0]}
            _map = {key: sum(item.get(key) for item in _maps) / 12 for key in _maps[0]}
            recall = {key: sum(item.get(key) for item in recalls) / 12 for key in recalls[0]}
            precision = {key: sum(item.get(key) for item in precisions) / 12 for key in precisions[0]}
            mrr = {key: sum(item.get(key) for item in mrrs) / 12 for key in mrrs[0]}
            recall_cap = {key: sum(item.get(key) for item in recall_caps) / 12 for key in recall_caps[0]}
            holes = {key: sum(item.get(key) for item in holes) / 12 for key in holes[0]}
        else:
            ndcg, _map, recall, precision = None, None, None, None
            mrr, recall_cap, hole = None, None, None
    else:
        corpus, queries, qrels = GenericDataLoader(data_folder=data_path).load(split=split)
        results = retriever.retrieve(corpus, queries)
        if is_main:
            ndcg, _map, recall, precision = retriever.evaluate(qrels, results, retriever.k_values)
            mrr = retriever.evaluate_custom(qrels, results, retriever.k_values, metric="mrr")
            recall_cap = retriever.evaluate_custom(qrels, results, retriever.k_values, metric="recall_cap")
            hole = retriever.evaluate_custom(qrels, results, retriever.k_values, metric="hole")
        else:
            ndcg, _map, recall, precision = None, None, None, None
            mrr, recall_cap, hole = None, None, None
    return ndcg, _map, recall, precision, mrr, recall_cap, hole