in msmarco-passage-ranking/track.py [0:0]
def params(self):
query_obj = self._queries[self._iters]
if self._query_strategy == "bm25":
query = generate_bm25_query(text_field=self._text_field, query=query_obj["query"], boost=1)
elif self._query_strategy == "text_expansion":
if self._prune is False:
query = generate_weighted_terms_query(
text_expansion_field=self._text_expansion_field, query_expansion=query_obj[self._text_expansion_field], boost=1
)
elif self._rescore is True:
query = generate_rescored_pruned_query(
field=self._text_expansion_field,
query_expansion=query_obj[self._text_expansion_field],
num_candidates=self._num_candidates,
boost=1,
)
else:
query = generate_pruned_query(
field=self._text_expansion_field, query_expansion=query_obj[self._text_expansion_field], boost=1
)
elif self._query_strategy == "hybrid":
query = generate_combine_bm25_weighted_terms_query(
self._text_field, self._text_expansion_field, query_obj["query"], 1, query_obj[self._text_expansion_field], 1
)
else:
raise Exception(f"The query strategy \\`{self._query_strategy}]\\` is not implemented")
self._iters = (self._iters + 1) % len(self._queries)
query["track_total_hits"] = self._track_total_hits
query["size"] = self._size
return {
"index": self._index_name,
"cache": self._cache,
"body": query,
}