def forward()

in mm_action_prediction/models/decoder.py [0:0]


    def forward(self, batch, encoder_output):
        """Forward pass through the decoder.

        Args:
            batch: Dict of batch variables.
            encoder_output: Dict of outputs from the encoder.

        Returns:
            decoder_outputs: Dict of outputs from the forward pass.
        """
        # Flatten for history_agnostic encoder.
        batch_size, num_rounds, max_length = batch["assist_in"].shape
        decoder_in = support.flatten(batch["assist_in"], batch_size, num_rounds)
        decoder_out = support.flatten(batch["assist_out"], batch_size, num_rounds)
        decoder_len = support.flatten(batch["assist_in_len"], batch_size, num_rounds)
        word_embeds_dec = self.word_embed_net(decoder_in)

        if self.params["encoder"] in self.DIALOG_CONTEXT_ENCODERS:
            dialog_context = support.flatten(
                encoder_output["dialog_context"], batch_size, num_rounds
            ).unsqueeze(1)
            dialog_context = dialog_context.expand(-1, max_length, -1)
            decoder_steps_in = torch.cat([dialog_context, word_embeds_dec], -1)
        else:
            decoder_steps_in = word_embeds_dec

        # Encoder states conditioned on action outputs, if need be.
        if self.params["use_action_output"]:
            action_out = encoder_output["action_output_all"].unsqueeze(1)
            time_steps = encoder_output["hidden_states_all"].shape[1]
            fusion_out = torch.cat(
                [
                    encoder_output["hidden_states_all"],
                    action_out.expand(-1, time_steps, -1),
                ],
                dim=-1,
            )
            encoder_output["hidden_states_all"] = self.action_fusion_net(fusion_out)

        if self.params["text_encoder"] == "transformer":
            # Check the status of no_peek_mask.
            if self.no_peek_mask is None or self.no_peek_mask.size(0) != max_length:
                self.no_peek_mask = self._generate_no_peek_mask(max_length)

            hidden_state = encoder_output["hidden_states_all"]
            enc_pad_mask = batch["user_utt"] == batch["pad_token"]
            enc_pad_mask = support.flatten(enc_pad_mask, batch_size, num_rounds)
            dec_pad_mask = batch["assist_in"] == batch["pad_token"]
            dec_pad_mask = support.flatten(dec_pad_mask, batch_size, num_rounds)
            if self.params["encoder"] != "pretrained_transformer":
                dec_embeds = self.pos_encoder(decoder_steps_in).transpose(0, 1)
                outputs = self.decoder_unit(
                    dec_embeds,
                    hidden_state.transpose(0, 1),
                    memory_key_padding_mask=enc_pad_mask,
                    tgt_mask=self.no_peek_mask,
                    tgt_key_padding_mask=dec_pad_mask,
                )
                outputs = outputs.transpose(0, 1)
            else:
                outputs = self.decoder_unit(
                    inputs_embeds=decoder_steps_in,
                    attention_mask=~dec_pad_mask,
                    encoder_hidden_states=hidden_state,
                    encoder_attention_mask=~enc_pad_mask,
                )
                outputs = outputs[0]
        else:
            hidden_state = encoder_output["hidden_state"]
            if self.params["encoder"] == "tf_idf":
                hidden_state = None

            # If Bahdahnue attention is to be used.
            if (
                self.params["use_bahdanau_attention"]
                and self.params["encoder"] != "tf_idf"
            ):
                encoder_states = encoder_output["hidden_states_all"]
                max_decoder_len = min(
                    decoder_in.shape[1], self.params["max_decoder_len"]
                )
                encoder_states_proj = self.attention_net(encoder_states)
                enc_mask = (batch["user_utt"] == batch["pad_token"]).unsqueeze(-1)
                enc_mask = support.flatten(enc_mask, batch_size, num_rounds)
                outputs = []
                for step in range(max_decoder_len):
                    previous_state = hidden_state[0][-1].unsqueeze(1)
                    att_logits = previous_state * encoder_states_proj
                    att_logits = att_logits.sum(dim=-1, keepdim=True)
                    # Use encoder mask to replace <pad> with -Inf.
                    att_logits.masked_fill_(enc_mask, float("-Inf"))
                    att_wts = nn.functional.softmax(att_logits, dim=1)
                    context = (encoder_states * att_wts).sum(1, keepdim=True)
                    # Run through LSTM.
                    concat_in = [context, decoder_steps_in[:, step : step + 1, :]]
                    step_in = torch.cat(concat_in, dim=-1)
                    decoder_output, hidden_state = self.decoder_unit(
                        step_in, hidden_state
                    )
                    concat_out = torch.cat([decoder_output, context], dim=-1)
                    outputs.append(concat_out)
                outputs = torch.cat(outputs, dim=1)
            else:
                outputs = rnn.dynamic_rnn(
                    self.decoder_unit,
                    decoder_steps_in,
                    decoder_len,
                    init_state=hidden_state,
                )
        if self.params["encoder"] == "pretrained_transformer":
            output_logits = outputs
        else:
            # Logits over vocabulary.
            output_logits = self.inv_word_net(outputs)
        # Mask out the criterion while summing.
        pad_mask = support.flatten(batch["assist_mask"], batch_size, num_rounds)
        loss_token = self.criterion(output_logits.transpose(1, 2), decoder_out)
        loss_token.masked_fill_(pad_mask, 0.0)
        return {"loss_token": loss_token, "pad_mask": pad_mask}