so_vector/track.py (51 lines of code) (raw):
import json
import os
class KnnParamSource:
def __init__(self, track, params, **kwargs):
# choose a suitable index: if there is only one defined for this track
# choose that one, but let the user always override index
if len(track.indices) == 1:
default_index = track.indices[0].name
else:
default_index = "_all"
self._index_name = params.get("index", default_index)
self._cache = params.get("cache", False)
self._exact_scan = params.get("exact", False)
self._params = params
cwd = os.path.dirname(__file__)
with open(os.path.join(cwd, "queries.json"), "r") as file:
lines = file.readlines()
self._queries = [json.loads(line) for line in lines]
self.infinite = True
def partition(self, partition_index, total_partitions):
return self
def params(self):
result = {"index": self._index_name, "cache": self._params.get("cache", False), "size": self._params.get("k", 10)}
if self._exact_scan:
result["body"] = {
"query": {
"script_score": {
"query": {"match_all": {}},
"script": {
"source": "dotProduct(params.query, 'titleVector') + 1.0",
"params": {"query": self._queries[0]},
},
}
},
"_source": False,
}
if "filter" in self._params:
result["body"]["query"]["script_score"]["query"] = self._params["filter"]
else:
result["body"] = {
"knn": {
"field": "titleVector",
"query_vector": self._queries[0],
"k": self._params.get("k", 10),
"num_candidates": self._params.get("num-candidates", 50),
},
"_source": False,
}
if "filter" in self._params:
result["body"]["knn"]["filter"] = self._params["filter"]
return result
def register(registry):
registry.register_param_source("knn-param-source", KnnParamSource)