def forward()

in mm_action_prediction/models/assistant.py [0:0]


    def forward(self, batch, mode=None):
        """Forward propagation.

        Args:
          batch: Dict of batch input variables.
          mode: None for training or teaching forcing evaluation;
                BEAMSEARCH / SAMPLE / MAX to generate text
        """
        outputs = self.encoder(batch)
        action_output = self.action_executor(batch, outputs)
        outputs.update(action_output)
        decoder_output = self.decoder(batch, outputs)
        if mode:
            generation_output = self.decoder.forward_beamsearch_multiple(
                batch, outputs, mode
            )
            outputs.update(generation_output)

        # If evaluating by retrieval, construct fake batch for each candidate.
        # Inputs from batch used in decoder:
        #   assist_in, assist_out, assist_in_len, assist_mask
        if self.params["retrieval_evaluation"] and not self.training:
            option_scores = []
            batch_size, num_rounds, num_candidates, _ = batch["candidate_in"].shape
            replace_keys = ("assist_in", "assist_out", "assist_in_len", "assist_mask")
            for ii in range(num_candidates):
                for key in replace_keys:
                    new_key = key.replace("assist", "candidate")
                    batch[key] = batch[new_key][:, :, ii]
                decoder_output = self.decoder(batch, outputs)
                log_probs = torch_support.unflatten(
                    decoder_output["loss_token"], batch_size, num_rounds
                )
                option_scores.append(-1 * log_probs.sum(-1))
            option_scores = torch.stack(option_scores, 2)
            outputs["candidate_scores"] = [
                {
                    "dialog_id": batch["dialog_id"][ii].item(),
                    "candidate_scores": [
                        {
                            "scores": [
                                float(kk) for kk in option_scores[ii, jj].cpu()
                            ],
                            "turn_id": jj
                        }
                        for jj in range(batch["dialog_len"][ii])
                    ]
                }
                for ii in range(batch_size)
            ]

        # Local aliases.
        loss_token = decoder_output["loss_token"]
        pad_mask = decoder_output["pad_mask"]
        if self.training:
            loss_token = loss_token.sum() / (~pad_mask).sum().item()
            loss_action = action_output["action_loss"]
            loss_action_attr = action_output["action_attr_loss"]
            loss_total = loss_action + loss_token + loss_action_attr
            return {
                "token": loss_token,
                "action": loss_action,
                "action_attr": loss_action_attr,
                "total": loss_total,
            }
        else:
            outputs.update(
                {"loss_sum": loss_token.sum(), "num_tokens": (~pad_mask).sum()}
            )
            return outputs