in mico/model/mico.py [0:0]
def __init__(self, hparams):
"""We initialize all the parameters for
the document assignment model P(Z|Y) which is `self.p_z_y`, where Z is the cluster index to which a document Y is assigned.
the query routing model P(Z|X) which is `self.q_z_x`, where Z is the cluster index to which a query X is routed.
the model for approximating the distribution of cluster sizes of document assignment E_Y[P(Z|Y)] which is `self.q_z`,
(For details, please check https://arxiv.org/pdf/2209.04378.pdf)
We load the BERT model and its corresponding tokenizer.
Parameters
----------
hparams : argparse result
All the hyper-parameters for MICO. This will also be used in training.
"""
super().__init__()
self.hparams = hparams
self.p_z_y = ConditionalDistributionZ(self.hparams.number_clusters,
self.hparams.dim_input,
self.hparams.num_layers_posterior,
self.hparams.dim_hidden)
self.q_z = MarginalDistributionZ(self.hparams.number_clusters)
self.q_z_x = ConditionalDistributionZ(self.hparams.number_clusters,
self.hparams.dim_input,
self.hparams.num_layers_posterior,
self.hparams.dim_hidden)
self.apply(get_init_function(self.hparams.init))
try:
tokenizer = BertTokenizerFast.from_pretrained('bert-base-multilingual-cased', local_files_only=True)
config_bert = BertConfig.from_pretrained("bert-base-multilingual-cased", output_hidden_states=True,
local_files_only=True)
model_bert = BertModel.from_pretrained("bert-base-multilingual-cased", config=config_bert,
local_files_only=True)
except:
# Connect to Internet to download the Huggingface tokenizer, config, and model.
tokenizer = BertTokenizerFast.from_pretrained('bert-base-multilingual-cased')
config_bert = BertConfig.from_pretrained("bert-base-multilingual-cased", output_hidden_states=True)
model_bert = BertModel.from_pretrained("bert-base-multilingual-cased", config=config_bert)
self.model_bert = model_bert
self.tokenizer = tokenizer