def main()

in dpr_scale/run_retrieval_fb.py [0:0]


def main(args, logger):
    # Temp patch for datamodule refactoring
    logger.info(args.__dict__)

    # index all passages
    local_ctx_embeddings_dir = PathManager.get_local_path(
        args.ctx_embeddings_dir
    )
    input_paths = sorted(glob.glob(
        os.path.join(local_ctx_embeddings_dir, "reps_*")
    ))
    index = build_index(input_paths)

    # reload question embeddings
    print("Loading question vectors.")
    if not args.query_emb_path:
        args.query_emb_path = os.path.join(
            args.ctx_embeddings_dir, "query_reps.pkl"
        )
    with PathManager.open(
        args.query_emb_path, "rb"
    ) as f:
        q_repr = pickle.load(f)  # noqa

    print("Retrieving results...")
    scores, indexes = index.search(q_repr.numpy(), args.topk)

    # load questions file
    print(f"Loading questions file {args.questions_jsonl_path}")
    with PathManager.open(args.questions_jsonl_path) as f:
        questions = [json.loads(line) for line in f]

    # load all passages:
    print(f"Loading passages from {args.passages_tsv_path}")
    ctxs = CSVDataset(args.passages_tsv_path)

    # write output file
    print("Merging results...")
    results = merge_results(ctxs, questions, indexes, scores)

    print(f"Writing output to {args.output_json_path}")
    pathlib.Path(args.output_json_path).parent.mkdir(
        parents=True, exist_ok=True
    )
    with PathManager.open(args.output_json_path, "w") as g:
        g.write(json.dumps(results, indent=4))
        g.write("\n")