def main()

in dpr_scale/run_retrieval_multiset.py [0:0]


def main(args, logger):
    # Temp patch for datamodule refactoring
    logger.info(args.__dict__)

    assert (
        len(args.query_emb_paths)
        == len(args.questions_jsonl_paths)
        == len(args.output_json_paths)
    )
    # index all passages
    print("Loading passage vectors.")
    local_ctx_embeddings_dir = PathManager.get_local_path(
        args.ctx_embeddings_dir
    )
    input_paths = sorted(glob.glob(
        os.path.join(local_ctx_embeddings_dir, "reps_*")
    ))
    index = build_index(input_paths)

    # load all passages:
    print(f"Loading passages from {args.passages_tsv_path}")
    ctxs = CSVDataset(args.passages_tsv_path)

    for questions_jsonl_path, query_emb_path, output_json_path in zip(
        args.questions_jsonl_paths, args.query_emb_paths, args.output_json_paths
    ):
        # load questions file
        print(f"Loading questions file {questions_jsonl_path}")
        with PathManager.open(questions_jsonl_path) as f:
            questions = [json.loads(line) for line in f]

        # reload question embeddings
        print("Loading question vectors.")
        if not query_emb_path:
            query_emb_path = os.path.join(
                args.ctx_embeddings_dir, "query_reps.pkl"
            )
        with PathManager.open(
            query_emb_path, "rb"
        ) as f:
            q_repr = pickle.load(f)  # noqa

        print("Retrieving results...")
        scores, indexes = index.search(q_repr.numpy(), 100)

        # write output file
        print("Merging results...")
        results = merge_results(ctxs, questions, indexes, scores)

        print(f"Writing output to {output_json_path}")
        pathlib.Path(output_json_path).parent.mkdir(
            parents=True, exist_ok=True
        )
        with PathManager.open(output_json_path, "w") as g:
            g.write(json.dumps(results, indent=4))
            g.write("\n")