in dense_retriever.py [0:0]
def main(cfg: DictConfig):
cfg = setup_cfg_gpu(cfg)
logger.info("CFG (after gpu configuration):")
logger.info("%s", OmegaConf.to_yaml(cfg))
saved_state = load_states_from_checkpoint(cfg.model_file)
set_cfg_params_from_state(saved_state.encoder_params, cfg)
tensorizer, encoder, _ = init_biencoder_components(cfg.encoder.encoder_model_type, cfg, inference_only=True)
encoder_path = cfg.encoder_path
if encoder_path:
logger.info("Selecting encoder: %s", encoder_path)
encoder = getattr(encoder, encoder_path)
else:
logger.info("Selecting standard question encoder")
encoder = encoder.question_model
encoder, _ = setup_for_distributed_mode(encoder, None, cfg.device, cfg.n_gpu, cfg.local_rank, cfg.fp16)
encoder.eval()
# load weights from the model file
model_to_load = get_model_obj(encoder)
logger.info("Loading saved model state ...")
encoder_prefix = (encoder_path if encoder_path else "question_model") + "."
prefix_len = len(encoder_prefix)
logger.info("Encoder state prefix %s", encoder_prefix)
question_encoder_state = {
key[prefix_len:]: value for (key, value) in saved_state.model_dict.items() if key.startswith(encoder_prefix)
}
# TODO: long term HF state compatibility fix
model_to_load.load_state_dict(question_encoder_state, strict=False)
vector_size = model_to_load.get_out_size()
logger.info("Encoder vector_size=%d", vector_size)
# get questions & answers
questions = []
question_answers = []
if not cfg.qa_dataset:
logger.warning("Please specify qa_dataset to use")
return
ds_key = cfg.qa_dataset
logger.info("qa_dataset: %s", ds_key)
qa_src = hydra.utils.instantiate(cfg.datasets[ds_key])
qa_src.load_data()
for ds_item in qa_src.data:
question, answers = ds_item.query, ds_item.answers
questions.append(question)
question_answers.append(answers)
index = hydra.utils.instantiate(cfg.indexers[cfg.indexer])
logger.info("Index class %s ", type(index))
index_buffer_sz = index.buffer_size
index.init_index(vector_size)
retriever = LocalFaissRetriever(encoder, cfg.batch_size, tensorizer, index)
logger.info("Using special token %s", qa_src.special_query_token)
questions_tensor = retriever.generate_question_vectors(questions, query_token=qa_src.special_query_token)
if qa_src.selector:
logger.info("Using custom representation token selector")
retriever.selector = qa_src.selector
id_prefixes = []
ctx_sources = []
for ctx_src in cfg.ctx_datatsets:
ctx_src = hydra.utils.instantiate(cfg.ctx_sources[ctx_src])
id_prefixes.append(ctx_src.id_prefix)
ctx_sources.append(ctx_src)
logger.info("id_prefixes per dataset: %s", id_prefixes)
# index all passages
ctx_files_patterns = cfg.encoded_ctx_files
index_path = cfg.index_path
logger.info("ctx_files_patterns: %s", ctx_files_patterns)
if ctx_files_patterns:
assert len(ctx_files_patterns) == len(id_prefixes), "ctx len={} pref leb={}".format(
len(ctx_files_patterns), len(id_prefixes)
)
else:
assert index_path, "Either encoded_ctx_files or index_path parameter should be set."
input_paths = []
path_id_prefixes = []
for i, pattern in enumerate(ctx_files_patterns):
pattern_files = glob.glob(pattern)
pattern_id_prefix = id_prefixes[i]
input_paths.extend(pattern_files)
path_id_prefixes.extend([pattern_id_prefix] * len(pattern_files))
logger.info("Embeddings files id prefixes: %s", path_id_prefixes)
if index_path and index.index_exists(index_path):
logger.info("Index path: %s", index_path)
retriever.index.deserialize(index_path)
else:
logger.info("Reading all passages data from files: %s", input_paths)
retriever.index_encoded_data(input_paths, index_buffer_sz, path_id_prefixes=path_id_prefixes)
if index_path:
retriever.index.serialize(index_path)
# get top k results
top_ids_and_scores = retriever.get_top_docs(questions_tensor.numpy(), cfg.n_docs)
# we no longer need the index
retriever = None
all_passages = {}
for ctx_src in ctx_sources:
ctx_src.load_data_to(all_passages)
if len(all_passages) == 0:
raise RuntimeError("No passages data found. Please specify ctx_file param properly.")
if cfg.validate_as_tables:
questions_doc_hits = validate_tables(
all_passages,
question_answers,
top_ids_and_scores,
cfg.validation_workers,
cfg.match,
)
else:
questions_doc_hits = validate(
all_passages,
question_answers,
top_ids_and_scores,
cfg.validation_workers,
cfg.match,
)
if cfg.out_file:
save_results(
all_passages,
questions,
question_answers,
top_ids_and_scores,
questions_doc_hits,
cfg.out_file,
)
if cfg.kilt_out_file:
kilt_ctx = next(iter([ctx for ctx in ctx_sources if isinstance(ctx, KiltCsvCtxSrc)]), None)
if not kilt_ctx:
raise RuntimeError("No Kilt compatible context file provided")
assert hasattr(cfg, "kilt_out_file")
kilt_ctx.convert_to_kilt(qa_src.kilt_gold_file, cfg.out_file, cfg.kilt_out_file)