in mico/model/mico.py [0:0]
def forward(self, document=None, query=None, is_monitor_forward=False, forward_method="update_all", device=None):
"""This function handles several types of forward:
1. for training the approximated distribution `q_z` only (`update_q_z=True`)
2. for training the MICO including finetuning BERT
3. for document assignment (`encode_doc=True`)
4. for query routing (`encode_query=True`)
Parameters
----------
document : string
A list of raw sentence of the document titles in the samples.
query : string
A list of raw sentence of the queries in samples.
is_monitor_forward : bool
Calculate more information of the forward pass: {'AUC', 'top1_cov', 'H__Z_Y', 'H__Z_X'}
Details please refer to the function `monitor_metrics`.
forward_method : string
Take values in ['update_all', 'update_q_z', 'encode_doc', 'encode_query']
The default value is 'update_all'.
Set this to 'update_q_z' if you only want to update the approximated distribution `q_z`.
Set this to 'encode_doc' if you only want to perform document assignment.
Set this to 'encode_query' if you only want to perform query routing.
device : int (for multi-GPU) or string ('cpu' or 'cuda')
The device that the BERT model is on.
Returns
-------
metrics :
By default, return a dictionary:
{'loss': loss, 'h_z_cond': h_z_cond, 'h_z': h_z}.
The datatypes are all PyTorch Tensor scalars (float).
If `is_monitor_forward=True`, there are extra keys in the dictionary:
{'AUC': metric_auc * 100, 'top1_cov': acc.item() * 100, 'H__Z_X': enc_x_entropy, 'H__Z_Y': enc_y_entropy}
The datatypes are all float.
If `encode_doc=True` or `encode_query=True`, return a PyTorch Tensor about the probability distribution (for this sample belonging to which cluster).
If `update_q_z=True`, return the cross-entropy loss for updating q_z.
"""
if query is None or document is None:
if forward_method == "encode_doc":
bert_representation = model_predict(self.model_bert, document,
self.tokenizer, self.hparams.max_length,
self.hparams.pooling_strategy,
self.hparams.selected_layer_idx, device).detach()
p = (self.p_z_y(bert_representation))
return p.view(p.shape[0], p.shape[-1])
elif forward_method == "encode_query":
bert_representation = model_predict(self.model_bert, query,
self.tokenizer, self.hparams.max_length,
self.hparams.pooling_strategy,
self.hparams.selected_layer_idx, device).detach()
p = (self.q_z_x(bert_representation))
return p.view(p.shape[0], p.shape[-1])
else:
raise ValueError("Unexpected usage of forward.")
if self.hparams.bert_fix:
with torch.no_grad():
query_bert = model_predict(self.model_bert, query,
self.tokenizer, self.hparams.max_length,
self.hparams.pooling_strategy,
self.hparams.selected_layer_idx, device).detach()
document_bert = model_predict(self.model_bert, document,
self.tokenizer, self.hparams.max_length,
self.hparams.pooling_strategy,
self.hparams.selected_layer_idx, device).detach()
else:
query_bert = model_predict(self.model_bert, query,
self.tokenizer, self.hparams.max_length,
self.hparams.pooling_strategy,
self.hparams.selected_layer_idx, device)
document_bert = model_predict(self.model_bert, document,
self.tokenizer, self.hparams.max_length,
self.hparams.pooling_strategy,
self.hparams.selected_layer_idx, device)
p = self.p_z_y(document_bert)
if forward_method == "update_q_z":
return cross_entropy_p_q(p.detach(), self.q_z())
elif forward_method != "update_all":
raise ValueError("Unexpected usage of forward.")
q = self.q_z_x(query_bert)
h_z_cond = cross_entropy_p_q(p, q)
h_z = cross_entropy_p_q(p, self.q_z())
loss = h_z_cond - self.hparams.entropy_weight * h_z
results = {'loss': loss, 'h_z_cond': h_z_cond, 'h_z': h_z}
if is_monitor_forward:
results_more = monitor_metrics(p.view(p.shape[0], p.shape[-1]), q.view(q.shape[0], q.shape[-1]))
results.update(results_more)
return results