def forward()

in mico/model/mico.py [0:0]


    def forward(self, document=None, query=None, is_monitor_forward=False, forward_method="update_all", device=None):
        """This function handles several types of forward:
            1. for training the approximated distribution `q_z` only  (`update_q_z=True`)
            2. for training the MICO including finetuning BERT  
            3. for document assignment (`encode_doc=True`)
            4. for query routing (`encode_query=True`)

        Parameters
        ----------
        document : string
            A list of raw sentence of the document titles in the samples.
        query : string
            A list of raw sentence of the queries in samples.
        is_monitor_forward : bool
            Calculate more information of the forward pass: {'AUC', 'top1_cov', 'H__Z_Y', 'H__Z_X'}
            Details please refer to the function `monitor_metrics`.
        forward_method : string
            Take values in ['update_all', 'update_q_z', 'encode_doc', 'encode_query']
            The default value is 'update_all'.
            Set this to 'update_q_z' if you only want to update the approximated distribution `q_z`.
            Set this to 'encode_doc' if you only want to perform document assignment.
            Set this to 'encode_query' if you only want to perform query routing.
        device : int (for multi-GPU) or string ('cpu' or 'cuda')
            The device that the BERT model is on.

        Returns
        -------
        metrics : 
            By default, return a dictionary: 
                        {'loss': loss, 'h_z_cond': h_z_cond, 'h_z': h_z}. 
                        The datatypes are all PyTorch Tensor scalars (float).
                If `is_monitor_forward=True`, there are extra keys in the dictionary: 
                        {'AUC': metric_auc * 100, 'top1_cov': acc.item() * 100, 'H__Z_X': enc_x_entropy, 'H__Z_Y': enc_y_entropy}
                        The datatypes are all float.
            If `encode_doc=True` or `encode_query=True`, return a PyTorch Tensor about the probability distribution (for this sample belonging to which cluster).
            If `update_q_z=True`, return the cross-entropy loss for updating q_z.
        """
        if query is None or document is None:
            if forward_method == "encode_doc":
                bert_representation = model_predict(self.model_bert, document,
                                                    self.tokenizer, self.hparams.max_length,
                                                    self.hparams.pooling_strategy,
                                                    self.hparams.selected_layer_idx, device).detach()
                p = (self.p_z_y(bert_representation))
                return p.view(p.shape[0], p.shape[-1])
            elif forward_method == "encode_query":
                bert_representation = model_predict(self.model_bert, query,
                                                    self.tokenizer, self.hparams.max_length,
                                                    self.hparams.pooling_strategy,
                                                    self.hparams.selected_layer_idx, device).detach()
                p = (self.q_z_x(bert_representation))
                return p.view(p.shape[0], p.shape[-1])
            else:
                raise ValueError("Unexpected usage of forward.")
        if self.hparams.bert_fix:
            with torch.no_grad():
                query_bert = model_predict(self.model_bert, query,
                                           self.tokenizer, self.hparams.max_length,
                                           self.hparams.pooling_strategy,
                                           self.hparams.selected_layer_idx, device).detach()
                document_bert = model_predict(self.model_bert, document,
                                              self.tokenizer, self.hparams.max_length,
                                              self.hparams.pooling_strategy,
                                              self.hparams.selected_layer_idx, device).detach()
        else:
            query_bert = model_predict(self.model_bert, query,
                                       self.tokenizer, self.hparams.max_length,
                                       self.hparams.pooling_strategy,
                                       self.hparams.selected_layer_idx, device)
            document_bert = model_predict(self.model_bert, document,
                                          self.tokenizer, self.hparams.max_length,
                                          self.hparams.pooling_strategy,
                                          self.hparams.selected_layer_idx, device)

        p = self.p_z_y(document_bert)

        if forward_method == "update_q_z":
            return cross_entropy_p_q(p.detach(), self.q_z())
        elif forward_method != "update_all":
            raise ValueError("Unexpected usage of forward.")

        q = self.q_z_x(query_bert)
        h_z_cond = cross_entropy_p_q(p, q)
        h_z = cross_entropy_p_q(p, self.q_z())

        loss = h_z_cond - self.hparams.entropy_weight * h_z
        results = {'loss': loss, 'h_z_cond': h_z_cond, 'h_z': h_z}
        if is_monitor_forward:
            results_more = monitor_metrics(p.view(p.shape[0], p.shape[-1]), q.view(q.shape[0], q.shape[-1]))
            results.update(results_more)
        return results