in kilt/retrievers/DPR_connector.py [0:0]
def __init__(self, name, **config):
super().__init__(name)
self.args = argparse.Namespace(**config)
saved_state = load_states_from_checkpoint(self.args.model_file)
set_encoder_params_from_state(saved_state.encoder_params, self.args)
tensorizer, encoder, _ = init_biencoder_components(
self.args.encoder_model_type, self.args, inference_only=True
)
encoder = encoder.question_model
encoder, _ = setup_for_distributed_mode(
encoder,
None,
self.args.device,
self.args.n_gpu,
self.args.local_rank,
self.args.fp16,
)
encoder.eval()
# load weights from the model file
model_to_load = get_model_obj(encoder)
prefix_len = len("question_model.")
question_encoder_state = {
key[prefix_len:]: value
for (key, value) in saved_state.model_dict.items()
if key.startswith("question_model.")
}
model_to_load.load_state_dict(question_encoder_state, strict=False)
vector_size = model_to_load.get_out_size()
# index all passages
ctx_files_pattern = self.args.encoded_ctx_file
input_paths = glob.glob(ctx_files_pattern)
index_buffer_sz = self.args.index_buffer
if self.args.hnsw_index:
index = DenseHNSWFlatIndexer(vector_size)
index.deserialize_from(self.args.hnsw_index_path)
else:
index = DenseFlatIndexer(vector_size)
index.index_data(input_paths)
self.retriever = DenseRetriever(
encoder, self.args.batch_size, tensorizer, index
)
# not needed for now
self.all_passages = load_passages(self.args.ctx_file)
self.KILT_mapping = None
if self.args.KILT_mapping:
self.KILT_mapping = pickle.load(open(self.args.KILT_mapping, "rb"))