search/mteb/dbpedia/track.py (158 lines of code) (raw):

import csv import json import logging import os from collections import defaultdict from typing import Dict Qrels = Dict[str, Dict[str, int]] Results = Dict[str, Dict[str, float]] logger = logging.getLogger(__name__) def calc_ndcg(qrels: Qrels, results: Results, k_list: list): import pytrec_eval as pe for qid, rels in results.items(): for pid in list(rels): if qid == pid: results[qid].pop(pid) scores = defaultdict(float) metrics = ["ndcg_cut"] pytrec_strings = {f"{metric}.{','.join([str(k) for k in k_list])}" for metric in metrics} evaluator = pe.RelevanceEvaluator(qrels, pytrec_strings) pytrec_scores = evaluator.evaluate(results) for query_id in pytrec_scores.keys(): for metric in metrics: for k in k_list: scores[f"{metric}@{k}"] += pytrec_scores[query_id][f"{metric}_{k}"] queries_count = len(pytrec_scores.keys()) if queries_count == 0: return scores for metric in metrics: for k in k_list: scores[f"{metric}@{k}"] = float(scores[f"{metric}@{k}"] / queries_count) return scores def read_qrels(qrels_input_file): qrels = defaultdict(dict) with open(qrels_input_file, "r") as input_file: tsv_reader = csv.reader(input_file, delimiter="\t") next(tsv_reader) # Skip column names for query_id, doc_id, score in tsv_reader: qrels[query_id][doc_id] = int(score) return qrels def generate_bm25_query(text_field, query, boost=1.0): return {"query": {"match": {f"{text_field}": {"query": query, "boost": boost}}}} def generate_query(query_string, title_field, title_boost, text_field, text_boost): return { "query": { "multi_match": { "minimum_should_match": "1<-1 3<49%", "type": "best_fields", "fuzziness": "AUTO", "prefix_length": 2, "query": query_string, "fields": [f"{title_field}^{title_boost}", f"{text_field}^{text_boost}"], } } } class QueryParamsSource: def __init__(self, track, params, **kwargs): # choose a suitable index: if there is only one defined for this track # choose that one, but let the user always override index if len(track.indices) == 1: default_index = track.indices[0].name else: default_index = "_all" self._index_name = params.get("index", default_index) self._cache = params.get("cache", False) self._size = params.get("size", 10) self._title_field = params.get("title_field", "title") self._text_field = params.get("text_field", "text") self._title_boost = params.get("title_boost", 5) self._text_boost = params.get("text_boost", 1) self._query_file = params.get("query_source", "queries.json") self._track_total_hits = params.get("track_total_hits", False) self._params = params self.infinite = True cwd = os.path.dirname(__file__) with open(os.path.join(cwd, self._query_file), "r") as file: self._queries = json.load(file) self._iters = 0 def partition(self, partition_index, total_partitions): return self def params(self): query_obj = self._queries[self._iters] query = generate_query(query_obj["text"], self._title_field, self._title_boost, self._text_field, self._text_boost) query["track_total_hits"] = self._track_total_hits query["size"] = self._size self._iters = (self._iters + 1) % len(self._queries) return { "index": self._index_name, "cache": self._cache, "body": query, } class RelevanceParamsSource: def __init__(self, track, params, **kwargs): # choose a suitable index: if there is only one defined for this track # choose that one, but let the user always override index if len(track.indices) == 1: default_index = track.indices[0].name else: default_index = "_all" self._index_name = params.get("index", default_index) self._cache = params.get("cache", False) self._size = params.get("size", 10) self._title_field = params.get("title_field", "title") self._text_field = params.get("text_field", "text") self._title_boost = params.get("title_boost", 5) self._text_boost = params.get("text_boost", 1) self._query_file = params.get("query_source", "queries.json") self._qrels_file = params.get("qrels_source", "test.tsv") self._params = params self.infinite = True cwd = os.path.dirname(__file__) with open(os.path.join(cwd, self._query_file), "r") as file: self._queries = json.load(file) self._iters = 0 cwd = os.path.dirname(__file__) with open(os.path.join(cwd, self._query_file), "r") as file: self._queries = json.load(file) self._qrels = read_qrels(os.path.join(cwd, self._qrels_file)) def partition(self, partition_index, total_partitions): return self def params(self): return { "index": self._index_name, "cache": self._cache, "size": self._size, "queries": self._queries, "qrels": self._qrels, "text_field": self._text_field, "title_field": self._title_field, "text_boost": self._text_boost, "title_boost": self._title_boost, } class TextSearchRelevanceRunner: async def __call__(self, es, params): text_search_results = defaultdict(dict) for query in params["queries"]: query_id = query["_id"] text_search_query = generate_query( query["text"], params["title_field"], params["title_boost"], params["text_field"], params["text_boost"] ) text_search_result = await es.search( body=text_search_query, index=params["index"], request_cache=params["cache"], size=params["size"], ) text_search_hits = {hit["_source"]["id"]: hit["_score"] for hit in text_search_result["hits"]["hits"]} # Construct input to NDCG calculation based on returned hits for doc_id, score in text_search_hits.items(): text_search_results[query_id][doc_id] = score text_search_relevance = calc_ndcg(params["qrels"], text_search_results, [10, 100]) logger.debug( f'text_search_relevance_10 = {text_search_relevance["ndcg_cut@10"]}, text_search_relevance_100 = {text_search_relevance["ndcg_cut@100"]}' ) return { "text_search_relevance_10": text_search_relevance["ndcg_cut@10"], "text_search_relevance_100": text_search_relevance["ndcg_cut@100"], } def __repr__(self, *args, **kwargs): return "text_search_relevance" def register(registry): registry.register_param_source("query_param_source", QueryParamsSource) registry.register_param_source("relevance_param_source", RelevanceParamsSource) registry.register_runner("text_search_relevance", TextSearchRelevanceRunner(), async_runner=True)