msmarco-passage-ranking/track.py (250 lines of code) (raw):
import csv
import json
import os
from collections import defaultdict
from typing import Dict
Qrels = Dict[str, Dict[str, int]]
Results = Dict[str, Dict[str, float]]
def calc_ndcg(qrels: Qrels, results: Results, k_list: list):
import pytrec_eval as pe
for qid, rels in results.items():
for pid in list(rels):
if qid == pid:
results[qid].pop(pid)
scores = defaultdict(float)
metrics = ["ndcg_cut"]
pytrec_strings = {f"{metric}.{','.join([str(k) for k in k_list])}" for metric in metrics}
evaluator = pe.RelevanceEvaluator(qrels, pytrec_strings)
pytrec_scores = evaluator.evaluate(results)
for query_id in pytrec_scores.keys():
for metric in metrics:
for k in k_list:
scores[f"{metric}@{k}"] += pytrec_scores[query_id][f"{metric}_{k}"]
queries_count = len(pytrec_scores.keys())
if queries_count == 0:
return scores
for metric in metrics:
for k in k_list:
scores[f"{metric}@{k}"] = float(scores[f"{metric}@{k}"] / queries_count)
return scores
def read_qrels(qrels_input_file):
qrels = defaultdict(dict)
with open(qrels_input_file, "r") as input_file:
tsv_reader = csv.reader(input_file, delimiter="\t")
for query_id, doc_id, score in tsv_reader:
qrels[query_id][doc_id] = int(score)
return qrels
def generate_weighted_terms_query(text_expansion_field, query_expansion, boost=1.0):
return {
"query": {
"bool": {
"should": [
{"term": {f"{text_expansion_field}": {"value": f"{key}", "boost": value}}} for key, value in query_expansion.items()
],
"boost": boost,
}
}
}
def generate_bm25_query(text_field, query, boost=1.0):
return {"query": {"match": {f"{text_field}": {"query": query, "boost": boost}}}}
def generate_combine_bm25_weighted_terms_query(
text_field, text_expansion_field, query, query_boost, query_expansion, query_expansion_boost
):
return {
"query": {
"bool": {
"should": [
generate_bm25_query(text_field, query, query_boost)["query"],
generate_weighted_terms_query(text_expansion_field, query_expansion, query_expansion_boost)["query"],
]
}
}
}
def generate_pruned_query(field, query_expansion, boost=1.0):
return {"query": {"sparse_vector": {"field": field, "query_vector": query_expansion, "prune": True, "boost": boost}}}
def generate_rescored_pruned_query(field, query_expansion, num_candidates, boost=1.0):
return {
"query": {"sparse_vector": {"field": field, "query_vector": query_expansion, "prune": True, "boost": boost}},
"rescore": {
"window_size": num_candidates,
"query": {
"rescore_query": {
"sparse_vector": {
"field": field,
"query_vector": query_expansion,
"prune": True,
"pruning_config": {
"only_score_pruned_tokens": True,
},
"boost": boost,
}
}
},
},
}
class QueryParamsSource:
def __init__(self, track, params, **kwargs):
# choose a suitable index: if there is only one defined for this track
# choose that one, but let the user always override index
if len(track.indices) == 1:
default_index = track.indices[0].name
else:
default_index = "_all"
self._index_name = params.get("index", default_index)
self._cache = params.get("cache", False)
self._size = params.get("size", 10)
self._num_candidates = params.get("num_candidates", 10)
self._text_field = params.get("text_field", "text")
self._text_expansion_field = params.get("text_expansion_field", "text_expansion_elser")
self._query_file = params.get("query_source", "queries.json")
self._query_strategy = params.get("query_strategy", "bm25")
self._track_total_hits = params.get("track_total_hits", False)
self._rescore = params.get("rescore", False)
self._prune = params.get("prune", False)
self._params = params
self.infinite = True
cwd = os.path.dirname(__file__)
with open(os.path.join(cwd, self._query_file), "r") as file:
self._queries = json.load(file)
self._iters = 0
def partition(self, partition_index, total_partitions):
return self
def params(self):
query_obj = self._queries[self._iters]
if self._query_strategy == "bm25":
query = generate_bm25_query(text_field=self._text_field, query=query_obj["query"], boost=1)
elif self._query_strategy == "text_expansion":
if self._prune is False:
query = generate_weighted_terms_query(
text_expansion_field=self._text_expansion_field, query_expansion=query_obj[self._text_expansion_field], boost=1
)
elif self._rescore is True:
query = generate_rescored_pruned_query(
field=self._text_expansion_field,
query_expansion=query_obj[self._text_expansion_field],
num_candidates=self._num_candidates,
boost=1,
)
else:
query = generate_pruned_query(
field=self._text_expansion_field, query_expansion=query_obj[self._text_expansion_field], boost=1
)
elif self._query_strategy == "hybrid":
query = generate_combine_bm25_weighted_terms_query(
self._text_field, self._text_expansion_field, query_obj["query"], 1, query_obj[self._text_expansion_field], 1
)
else:
raise Exception(f"The query strategy \\`{self._query_strategy}]\\` is not implemented")
self._iters = (self._iters + 1) % len(self._queries)
query["track_total_hits"] = self._track_total_hits
query["size"] = self._size
return {
"index": self._index_name,
"cache": self._cache,
"body": query,
}
class WeightedRecallParamSource:
def __init__(self, track, params, **kwargs):
if len(track.indices) == 1:
default_index = track.indices[0].name
else:
default_index = "_all"
self._query_file = params.get("query_source", "queries-small.json")
self._qrels_file = params.get("qrels_source", "qrels-small.tsv")
self._index_name = params.get("index", default_index)
self._cache = params.get("cache", False)
self._top_k = params.get("top_k", 10)
self._num_candidates = params.get("num_candidates", 100)
self._params = params
self._queries = []
self._text_expansion_field = params.get("text_expansion_field", "text_expansion_elser")
self.infinite = True
cwd = os.path.dirname(__file__)
with open(os.path.join(cwd, self._query_file), "r") as file:
self._queries = json.load(file)
self._qrels = read_qrels(os.path.join(cwd, self._qrels_file))
def partition(self, partition_index, total_partitions):
return self
def params(self):
return {
"index": self._index_name,
"cache": self._cache,
"top_k": self._top_k,
"num_candidates": self._num_candidates,
"queries": self._queries,
"qrels": self._qrels,
"text_expansion_field": self._text_expansion_field,
}
# For each query this will generate the weighted terms query, a pruned version and a rescored pruned version of the same query.
# These queries can then be executed and compared for accuracy.
class WeightedTermsRecallRunner:
async def __call__(self, es, params):
recall_total = 0
recall_with_rescore_total = 0
exact_total = 0
min_recall = params["top_k"]
weighted_term_results = defaultdict(dict)
pruned_results = defaultdict(dict)
pruned_rescored_results = defaultdict(dict)
for query in params["queries"]:
query_id = query["id"]
# Build and execute all three queries
weighted_terms_result = await es.search(
body=generate_weighted_terms_query(params["text_expansion_field"], query[params["text_expansion_field"]], 1),
index=params["index"],
request_cache=params["cache"],
size=params["top_k"],
)
pruned_result = await es.search(
body=generate_pruned_query(params["text_expansion_field"], query[params["text_expansion_field"]], 1),
index=params["index"],
request_cache=params["cache"],
size=params["top_k"],
)
pruned_rescored_result = await es.search(
body=generate_rescored_pruned_query(
params["text_expansion_field"], query[params["text_expansion_field"]], params["num_candidates"], 1
),
index=params["index"],
request_cache=params["cache"],
size=params["top_k"],
)
weighted_terms_hits = {hit["_source"]["id"]: hit["_score"] for hit in weighted_terms_result["hits"]["hits"]}
pruned_hits = {hit["_source"]["id"]: hit["_score"] for hit in pruned_result["hits"]["hits"]}
pruned_rescored_hits = {hit["_source"]["id"]: hit["_score"] for hit in pruned_rescored_result["hits"]["hits"]}
# Recall calculations as compared to the control/non-pruned hits
weighted_terms_ids = set(weighted_terms_hits.keys())
pruned_ids = set(pruned_hits.keys())
pruned_rescored_ids = set(pruned_rescored_hits.keys())
current_recall_with_rescore = len(weighted_terms_ids.intersection(pruned_rescored_ids))
current_recall = len(weighted_terms_ids.intersection(pruned_ids))
recall_with_rescore_total += current_recall_with_rescore
recall_total += current_recall
exact_total += len(weighted_terms_ids)
min_recall = min(min_recall, current_recall)
# Construct input to NDCG calculation based on returned hits
for doc_id, score in weighted_terms_hits.items():
weighted_term_results[query_id][doc_id] = score
for doc_id, score in pruned_hits.items():
pruned_results[query_id][doc_id] = score
for doc_id, score in pruned_rescored_hits.items():
pruned_rescored_results[query_id][doc_id] = score
control_relevance = calc_ndcg(params["qrels"], weighted_term_results, [10, 100])
pruned_relevance = calc_ndcg(params["qrels"], pruned_results, [10, 100])
pruned_rescored_relevance = calc_ndcg(params["qrels"], pruned_rescored_results, [10, 100])
return (
{
"avg_recall": float(recall_with_rescore_total / exact_total), # Calculated on pruned/rescored hits
"avg_recall_without_rescore": float(recall_total / exact_total), # Calculated on pruned hits without rescore
"min_recall": min_recall, # Calculated on pruned/rescored hits
"top_k": params["top_k"],
"num_candidates": params["num_candidates"],
"control_ndcg_10": control_relevance["ndcg_cut@10"],
"control_ndcg_100": control_relevance["ndcg_cut@100"],
"pruned_ndcg_10": pruned_relevance["ndcg_cut@10"],
"pruned_ndcg_100": pruned_relevance["ndcg_cut@100"],
"pruned_rescored_ndcg_10": pruned_rescored_relevance["ndcg_cut@10"],
"pruned_rescored_ndcg_100": pruned_rescored_relevance["ndcg_cut@100"],
}
if exact_total > 0
else None
)
def __repr__(self, *args, **kwargs):
return "weighted_terms_recall"
def register(registry):
registry.register_param_source("query_param_source", QueryParamsSource)
registry.register_param_source("weighted_terms_recall_param_source", WeightedRecallParamSource)
registry.register_runner("weighted_terms_recall", WeightedTermsRecallRunner(), async_runner=True)