msmarco-passage-ranking/track.py (250 lines of code) (raw):

import csv import json import os from collections import defaultdict from typing import Dict Qrels = Dict[str, Dict[str, int]] Results = Dict[str, Dict[str, float]] def calc_ndcg(qrels: Qrels, results: Results, k_list: list): import pytrec_eval as pe for qid, rels in results.items(): for pid in list(rels): if qid == pid: results[qid].pop(pid) scores = defaultdict(float) metrics = ["ndcg_cut"] pytrec_strings = {f"{metric}.{','.join([str(k) for k in k_list])}" for metric in metrics} evaluator = pe.RelevanceEvaluator(qrels, pytrec_strings) pytrec_scores = evaluator.evaluate(results) for query_id in pytrec_scores.keys(): for metric in metrics: for k in k_list: scores[f"{metric}@{k}"] += pytrec_scores[query_id][f"{metric}_{k}"] queries_count = len(pytrec_scores.keys()) if queries_count == 0: return scores for metric in metrics: for k in k_list: scores[f"{metric}@{k}"] = float(scores[f"{metric}@{k}"] / queries_count) return scores def read_qrels(qrels_input_file): qrels = defaultdict(dict) with open(qrels_input_file, "r") as input_file: tsv_reader = csv.reader(input_file, delimiter="\t") for query_id, doc_id, score in tsv_reader: qrels[query_id][doc_id] = int(score) return qrels def generate_weighted_terms_query(text_expansion_field, query_expansion, boost=1.0): return { "query": { "bool": { "should": [ {"term": {f"{text_expansion_field}": {"value": f"{key}", "boost": value}}} for key, value in query_expansion.items() ], "boost": boost, } } } def generate_bm25_query(text_field, query, boost=1.0): return {"query": {"match": {f"{text_field}": {"query": query, "boost": boost}}}} def generate_combine_bm25_weighted_terms_query( text_field, text_expansion_field, query, query_boost, query_expansion, query_expansion_boost ): return { "query": { "bool": { "should": [ generate_bm25_query(text_field, query, query_boost)["query"], generate_weighted_terms_query(text_expansion_field, query_expansion, query_expansion_boost)["query"], ] } } } def generate_pruned_query(field, query_expansion, boost=1.0): return {"query": {"sparse_vector": {"field": field, "query_vector": query_expansion, "prune": True, "boost": boost}}} def generate_rescored_pruned_query(field, query_expansion, num_candidates, boost=1.0): return { "query": {"sparse_vector": {"field": field, "query_vector": query_expansion, "prune": True, "boost": boost}}, "rescore": { "window_size": num_candidates, "query": { "rescore_query": { "sparse_vector": { "field": field, "query_vector": query_expansion, "prune": True, "pruning_config": { "only_score_pruned_tokens": True, }, "boost": boost, } } }, }, } class QueryParamsSource: def __init__(self, track, params, **kwargs): # choose a suitable index: if there is only one defined for this track # choose that one, but let the user always override index if len(track.indices) == 1: default_index = track.indices[0].name else: default_index = "_all" self._index_name = params.get("index", default_index) self._cache = params.get("cache", False) self._size = params.get("size", 10) self._num_candidates = params.get("num_candidates", 10) self._text_field = params.get("text_field", "text") self._text_expansion_field = params.get("text_expansion_field", "text_expansion_elser") self._query_file = params.get("query_source", "queries.json") self._query_strategy = params.get("query_strategy", "bm25") self._track_total_hits = params.get("track_total_hits", False) self._rescore = params.get("rescore", False) self._prune = params.get("prune", False) self._params = params self.infinite = True cwd = os.path.dirname(__file__) with open(os.path.join(cwd, self._query_file), "r") as file: self._queries = json.load(file) self._iters = 0 def partition(self, partition_index, total_partitions): return self def params(self): query_obj = self._queries[self._iters] if self._query_strategy == "bm25": query = generate_bm25_query(text_field=self._text_field, query=query_obj["query"], boost=1) elif self._query_strategy == "text_expansion": if self._prune is False: query = generate_weighted_terms_query( text_expansion_field=self._text_expansion_field, query_expansion=query_obj[self._text_expansion_field], boost=1 ) elif self._rescore is True: query = generate_rescored_pruned_query( field=self._text_expansion_field, query_expansion=query_obj[self._text_expansion_field], num_candidates=self._num_candidates, boost=1, ) else: query = generate_pruned_query( field=self._text_expansion_field, query_expansion=query_obj[self._text_expansion_field], boost=1 ) elif self._query_strategy == "hybrid": query = generate_combine_bm25_weighted_terms_query( self._text_field, self._text_expansion_field, query_obj["query"], 1, query_obj[self._text_expansion_field], 1 ) else: raise Exception(f"The query strategy \\`{self._query_strategy}]\\` is not implemented") self._iters = (self._iters + 1) % len(self._queries) query["track_total_hits"] = self._track_total_hits query["size"] = self._size return { "index": self._index_name, "cache": self._cache, "body": query, } class WeightedRecallParamSource: def __init__(self, track, params, **kwargs): if len(track.indices) == 1: default_index = track.indices[0].name else: default_index = "_all" self._query_file = params.get("query_source", "queries-small.json") self._qrels_file = params.get("qrels_source", "qrels-small.tsv") self._index_name = params.get("index", default_index) self._cache = params.get("cache", False) self._top_k = params.get("top_k", 10) self._num_candidates = params.get("num_candidates", 100) self._params = params self._queries = [] self._text_expansion_field = params.get("text_expansion_field", "text_expansion_elser") self.infinite = True cwd = os.path.dirname(__file__) with open(os.path.join(cwd, self._query_file), "r") as file: self._queries = json.load(file) self._qrels = read_qrels(os.path.join(cwd, self._qrels_file)) def partition(self, partition_index, total_partitions): return self def params(self): return { "index": self._index_name, "cache": self._cache, "top_k": self._top_k, "num_candidates": self._num_candidates, "queries": self._queries, "qrels": self._qrels, "text_expansion_field": self._text_expansion_field, } # For each query this will generate the weighted terms query, a pruned version and a rescored pruned version of the same query. # These queries can then be executed and compared for accuracy. class WeightedTermsRecallRunner: async def __call__(self, es, params): recall_total = 0 recall_with_rescore_total = 0 exact_total = 0 min_recall = params["top_k"] weighted_term_results = defaultdict(dict) pruned_results = defaultdict(dict) pruned_rescored_results = defaultdict(dict) for query in params["queries"]: query_id = query["id"] # Build and execute all three queries weighted_terms_result = await es.search( body=generate_weighted_terms_query(params["text_expansion_field"], query[params["text_expansion_field"]], 1), index=params["index"], request_cache=params["cache"], size=params["top_k"], ) pruned_result = await es.search( body=generate_pruned_query(params["text_expansion_field"], query[params["text_expansion_field"]], 1), index=params["index"], request_cache=params["cache"], size=params["top_k"], ) pruned_rescored_result = await es.search( body=generate_rescored_pruned_query( params["text_expansion_field"], query[params["text_expansion_field"]], params["num_candidates"], 1 ), index=params["index"], request_cache=params["cache"], size=params["top_k"], ) weighted_terms_hits = {hit["_source"]["id"]: hit["_score"] for hit in weighted_terms_result["hits"]["hits"]} pruned_hits = {hit["_source"]["id"]: hit["_score"] for hit in pruned_result["hits"]["hits"]} pruned_rescored_hits = {hit["_source"]["id"]: hit["_score"] for hit in pruned_rescored_result["hits"]["hits"]} # Recall calculations as compared to the control/non-pruned hits weighted_terms_ids = set(weighted_terms_hits.keys()) pruned_ids = set(pruned_hits.keys()) pruned_rescored_ids = set(pruned_rescored_hits.keys()) current_recall_with_rescore = len(weighted_terms_ids.intersection(pruned_rescored_ids)) current_recall = len(weighted_terms_ids.intersection(pruned_ids)) recall_with_rescore_total += current_recall_with_rescore recall_total += current_recall exact_total += len(weighted_terms_ids) min_recall = min(min_recall, current_recall) # Construct input to NDCG calculation based on returned hits for doc_id, score in weighted_terms_hits.items(): weighted_term_results[query_id][doc_id] = score for doc_id, score in pruned_hits.items(): pruned_results[query_id][doc_id] = score for doc_id, score in pruned_rescored_hits.items(): pruned_rescored_results[query_id][doc_id] = score control_relevance = calc_ndcg(params["qrels"], weighted_term_results, [10, 100]) pruned_relevance = calc_ndcg(params["qrels"], pruned_results, [10, 100]) pruned_rescored_relevance = calc_ndcg(params["qrels"], pruned_rescored_results, [10, 100]) return ( { "avg_recall": float(recall_with_rescore_total / exact_total), # Calculated on pruned/rescored hits "avg_recall_without_rescore": float(recall_total / exact_total), # Calculated on pruned hits without rescore "min_recall": min_recall, # Calculated on pruned/rescored hits "top_k": params["top_k"], "num_candidates": params["num_candidates"], "control_ndcg_10": control_relevance["ndcg_cut@10"], "control_ndcg_100": control_relevance["ndcg_cut@100"], "pruned_ndcg_10": pruned_relevance["ndcg_cut@10"], "pruned_ndcg_100": pruned_relevance["ndcg_cut@100"], "pruned_rescored_ndcg_10": pruned_rescored_relevance["ndcg_cut@10"], "pruned_rescored_ndcg_100": pruned_rescored_relevance["ndcg_cut@100"], } if exact_total > 0 else None ) def __repr__(self, *args, **kwargs): return "weighted_terms_recall" def register(registry): registry.register_param_source("query_param_source", QueryParamsSource) registry.register_param_source("weighted_terms_recall_param_source", WeightedRecallParamSource) registry.register_runner("weighted_terms_recall", WeightedTermsRecallRunner(), async_runner=True)