def main()

in dpr_scale/run_retrieval.py [0:0]


def main(cfg: MainConfig):
    # Temp patch for datamodule refactoring
    cfg.task.datamodule = None
    cfg.task._target_ = (
        "dpr_scale.task.dpr_eval_task.GenerateQueryEmbeddingsTask"  # hack
    )
    # trainer.fit does some setup, so we need to call it even though no training is done
    with open_dict(cfg):
        cfg.trainer.limit_train_batches = 0
        if "plugins" in cfg.trainer:
            cfg.trainer.pop(
                "plugins"
            )  # remove ddp_sharded, because it breaks during loading

    print(cfg)

    task = hydra.utils.instantiate(cfg.task, _recursive_=False)
    transform = hydra.utils.instantiate(cfg.task.transform)
    datamodule = hydra.utils.instantiate(cfg.datamodule, transform=transform)

    trainer = Trainer(**cfg.trainer)
    trainer.fit(task, datamodule=datamodule)
    trainer.test(task, datamodule=datamodule)

    # index all passages
    input_paths = sorted(glob.glob(os.path.join(cfg.task.ctx_embeddings_dir, "reps_*")))
    index = build_index(input_paths)

    # reload question embeddings
    print("Loading question vectors.")
    with open(
        task.query_emb_output_path, "rb"
    ) as f:
        q_repr = pickle.load(f)

    print("Retrieving results...")
    scores, indexes = index.search(q_repr.numpy(), 100)

    # load questions file
    print(f"Loading questions file {cfg.datamodule.test_path}")
    with open(cfg.datamodule.test_path) as f:
        questions = [json.loads(line) for line in f]

    # load all passages:
    print(f"Loading passages from {cfg.task.passages}")
    ctxs = CSVDataset(cfg.task.passages)

    # write output file
    print("Merging results...")
    results = merge_results(ctxs, questions, indexes, scores)
    print(f"Writing output to {cfg.task.output_path}")
    os.makedirs(os.path.dirname(cfg.task.output_path), exist_ok=True)
    with open(cfg.task.output_path, "w") as g:
        g.write(json.dumps(results, indent=4))
        g.write("\n")