msmarco-v2-vector/track.py (200 lines of code) (raw):

import bz2 import csv import json import os import statistics from collections import defaultdict from typing import Any, Dict, List Qrels = Dict[str, Dict[str, int]] Results = Dict[str, Dict[str, float]] QUERIES_FILENAME: str = "queries.json.bz2" QUERIES_RECALL_FILENAME: str = "queries-recall.json.bz2" def extract_vector_operations_count(knn_result): vector_operations_count = 0 profile = knn_result["profile"] for shard in profile["shards"]: assert len(shard["dfs"]["knn"]) == 1 knn_search = shard["dfs"]["knn"][0] if "vector_operations_count" in knn_search: vector_operations_count += knn_search["vector_operations_count"] return vector_operations_count def compute_percentile(data: List[Any], percentile): size = len(data) if size <= 0: return None sorted_data = sorted(data) index = int(round(percentile * size / 100)) - 1 return sorted_data[max(min(index, size - 1), 0)] def calc_ndcg(qrels: Qrels, results: Results, k_list: list): import pytrec_eval as pe scores = defaultdict(float) metrics = ["ndcg_cut"] pytrec_strings = {f"{metric}.{','.join([str(k) for k in k_list])}" for metric in metrics} evaluator = pe.RelevanceEvaluator(qrels, pytrec_strings) pytrec_scores = evaluator.evaluate(results) for query_id in pytrec_scores.keys(): for metric in metrics: for k in k_list: scores[f"{metric}@{k}"] += pytrec_scores[query_id][f"{metric}_{k}"] queries_count = len(pytrec_scores.keys()) if queries_count == 0: return scores for metric in metrics: for k in k_list: scores[f"{metric}@{k}"] = float(scores[f"{metric}@{k}"] / queries_count) return scores def read_qrels(qrels_input_file): qrels = defaultdict(dict) with open(qrels_input_file, "r") as input_file: tsv_reader = csv.reader(input_file, delimiter="\t") for query_id, doc_id, score, _ in tsv_reader: qrels[query_id][doc_id] = int(score) return qrels def get_rescore_query(vec, window_size): return { "window_size": window_size, "query": { "query_weight": 0, "rescore_query": { "script_score": { "query": {"match_all": {}}, "script": { "source": "double value = dotProduct(params.query_vector, 'emb'); return sigmoid(1, Math.E, -value);", "params": {"query_vector": vec}, }, } }, }, } class KnnParamSource: def __init__(self, track, params, **kwargs): # choose a suitable index: if there is only one defined for this track # choose that one, but let the user always override index if len(track.indices) == 1: default_index = track.indices[0].name else: default_index = "_all" self._index_name = params.get("index", default_index) self._cache = params.get("cache", False) self._params = params self._queries = [] cwd = os.path.dirname(__file__) with bz2.open(os.path.join(cwd, QUERIES_FILENAME), "r") as queries_file: for vector_query in queries_file: self._queries.append(json.loads(vector_query)) self._iters = 0 self._maxIters = len(self._queries) self.infinite = True def partition(self, partition_index, total_partitions): return self def params(self): top_k = self._params.get("k", 10) num_candidates = self._params.get("num-candidates", 50) num_rescore = self._params.get("num-rescore", 0) query_vec = self._queries[self._iters] result = { "index": self._index_name, "cache": self._params.get("cache", False), "size": top_k, "body": { "knn": {"field": "emb", "query_vector": query_vec, "k": max(top_k, num_rescore), "num_candidates": num_candidates}, "_source": False, }, } if num_rescore > 0: result["body"]["rescore"] = get_rescore_query(query_vec, num_rescore) if "filter" in self._params: result["body"]["knn"]["filter"] = self._params["filter"] self._iters += 1 if self._iters >= self._maxIters: self._iters = 0 return result class KnnRecallParamSource: def __init__(self, track, params, **kwargs): if len(track.indices) == 1: default_index = track.indices[0].name else: default_index = "_all" self._index_name = params.get("index", default_index) self._cache = params.get("cache", False) self._params = params self.infinite = True def partition(self, partition_index, total_partitions): return self def params(self): return { "index": self._index_name, "cache": self._params.get("cache", False), "size": self._params.get("k", 10), "num_candidates": self._params.get("num-candidates", 100), "num_rescore": self._params.get("num-rescore", 0), } class KnnRecallRunner: async def __call__(self, es, params): top_k = params["size"] num_candidates = params["num_candidates"] num_rescore = params["num_rescore"] index = params["index"] request_cache = params["cache"] cwd = os.path.dirname(__file__) qrels = read_qrels(os.path.join(cwd, "qrels.tsv")) results = defaultdict(dict) best_results = defaultdict(dict) recall_total = 0 exact_total = 0 min_recall = top_k nodes_visited = [] with bz2.open(os.path.join(cwd, QUERIES_RECALL_FILENAME), "r") as queries_file: for line in queries_file: query = json.loads(line) query_id = query["query_id"] body = { "knn": { "field": "emb", "query_vector": query["emb"], "k": max(top_k, num_rescore), "num_candidates": num_candidates, }, "_source": False, "fields": ["docid"], "profile": True, } if num_rescore > 0: body["rescore"] = get_rescore_query(query["emb"], num_rescore) knn_result = await es.search(index=index, request_cache=request_cache, size=top_k, body=body) knn_hits = [] for hit in knn_result["hits"]["hits"]: doc_id = hit["fields"]["docid"][0] results[query_id][doc_id] = hit["_score"] knn_hits.append(doc_id) recall_hits = [] for i in range(top_k): doc_id, score = query["ids"][i] recall_hits.append(doc_id) best_results[query_id][doc_id] = score vector_operations_count = extract_vector_operations_count(knn_result) nodes_visited.append(vector_operations_count) current_recall = len(set(knn_hits).intersection(set(recall_hits))) recall_total += current_recall exact_total += len(recall_hits) min_recall = min(min_recall, current_recall) relevance_res = calc_ndcg(qrels, results, [top_k]) best_relevance_res = calc_ndcg(qrels, best_results, [top_k]) return ( { f"best_ndcg_{top_k}": best_relevance_res[f"ndcg_cut@{top_k}"], f"ndcg_{top_k}": relevance_res[f"ndcg_cut@{top_k}"], "avg_recall": recall_total / exact_total, "min_recall": min_recall, "k": top_k, "num_candidates": num_candidates, "avg_nodes_visited": statistics.mean(nodes_visited) if any([x > 0 for x in nodes_visited]) else None, "99th_percentile_nodes_visited": compute_percentile(nodes_visited, 99) if any([x > 0 for x in nodes_visited]) else None, } if exact_total > 0 else None ) def __repr__(self, *args, **kwargs): return "knn-recall" def register(registry): registry.register_param_source("knn-param-source", KnnParamSource) registry.register_param_source("knn-recall-param-source", KnnRecallParamSource) registry.register_runner("knn-recall", KnnRecallRunner(), async_runner=True)