#  Copyright 2020 Mesh TensorFlow authors, T5 Authors and HuggingFace Inc. team.
#  Copyright (c) 2023 Graphcore Ltd. All rights reserved.
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.

import warnings
from typing import Optional, Tuple, Union

import poptorch
import torch
import torch.nn as nn
from torch import Tensor
from transformers import MT5ForConditionalGeneration
from transformers.activations import NewGELUActivation
from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput
from transformers.models.mt5.modeling_mt5 import __HEAD_MASK_WARNING_MSG, MT5Block, MT5Stack

from optimum.utils import logging

from ...generation import IPUGenerationMixin
from ...modeling_utils import (
    PipelineMixin,
    SerializedEmbedding,
    SerializedLinear,
    SharedEmbedding,
    SplitProjection,
    get_layer_ipu,
    recomputation_checkpoint,
    register,
    split_encoder_decoder_ipu_config,
)


logger = logging.get_logger(__name__)


# Copied from optimum.graphcore.models.t5.modeling_t5.UpCastWrapper
class UpCastWrapper(nn.Module):
    def __init__(self, module: nn.Module, scale: float = 1.0):
        super().__init__()
        self.module = module
        self.scale = scale

    def forward(self, input):
        return self.module(input).to(torch.float32) * self.scale


# Copied from optimum.graphcore.models.t5.modeling_t5.CustomGELU
class CustomGELU(NewGELUActivation):
    # Work-around bug with torch.nn.GELU(approximate="tanh")
    # TODO: Remove this when bug is fixed
    def forward(self, input: Tensor) -> Tensor:
        safe = torch.logical_and(-39 < input, input < 39)
        safe_input = torch.where(safe, input, 0.0)
        gelu = super().forward(safe_input)
        relu = nn.functional.relu(input)
        return torch.where(safe, gelu, relu)


# Copied from optimum.graphcore.models.t5.modeling_t5.CustomT5Block with t5->mt5 and T5->MT5
class CustomMT5Block(MT5Block):
    def forward(
        self,
        hidden_states,
        attention_mask=None,
        position_bias=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        encoder_decoder_position_bias=None,
        layer_head_mask=None,
        cross_attn_layer_head_mask=None,
        past_key_value=None,
        use_cache=False,
        output_attentions=False,
        return_dict=True,
    ):
        if past_key_value is not None:
            if not self.is_decoder:
                logger.warning("`past_key_values` is passed to the encoder. Please make sure this is intended.")
            expected_num_past_key_values = 2 if encoder_hidden_states is None else 4

            if len(past_key_value) != expected_num_past_key_values:
                raise ValueError(
                    f"There should be {expected_num_past_key_values} past states. "
                    f"{'2 (past / key) for cross attention. ' if expected_num_past_key_values == 4 else ''}"
                    f"Got {len(past_key_value)} past key / value states"
                )

            self_attn_past_key_value = past_key_value[:2]
            cross_attn_past_key_value = past_key_value[2:]
        else:
            self_attn_past_key_value, cross_attn_past_key_value = None, None

        self_attention_outputs = self.layer[0](
            hidden_states,
            attention_mask=attention_mask,
            position_bias=position_bias,
            layer_head_mask=layer_head_mask,
            past_key_value=self_attn_past_key_value,
            use_cache=use_cache,
            output_attentions=output_attentions,
        )
        hidden_states, present_key_value_state = self_attention_outputs[:2]
        attention_outputs = self_attention_outputs[2:]  # Keep self-attention outputs and relative position weights

        # clamp inf values to enable fp16 training
        # Custom: Remove check for inf
        if hidden_states.dtype == torch.float16:
            clamp_value = torch.tensor(torch.finfo(hidden_states.dtype).max - 1000, dtype=hidden_states.dtype)
            hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)

        do_cross_attention = self.is_decoder and encoder_hidden_states is not None
        if do_cross_attention:
            # the actual query length is unknown for cross attention
            # if using past key value states. Need to inject it here
            if present_key_value_state is not None:
                query_length = present_key_value_state[0].shape[2]
            else:
                query_length = None

            cross_attention_outputs = self.layer[1](
                hidden_states,
                key_value_states=encoder_hidden_states,
                attention_mask=encoder_attention_mask,
                position_bias=encoder_decoder_position_bias,
                layer_head_mask=cross_attn_layer_head_mask,
                past_key_value=cross_attn_past_key_value,
                query_length=query_length,
                use_cache=use_cache,
                output_attentions=output_attentions,
            )
            hidden_states = cross_attention_outputs[0]

            # clamp inf values to enable fp16 training
            # Custom: Remove check for inf
            if hidden_states.dtype == torch.float16:
                clamp_value = torch.tensor(torch.finfo(hidden_states.dtype).max - 1000, dtype=hidden_states.dtype)
                hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)

            # Combine self attn and cross attn key value states
            if present_key_value_state is not None:
                present_key_value_state = present_key_value_state + cross_attention_outputs[1]

            # Keep cross-attention outputs and relative position weights
            attention_outputs = attention_outputs + cross_attention_outputs[2:]

        # Apply Feed Forward layer
        hidden_states = self.layer[-1](hidden_states)

        # clamp inf values to enable fp16 training
        # Custom: Remove check for inf
        if hidden_states.dtype == torch.float16:
            clamp_value = torch.tensor(torch.finfo(hidden_states.dtype).max - 1000, dtype=hidden_states.dtype)
            hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)

        outputs = (hidden_states,)

        if use_cache:
            outputs = outputs + (present_key_value_state,) + attention_outputs
        else:
            outputs = outputs + attention_outputs

        return outputs  # hidden-states, present_key_value_states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)


# Copied from optimum.graphcore.models.t5.modeling_t5.CustomT5Stack with t5->mt5 and T5->MT5
class CustomMT5Stack(MT5Stack):
    def invert_attention_mask(self, *args, **kwargs) -> Tensor:
        return super().invert_attention_mask(*args, **kwargs) * 0.75

    def get_extended_attention_mask(self, *args, **kwargs) -> Tensor:
        return super().get_extended_attention_mask(*args, **kwargs) * 0.75


@register(MT5ForConditionalGeneration)
class PipelinedMT5ForConditionalGeneration(MT5ForConditionalGeneration, PipelineMixin, IPUGenerationMixin):
    # Copied from optimum.graphcore.models.t5.modeling_t5.PipelinedT5ForConditionalGenerationCustomT5Stack.is_encoder_and_decoder_embeddings_computation_shared
    @property
    def is_encoder_and_decoder_embeddings_computation_shared(self):
        return isinstance(self.shared, SharedEmbedding)

    # Copied from optimum.graphcore.models.t5.modeling_t5.PipelinedT5ForConditionalGenerationCustomT5Stack.encoder_and_decoder_embeddings_computation with t5->mt5 and T5->MT5
    def encoder_and_decoder_embeddings_computation(self, use_shared_embedding: bool):
        """Sets the MT5ForConditionalGeneration shared embedding layer to SharedEmbedding that combines the computation under one layer.

        Args:
            use_shared_embedding: whether to use SharedEmbedding or not.
        """

        if use_shared_embedding:
            if isinstance(self.shared, SharedEmbedding):
                logger.warning("encoder and decoder embeddings computation is already shared")
            else:
                self.shared = SharedEmbedding(self.shared)
        else:
            if isinstance(self.shared, nn.Embedding):
                logger.warning("encoder and decoder embeddings computation is not shared")
            else:
                self.shared = self.shared.shared

    # Copied from optimum.graphcore.models.t5.modeling_t5.PipelinedT5ForConditionalGenerationCustomT5Stack.parallelize
    # with parallelization changes for MT5
    def parallelize(self, for_generation=False):
        """
        Transform the model to run in an IPU pipeline.
        - Adds pipeline stages to the model
        - (If enabled) Replaces the shared embedding with a SerializedEmbedding
        - Adds recomputation checkpoints

        Recommended usage:
        ```
        model = PipelinedMT5ForConditionalGeneration(config).parallelize().half()
        ```
        """
        PipelineMixin.parallelize(self)

        serialized_projection_splits_per_ipu = self.ipu_config._serialized_projection_splits_per_ipu
        projection_serialization_factor = (
            self.ipu_config._projection_serialization_factor
            if self._ipu_config._projection_serialization_factor
            else sum(serialized_projection_splits_per_ipu)
        )
        serialized_embedding_splits_per_ipu = self.ipu_config._serialized_embedding_splits_per_ipu
        embedding_serialization_factor = (
            self.ipu_config._embedding_serialization_factor
            if self.ipu_config._embedding_serialization_factor
            else sum(self.ipu_config._serialized_embedding_splits_per_ipu)
        )

        # Cannot shard input and output embeddings when using
        # tied weights. Using `SerializedLinear` is exempt since
        # the weights are not sharded
        if self.config.tie_word_embeddings and (
            embedding_serialization_factor > 1 or serialized_projection_splits_per_ipu is not None
        ):
            serialized_projection_splits_per_ipu_mode_str = self.ipu_config._get_managed_attr_mode_name(
                "serialized_projection_splits_per_ipu"
            )
            serialized_embedding_splits_per_ipu_mode_str = self.ipu_config._get_managed_attr_mode_name(
                "serialized_embedding_splits_per_ipu"
            )
            embedding_serialization_factor_mode_str = self.ipu_config._get_managed_attr_mode_name(
                "embedding_serialization_factor"
            )
            raise ValueError(
                "Cannot shard input and output embedding layers when using tied weights."
                f" {serialized_projection_splits_per_ipu_mode_str}={serialized_projection_splits_per_ipu}"
                f" {serialized_embedding_splits_per_ipu_mode_str}={serialized_embedding_splits_per_ipu}"
                " should not be provided when using tied input and output embeddings as it is"
                " redundant to split layers that can only reside on 1 IPU."
                f" {embedding_serialization_factor_mode_str}={embedding_serialization_factor}"
                " should also be set to 1 as creating a `SerializedEmbedding` will split the"
                " embedding table into sub embedding tables."
            )

        logger.info("-------------------- Device Allocation --------------------")

        if embedding_serialization_factor > 1:
            self.shared = SerializedEmbedding.from_model(self.shared, embedding_serialization_factor)
            self.encoder.embed_tokens = self.shared
            self.decoder.embed_tokens = self.shared

        if projection_serialization_factor > 1:
            if serialized_projection_splits_per_ipu is None:
                self.lm_head = SerializedLinear.from_model(self.lm_head, projection_serialization_factor)
                if self.config.tie_word_embeddings:
                    self.tie_weights()
            else:
                self.lm_head = SplitProjection.from_model(
                    self.lm_head, serialization_factor=projection_serialization_factor
                )

        self.encoder_and_decoder_embeddings_computation(True)

        # Parallelize the embedding layer
        if embedding_serialization_factor > 1 and serialized_embedding_splits_per_ipu is not None:
            # Sharing encoder and decoder computation wraps the
            # SerializedEmbedding using SharedEmbedding
            logger.info("Embedding Placement: ")
            self.shared.shared = self.shared.shared.parallelize(serialized_embedding_splits_per_ipu)
        else:
            logger.info("Embedding  --> IPU 0")
            self.shared = poptorch.BeginBlock(self.shared, "Embedding", ipu_id=0)

        # Use a custom MT5Stack implementation because sharing the position bias causes OOM error
        self.encoder.__class__ = CustomMT5Stack
        self.decoder.__class__ = CustomMT5Stack

        # Upcast input embeddings so that the residuals remain in FP32. This
        # cast is reversed where necessary by the MT5LayerNorm layers in:
        # - first layer of MT5LayerSelfAttention
        # - first layer of MT5LayerFF
        # - final_layer_norm
        # Which, conveniently, are all the places that this needs to happen.
        # Therefore, so we just need to upcast immediately before the residual
        # adds in MT5LayerSelfAttention and MT5LayerFF. This is handled in the
        # for loop below.
        self.encoder.embed_tokens = UpCastWrapper(self.encoder.embed_tokens)

        # Use a custom MT5Block implementation that removes a dynamic if blocks that can't be statically traced
        for block in self.encoder.block:
            block.__class__ = CustomMT5Block
            # Dropout happens immediately before the residual add. Inserting a
            # cast in MT5LayerSelfAttention and MT5LayerFF keeps the residual
            # structure in FP32
            block.layer[0].dropout = UpCastWrapper(block.layer[0].dropout)
            # Scale down the weights for the MT5LayerFF down-projection and
            # then scale its output back up again after it is cast to FP32
            scale = 8.0
            with torch.no_grad():
                block.layer[1].DenseReluDense.wo.weight /= scale
            block.layer[1].dropout = UpCastWrapper(block.layer[1].dropout, scale)
            # Prevent overflow in NewGELUActivation
            if self.config.dense_act_fn == "gelu_new":
                # TODO: Work-around bug with torch.nn.GELU(approximate="tanh"). Replace
                # this with block.layer[1].DenseReluDense.act = torch.nn.GELU(approximate="tanh")
                # when bug is fixed
                block.layer[1].DenseReluDense.act = CustomGELU()

        for block in self.decoder.block:
            block.__class__ = CustomMT5Block
            # Work-around bug with torch.nn.GELU(approximate="tanh")
            # TODO: Remove this when bug is fixed
            if self.config.dense_act_fn == "gelu_new":
                block.layer[2].DenseReluDense.act = CustomGELU()

        num_encoder_layers = len(self.encoder.block)
        num_decoder_layers = len(self.decoder.block)

        if for_generation:
            # If running for text generation we split the IPU config into two configs
            # because we run the encoder and decoder as separate Poplar executors.
            ipu_configs = split_encoder_decoder_ipu_config(self.ipu_config, num_encoder_layers, num_decoder_layers)
            self.encoder_ipu_config, self.decoder_ipu_config = ipu_configs
            encoder_layer_ipu = get_layer_ipu(self.encoder_ipu_config, num_encoder_layers)
            decoder_layer_ipu = get_layer_ipu(self.decoder_ipu_config, num_decoder_layers)
        else:
            number_of_layers = num_encoder_layers + num_decoder_layers
            layer_ipu = get_layer_ipu(self.ipu_config, number_of_layers)
            encoder_layer_ipu = layer_ipu[:num_encoder_layers]
            decoder_layer_ipu = layer_ipu[num_encoder_layers:]

        for index, (layer, ipu) in enumerate(zip(self.encoder.block, encoder_layer_ipu)):
            if self.ipu_config.recompute_checkpoint_every_layer and index != self.config.num_layers - 1:
                self._hooks.append(recomputation_checkpoint(layer))
            self.encoder.block[index] = poptorch.BeginBlock(layer, f"Encoder{index}", ipu_id=ipu)
            logger.info(f"Encoder {index:<2} --> IPU {ipu}")

        self.encoder.final_layer_norm = poptorch.BeginBlock(
            self.encoder.final_layer_norm, "Encoder Stack Final LayerNorm", ipu_id=ipu
        )

        for index, (layer, ipu) in enumerate(zip(self.decoder.block, decoder_layer_ipu)):
            if self.ipu_config.recompute_checkpoint_every_layer and index != self.config.num_layers - 1:
                self._hooks.append(recomputation_checkpoint(layer))
            self.decoder.block[index] = poptorch.BeginBlock(layer, f"Decoder{index}", ipu_id=ipu)
            logger.info(f"Decoder {index:<2} --> IPU {ipu}")

        self.decoder.final_layer_norm = poptorch.BeginBlock(
            self.decoder.final_layer_norm, "Decoder Stack Final LayerNorm", ipu_id=ipu
        )

        # Parallelize the lm head
        if self.config.tie_word_embeddings:
            # Place LM head on IPU 0
            ipu_id = 0
            logger.info(f"LM Head Output --> IPU {ipu_id}")
            self.lm_head = poptorch.BeginBlock(self.lm_head, "LM Head Output", ipu_id=ipu_id)
        else:
            # Place LM head on the last IPU if serialized_projection_splits_per_ipu is not provided
            # For generation: override serialized_projection_splits_per_ipu
            ipu_id = self.ipu_config._ipus_per_replica - 1
            if for_generation:
                serialized_projection_splits_per_ipu = self.decoder_ipu_config._serialized_projection_splits_per_ipu
                ipu_id = self.decoder_ipu_config._ipus_per_replica - 1

            # Parallelize `SplitLinear` layer if configuration is provided
            if self.lm_head.__class__ == SplitProjection:
                logger.info("LM Head Placement: ")
                self.lm_head = self.lm_head.parallelize(serialized_projection_splits_per_ipu)
            else:
                # Place SerializedLinear and nn.Linear forms of the lm head on the last IPU
                logger.info(f"LM Head Output --> IPU {ipu_id}")
                self.lm_head = poptorch.BeginBlock(self.lm_head, "LM Head Output", ipu_id=ipu_id)

        self.change_lm_head_to_indexed_input_linear(restore=not for_generation)

        logger.info("-----------------------------------------------------------")
        return self

    # Copied from optimum.graphcore.models.t5.modeling_t5.PipelinedT5ForConditionalGenerationCustomT5Stack.parallelize
    # with deparallelization changes for MT5
    def deparallelize(self):
        """
        Undo the changes to the model done by `parallelize`.
        You should call this before doing `save_pretrained` so that the `model.state_dict` is
        fully compatible with `transformers.MT5ForConditionalGeneration`.
        """
        # MT5ForConditionalGeneration has a deparallelize method, so make sure that the PipelineMixin one is used here.
        PipelineMixin.deparallelize(self)

        self.encoder_and_decoder_embeddings_computation(False)

        if self.shared.__class__ == SerializedEmbedding:
            self.shared = self.shared.to_model()
            self.encoder.embed_tokens = self.shared
            self.decoder.embed_tokens = self.shared

        self.change_lm_head_to_indexed_input_linear(restore=True)

        if self.lm_head.__class__ == SerializedLinear:
            self.lm_head = self.lm_head.to_model()
            if self.config.tie_word_embeddings:
                self.tie_weights()
        elif self.lm_head.__class__ == SplitProjection:
            self.lm_head = self.lm_head.to_model()

        self.encoder.__class__ = MT5Stack
        self.decoder.__class__ = MT5Stack

        for block in self.encoder.block:
            block.__class__ = MT5Block
            block.layer[0].dropout = block.layer[0].dropout.module
            with torch.no_grad():
                block.layer[1].DenseReluDense.wo.weight *= block.layer[1].dropout.scale
            block.layer[1].dropout = block.layer[1].dropout.module
            if self.config.dense_act_fn == "gelu_new":
                block.layer[1].DenseReluDense.act = NewGELUActivation()

        for block in self.decoder.block:
            block.__class__ = MT5Block
            if self.config.dense_act_fn == "gelu_new":
                block.layer[2].DenseReluDense.act = NewGELUActivation()

        return self

    # Copied from optimum.graphcore.models.t5.modeling_t5.PipelinedT5ForConditionalGenerationCustomT5Stack.forward
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        decoder_input_ids: Optional[torch.LongTensor] = None,
        decoder_attention_mask: Optional[torch.BoolTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        decoder_head_mask: Optional[torch.FloatTensor] = None,
        cross_attn_head_mask: Optional[torch.Tensor] = None,
        encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None,
        past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:
        use_cache = use_cache if use_cache is not None else self.config.use_cache
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        if self.is_encoder_and_decoder_embeddings_computation_shared:
            inputs_embeds, decoder_inputs_embeds = self.shared(
                input_ids=input_ids,
                decoder_input_ids=decoder_input_ids,
            )
            if inputs_embeds is not None:
                input_ids = None
            if decoder_inputs_embeds is not None:
                decoder_input_ids = None

        # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
        if head_mask is not None and decoder_head_mask is None:
            if self.config.num_layers == self.config.num_decoder_layers:
                warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning)
                decoder_head_mask = head_mask

        # Encode if needed (training, first prediction pass)
        if encoder_outputs is None:
            # Convert encoder inputs in embeddings if needed
            encoder_outputs = self.encoder(
                input_ids=input_ids,
                attention_mask=attention_mask,
                inputs_embeds=inputs_embeds,
                head_mask=head_mask,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
            )
        elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
            encoder_outputs = BaseModelOutput(
                last_hidden_state=encoder_outputs[0],
                hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
                attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
            )

        hidden_states = encoder_outputs[0]

        if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
            # get decoder inputs from shifting lm labels to the right
            decoder_input_ids = self._shift_right(labels)

        # Decode
        decoder_outputs = self.decoder(
            input_ids=decoder_input_ids,
            attention_mask=decoder_attention_mask,
            inputs_embeds=decoder_inputs_embeds,
            past_key_values=past_key_values,
            encoder_hidden_states=hidden_states,
            encoder_attention_mask=attention_mask,
            head_mask=decoder_head_mask,
            cross_attn_head_mask=cross_attn_head_mask,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        sequence_output = decoder_outputs[0]

        # Set device for model parallelism
        if self.model_parallel:
            self.lm_head = self.lm_head.to(self.encoder.first_device)
            sequence_output = sequence_output.to(self.lm_head.weight.device)

        if self.config.tie_word_embeddings:
            # Rescale output before projecting on vocab
            # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
            sequence_output = sequence_output * (self.model_dim**-0.5)

        lm_scale_modifier = getattr(self, "lm_scale_modifier", None)
        if lm_scale_modifier is not None:
            sequence_output = sequence_output * lm_scale_modifier

        lm_logits = self.lm_head(sequence_output)

        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
            loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
        # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666

        # Only returning the loss to make the communication between the host and the device faster.
        if not return_dict:
            output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs
            return (loss,) if labels is not None else output

        if loss is not None:
            return Seq2SeqLMOutput(
                loss=loss,
            )

        return Seq2SeqLMOutput(
            loss=loss,
            logits=lm_logits,
            past_key_values=decoder_outputs.past_key_values,
            decoder_hidden_states=decoder_outputs.hidden_states,
            decoder_attentions=decoder_outputs.attentions,
            cross_attentions=decoder_outputs.cross_attentions,
            encoder_last_hidden_state=encoder_outputs.last_hidden_state,
            encoder_hidden_states=encoder_outputs.hidden_states,
            encoder_attentions=encoder_outputs.attentions,
        )
