in dpr_scale/run_retrieval_fb.py [0:0]
def main(args, logger):
# Temp patch for datamodule refactoring
logger.info(args.__dict__)
# index all passages
local_ctx_embeddings_dir = PathManager.get_local_path(
args.ctx_embeddings_dir
)
input_paths = sorted(glob.glob(
os.path.join(local_ctx_embeddings_dir, "reps_*")
))
index = build_index(input_paths)
# reload question embeddings
print("Loading question vectors.")
if not args.query_emb_path:
args.query_emb_path = os.path.join(
args.ctx_embeddings_dir, "query_reps.pkl"
)
with PathManager.open(
args.query_emb_path, "rb"
) as f:
q_repr = pickle.load(f) # noqa
print("Retrieving results...")
scores, indexes = index.search(q_repr.numpy(), args.topk)
# load questions file
print(f"Loading questions file {args.questions_jsonl_path}")
with PathManager.open(args.questions_jsonl_path) as f:
questions = [json.loads(line) for line in f]
# load all passages:
print(f"Loading passages from {args.passages_tsv_path}")
ctxs = CSVDataset(args.passages_tsv_path)
# write output file
print("Merging results...")
results = merge_results(ctxs, questions, indexes, scores)
print(f"Writing output to {args.output_json_path}")
pathlib.Path(args.output_json_path).parent.mkdir(
parents=True, exist_ok=True
)
with PathManager.open(args.output_json_path, "w") as g:
g.write(json.dumps(results, indent=4))
g.write("\n")