in dpr_scale/run_retrieval.py [0:0]
def main(cfg: MainConfig):
# Temp patch for datamodule refactoring
cfg.task.datamodule = None
cfg.task._target_ = (
"dpr_scale.task.dpr_eval_task.GenerateQueryEmbeddingsTask" # hack
)
# trainer.fit does some setup, so we need to call it even though no training is done
with open_dict(cfg):
cfg.trainer.limit_train_batches = 0
if "plugins" in cfg.trainer:
cfg.trainer.pop(
"plugins"
) # remove ddp_sharded, because it breaks during loading
print(cfg)
task = hydra.utils.instantiate(cfg.task, _recursive_=False)
transform = hydra.utils.instantiate(cfg.task.transform)
datamodule = hydra.utils.instantiate(cfg.datamodule, transform=transform)
trainer = Trainer(**cfg.trainer)
trainer.fit(task, datamodule=datamodule)
trainer.test(task, datamodule=datamodule)
# index all passages
input_paths = sorted(glob.glob(os.path.join(cfg.task.ctx_embeddings_dir, "reps_*")))
index = build_index(input_paths)
# reload question embeddings
print("Loading question vectors.")
with open(
task.query_emb_output_path, "rb"
) as f:
q_repr = pickle.load(f)
print("Retrieving results...")
scores, indexes = index.search(q_repr.numpy(), 100)
# load questions file
print(f"Loading questions file {cfg.datamodule.test_path}")
with open(cfg.datamodule.test_path) as f:
questions = [json.loads(line) for line in f]
# load all passages:
print(f"Loading passages from {cfg.task.passages}")
ctxs = CSVDataset(cfg.task.passages)
# write output file
print("Merging results...")
results = merge_results(ctxs, questions, indexes, scores)
print(f"Writing output to {cfg.task.output_path}")
os.makedirs(os.path.dirname(cfg.task.output_path), exist_ok=True)
with open(cfg.task.output_path, "w") as g:
g.write(json.dumps(results, indent=4))
g.write("\n")