in generate_dense_embeddings.py [0:0]
def main(cfg: DictConfig):
assert cfg.model_file, "Please specify encoder checkpoint as model_file param"
assert cfg.ctx_src, "Please specify passages source as ctx_src param"
cfg = setup_cfg_gpu(cfg)
saved_state = load_states_from_checkpoint(cfg.model_file)
set_cfg_params_from_state(saved_state.encoder_params, cfg)
logger.info("CFG:")
logger.info("%s", OmegaConf.to_yaml(cfg))
tensorizer, encoder, _ = init_biencoder_components(cfg.encoder.encoder_model_type, cfg, inference_only=True)
encoder = encoder.ctx_model if cfg.encoder_type == "ctx" else encoder.question_model
encoder, _ = setup_for_distributed_mode(
encoder,
None,
cfg.device,
cfg.n_gpu,
cfg.local_rank,
cfg.fp16,
cfg.fp16_opt_level,
)
encoder.eval()
# load weights from the model file
model_to_load = get_model_obj(encoder)
logger.info("Loading saved model state ...")
logger.debug("saved model keys =%s", saved_state.model_dict.keys())
prefix_len = len("ctx_model.")
ctx_state = {
key[prefix_len:]: value for (key, value) in saved_state.model_dict.items() if key.startswith("ctx_model.")
}
model_to_load.load_state_dict(ctx_state)
logger.info("reading data source: %s", cfg.ctx_src)
ctx_src = hydra.utils.instantiate(cfg.ctx_sources[cfg.ctx_src])
all_passages_dict = {}
ctx_src.load_data_to(all_passages_dict)
all_passages = [(k, v) for k, v in all_passages_dict.items()]
shard_size = math.ceil(len(all_passages) / cfg.num_shards)
start_idx = cfg.shard_id * shard_size
end_idx = start_idx + shard_size
logger.info(
"Producing encodings for passages range: %d to %d (out of total %d)",
start_idx,
end_idx,
len(all_passages),
)
shard_passages = all_passages[start_idx:end_idx]
data = gen_ctx_vectors(cfg, shard_passages, encoder, tensorizer, True)
file = cfg.out_file + "_" + str(cfg.shard_id)
pathlib.Path(os.path.dirname(file)).mkdir(parents=True, exist_ok=True)
logger.info("Writing results to %s" % file)
with open(file, mode="wb") as f:
pickle.dump(data, f)
logger.info("Total passages processed %d. Written to %s", len(data), file)