in dpr_scale/run_retrieval_multiset.py [0:0]
def main(args, logger):
# Temp patch for datamodule refactoring
logger.info(args.__dict__)
assert (
len(args.query_emb_paths)
== len(args.questions_jsonl_paths)
== len(args.output_json_paths)
)
# index all passages
print("Loading passage vectors.")
local_ctx_embeddings_dir = PathManager.get_local_path(
args.ctx_embeddings_dir
)
input_paths = sorted(glob.glob(
os.path.join(local_ctx_embeddings_dir, "reps_*")
))
index = build_index(input_paths)
# load all passages:
print(f"Loading passages from {args.passages_tsv_path}")
ctxs = CSVDataset(args.passages_tsv_path)
for questions_jsonl_path, query_emb_path, output_json_path in zip(
args.questions_jsonl_paths, args.query_emb_paths, args.output_json_paths
):
# load questions file
print(f"Loading questions file {questions_jsonl_path}")
with PathManager.open(questions_jsonl_path) as f:
questions = [json.loads(line) for line in f]
# reload question embeddings
print("Loading question vectors.")
if not query_emb_path:
query_emb_path = os.path.join(
args.ctx_embeddings_dir, "query_reps.pkl"
)
with PathManager.open(
query_emb_path, "rb"
) as f:
q_repr = pickle.load(f) # noqa
print("Retrieving results...")
scores, indexes = index.search(q_repr.numpy(), 100)
# write output file
print("Merging results...")
results = merge_results(ctxs, questions, indexes, scores)
print(f"Writing output to {output_json_path}")
pathlib.Path(output_json_path).parent.mkdir(
parents=True, exist_ok=True
)
with PathManager.open(output_json_path, "w") as g:
g.write(json.dumps(results, indent=4))
g.write("\n")