in src/beir_utils.py [0:0]
def evaluate_model(
query_encoder,
doc_encoder,
tokenizer,
dataset,
batch_size=128,
add_special_tokens=True,
norm_query=False,
norm_doc=False,
is_main=True,
split='test',
metric='dot',
beir_data_path="BEIR/datasets",
):
if hasattr(query_encoder, "module"):
query_encoder = query_encoder.module
query_encoder.eval()
if doc_encoder is not None:
if hasattr(doc_encoder, "module"):
doc_encoder = doc_encoder.module
doc_encoder.eval()
else:
doc_encoder = query_encoder
dmodel = DenseRetrievalExactSearch(
DenseEncoderModel(
query_encoder=query_encoder,
doc_encoder=doc_encoder,
tokenizer=tokenizer,
add_special_tokens=add_special_tokens,
norm_query=norm_query,
norm_doc=norm_doc
),
batch_size=128
)
retriever = EvaluateRetrieval(dmodel, score_function=metric)
url = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip".format(dataset)
data_path = beir.util.download_and_unzip(url, beir_data_path)
if dataset == 'cqadupstack':
ndcgs, _maps, recalls, precisions, mrrs, recall_caps, holes = [], [], [], [], [], [], []
cqasubsets = [
'android',
'english',
'gaming',
'gis',
'mathematica',
'physics',
'programmers',
'stats',
'tex',
'unix',
'webmasters',
'wordpress'
]
for sub in cqasubsets:
data_folder = os.path.join(data_path, sub)
corpus, queries, qrels = GenericDataLoader(data_folder=data_folder).load(split=split)
results = retriever.retrieve(corpus, queries)
if is_main:
ndcg, _map, recall, precision = retriever.evaluate(qrels, results, retriever.k_values)
mrr = retriever.evaluate_custom(qrels, results, retriever.k_values, metric="mrr")
recall_cap = retriever.evaluate_custom(qrels, results, retriever.k_values, metric="recall_cap")
hole = retriever.evaluate_custom(qrels, results, retriever.k_values, metric="hole")
ndcgs.append(ndcg)
_maps.append(_map)
recalls.append(recall)
precisions.append(precision)
mrrs.append(mrr)
recall_caps.append(recall_cap)
holes.append(hole)
if is_main:
ndcg = {key: sum(item.get(key) for item in ndcgs) / 12 for key in ndcgs[0]}
_map = {key: sum(item.get(key) for item in _maps) / 12 for key in _maps[0]}
recall = {key: sum(item.get(key) for item in recalls) / 12 for key in recalls[0]}
precision = {key: sum(item.get(key) for item in precisions) / 12 for key in precisions[0]}
mrr = {key: sum(item.get(key) for item in mrrs) / 12 for key in mrrs[0]}
recall_cap = {key: sum(item.get(key) for item in recall_caps) / 12 for key in recall_caps[0]}
holes = {key: sum(item.get(key) for item in holes) / 12 for key in holes[0]}
else:
ndcg, _map, recall, precision = None, None, None, None
mrr, recall_cap, hole = None, None, None
else:
corpus, queries, qrels = GenericDataLoader(data_folder=data_path).load(split=split)
results = retriever.retrieve(corpus, queries)
if is_main:
ndcg, _map, recall, precision = retriever.evaluate(qrels, results, retriever.k_values)
mrr = retriever.evaluate_custom(qrels, results, retriever.k_values, metric="mrr")
recall_cap = retriever.evaluate_custom(qrels, results, retriever.k_values, metric="recall_cap")
hole = retriever.evaluate_custom(qrels, results, retriever.k_values, metric="hole")
else:
ndcg, _map, recall, precision = None, None, None, None
mrr, recall_cap, hole = None, None, None
return ndcg, _map, recall, precision, mrr, recall_cap, hole