msmarco-v2-vector/_tools/parse_queries.py (78 lines of code) (raw):
import argparse
import asyncio
import json
import sys
from os import environ
import ir_datasets
import numpy
import vg
from cohere import AsyncClient
from elasticsearch import AsyncElasticsearch
DATASET_NAME: str = "msmarco-passage-v2/train"
RECALL_DATASET_NAME: str = "msmarco-passage-v2/trec-dl-2022/judged"
OUTPUT_FILENAME: str = "queries.json"
OUTPUT_RECALL_FILENAME: str = "queries-recall.json"
MAX_DOCS: int = 12_000
REQUEST_TIMEOUT: int = 60 * 60 * 5
def get_brute_force_query(emb):
return {
"script_score": {
"query": {"match_all": {}},
"script": {
"source": "double value = dotProduct(params.query_vector, 'emb'); return sigmoid(1, Math.E, -value);",
"params": {"query_vector": emb},
},
}
}
async def retrieve_embed_for_query(co, text):
response = await co.embed(texts=[text], model="embed-english-v3.0", input_type="search_query")
return vg.normalize(numpy.array(response.embeddings[0])).tolist()
async def output_queries(queries_file):
output = []
dataset = ir_datasets.load(DATASET_NAME)
# Get your production Cohere API key from https://dashboard.cohere.com/api-keys
async with AsyncClient(environ["COHERE_API_KEY"]) as co:
co_queries = []
for query in dataset.queries_iter():
co_queries.append(query.text)
# Run our async requests every 100 queries *or* as soon as we
# have enough to fill our output list
output_left = MAX_DOCS - len(output)
if len(co_queries) in (100, output_left):
cos = (retrieve_embed_for_query(co, q) for q in co_queries)
co_queries = []
output += [v for v in await asyncio.gather(*cos) if not isinstance(v, Exception)]
if len(output) == MAX_DOCS:
break
queries_file.write("\n".join(json.dumps(embed) for embed in output))
async def output_recall_queries(queries_file):
async with AsyncElasticsearch(
"https://localhost:19200/", basic_auth=("esbench", "super-secret-password"), verify_certs=False, request_timeout=REQUEST_TIMEOUT
) as es:
dataset = ir_datasets.load("msmarco-passage-v2/trec-dl-2022/judged")
async with AsyncClient(environ["COHERE_API_KEY"]) as co:
count = 0
for query in dataset.queries_iter():
emb = await retrieve_embed_for_query(co, query[1])
resp = await es.search(
index="msmarco-v2", query=get_brute_force_query(emb), size=1000, _source=["_none_"], fields=["docid"]
)
ids = [(hit["fields"]["docid"][0], hit["_score"]) for hit in resp["hits"]["hits"]]
line = {"query_id": query[0], "text": query[1], "emb": emb, "ids": ids}
queries_file.write(json.dumps(line) + "\n")
count += 1
async def create_queries():
with open(OUTPUT_FILENAME, "w") as queries_file:
await output_queries(queries_file)
async def create_recall_queries():
with open(OUTPUT_RECALL_FILENAME, "w") as queries_file:
await output_recall_queries(queries_file)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Create queries for throughput or recall operations")
parser.add_argument("-t", "--throughput", help="Create queries for throughput operations", action="store_true")
parser.add_argument("-r", "--recall", help="Create queries for recall operations", action="store_true")
if len(sys.argv) == 1:
# Neither -t or -r was called, show the options
parser.print_help(sys.stderr)
args = parser.parse_args()
loop = asyncio.get_event_loop()
if args.throughput:
loop.run_until_complete(create_queries())
if args.recall:
loop.run_until_complete(create_recall_queries())