cohere_vector/track.py (41 lines of code) (raw):
import json
import os
QUERIES_FILENAME: str = "queries.json"
class KnnParamSource:
def __init__(self, track, params, **kwargs):
# choose a suitable index: if there is only one defined for this track
# choose that one, but let the user always override index
if len(track.indices) == 1:
default_index = track.indices[0].name
else:
default_index = "_all"
self._index_name = params.get("index", default_index)
self._cache = params.get("cache", False)
self._params = params
self._queries = []
cwd = os.path.dirname(__file__)
with open(os.path.join(cwd, QUERIES_FILENAME), "r") as queries_file:
for vector_query in queries_file:
self._queries.append(json.loads(vector_query))
self._iters = 0
self._maxIters = len(self._queries)
self.infinite = True
def partition(self, partition_index, total_partitions):
return self
def params(self):
result = {"index": self._index_name, "cache": self._params.get("cache", False), "size": self._params.get("k", 10)}
result["body"] = {
"knn": {
"field": "emb",
"query_vector": self._queries[self._iters],
"k": self._params.get("k", 10),
"num_candidates": self._params.get("num-candidates", 50),
},
"_source": False,
}
if "filter" in self._params:
result["body"]["knn"]["filter"] = self._params["filter"]
self._iters += 1
if self._iters >= self._maxIters:
self._iters = 0
return result
def register(registry):
registry.register_param_source("knn-param-source", KnnParamSource)