# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for generation with Neuron."""

import copy
import inspect
import warnings
from functools import wraps
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import torch
import torch.distributed as dist

from optimum.neuron.utils.import_utils import is_torch_xla_available

from ..utils.import_utils import is_neuronx_distributed_available
from ..utils.misc import args_and_kwargs_to_kwargs_only


if is_torch_xla_available():
    import torch_xla.core.xla_model as xm
from transformers import GenerationMixin
from transformers.generation.beam_search import BeamScorer, BeamSearchScorer
from transformers.generation.configuration_utils import GenerationConfig
from transformers.generation.logits_process import (
    LogitsProcessorList,
)
from transformers.generation.stopping_criteria import (
    MaxLengthCriteria,
    MaxTimeCriteria,
    StoppingCriteriaList,
    validate_stopping_criteria,
)
from transformers.generation.utils import (
    BeamSearchDecoderOnlyOutput,
    BeamSearchEncoderDecoderOutput,
    BeamSearchOutput,
    GenerateOutput,
    GreedySearchDecoderOnlyOutput,
    GreedySearchEncoderDecoderOutput,
    GreedySearchOutput,
)
from transformers.utils import ModelOutput, logging


logger = logging.get_logger(__name__)


if is_neuronx_distributed_available():
    from neuronx_distributed.parallel_layers import parallel_state


def _move_dict_args_to_device(kwargs: Dict[str, Any], device: str = "cpu") -> Dict[str, Any]:
    """
    Takes keyword arguments which will be passed to a model's forward function
    and moves its values to `device` if
    they are of type `torch.Tensor`. If the key is a dictionary it does the same to the
    respective values.
    Args:
        kwargs: (`Dict[str, Any]`):
            The kwargs to be passed to the models forward function.
        device: (`str`, defaults to `cpu`):
            The target device to which tensors should be moved.
    Returns:
        `Dict[str, Any]`: The kwargs dict with its tensors moved to `device`.
    """

    def needs_move(src_device, tgt_device):
        return src_device != tgt_device

    for k, v in kwargs.items():
        # Handle nested dicts
        if isinstance(v, dict):
            for k_, v_ in v.items():
                if isinstance(v_, torch.Tensor):
                    if needs_move(v_.device, device):
                        v[k_] = v_.to(device=device)

        # Handle tensor types
        elif isinstance(v, torch.Tensor):
            if needs_move(v.device, device):
                kwargs[k] = v.to(device=device)

        # Handle past_key_value tuples
        elif k == "past_key_values":
            if v is not None:
                new_past_key_values = ()
                for layer_past in v:
                    new_layer_past = ()
                    for past_state in layer_past:
                        if needs_move(past_state.device, device):
                            new_layer_past += (past_state.to(device=device),)
                        else:
                            new_layer_past += (past_state,)
                    new_past_key_values += (new_layer_past,)
                kwargs[k] = new_past_key_values

    return kwargs


def _pad_input_ids_for_general_sampling(
    input_ids: torch.Tensor, num_padding_values: int, pad_token_id: int
) -> torch.Tensor:
    """
    Pads `input_ids` with `num_padding_values` padding tokens along the second dimension.
    Args:
        input_ids (`torch.Tensor`):
            Input ids to be padded.
        num_padding_values (`int`):
            Number of padding values to add.
        pad_token_id (`int`):
            Token ID of padding token.
    Returns:
        `torch.Tensor`: Padded `input_ids`.
    """
    bsz = input_ids.size(0)
    input_ids = torch.cat(
        [input_ids, torch.ones((bsz, num_padding_values), device=input_ids.device, dtype=torch.long) * pad_token_id], 1
    )
    return input_ids


def _get_fwd_for_general_sampling(
    current_fwd: Callable,
    generation_config: GenerationConfig,
    is_encoder_decoder: bool,
    vocab_size: int,
    main_device: str,
    to_device: str = "cpu",
    output_dtype: torch.dtype = torch.float32,
) -> Callable:
    """
    Wraps the passed forward function and extends it such that before each forward call
    the `decoder_input_ids` are padded and all tensors are moved to `main_device` (e.g. XLA).
    Then the original forward passed is called followed by a `xm.mark_step`. Subsequently,
    an "unpadding" of the logits is performed. This way, all functions that process the logits
    can be called without making any changes.
    Args:
        current_fwd (`Callable`):
            The current forward function of the model.
        generation_config (`GenerationConfig`):
            The GenerationConfig of the model.
        is_encoder_decoder (`bool`):
            Defines if this is a encoder-decoder model.
        vocab_size (`int`):
            The total number of vocabs of the current model.
        main_device (`str`):
            The device on which the forward pass should be executed.
        to_device (`str`, defaults to `cpu`):
            The device on which all other processing should be executed.
        output_dtype (`torch.dtype`, defaults to `torch.float32`):
            The expected data type of the output logits.
    Returns:
        `Callable`: The extended forward function.
    """

    @wraps(current_fwd)
    def new_fwd(*args, **kwargs):
        # Pad input to max length
        cur_len = None
        input_ids_string = "decoder_input_ids" if is_encoder_decoder else "input_ids"
        if input_ids_string in kwargs:
            current_input_ids = kwargs[input_ids_string]
            batch_size, cur_len = current_input_ids.shape
            num_padding_values = generation_config.max_length - cur_len
            kwargs[input_ids_string] = _pad_input_ids_for_general_sampling(
                current_input_ids, num_padding_values, generation_config.pad_token_id
            )

            # For decoder only models, pad decoder attention mask in addition to prompts
            if "attention_mask" in kwargs and not is_encoder_decoder and num_padding_values > 0:
                kwargs["attention_mask"] = torch.cat(
                    [
                        kwargs["attention_mask"],
                        torch.zeros((batch_size, (generation_config.max_length - cur_len)))
                        .long()
                        .to(kwargs["attention_mask"].device),
                    ],
                    1,
                )
                # create position_ids on the fly for batch generation
                if "position_ids" in set(inspect.signature(current_fwd).parameters.keys()):
                    position_ids = kwargs["attention_mask"].long().cumsum(-1) - 1
                    position_ids.masked_fill_(kwargs["attention_mask"] == 0, 1)
                    kwargs["position_ids"] = position_ids

        # Move inputs to device
        _move_dict_args_to_device(kwargs, main_device)

        # Forward
        kwargs = args_and_kwargs_to_kwargs_only(current_fwd, args, kwargs)
        outputs = current_fwd(**kwargs)
        # Gather outputs if NxD tensor parallelism is applied and the output logits have not been gathered.
        if (
            is_neuronx_distributed_available()
            and parallel_state.model_parallel_is_initialized()
            and parallel_state.get_tensor_model_parallel_size() > 1
            and outputs["logits"].shape[-1] != vocab_size
        ):
            outputs["logits"] = xm.all_gather(
                outputs["logits"],
                dim=-1,
                groups=parallel_state.get_tensor_model_parallel_group(as_list=True),
            )
        xm.mark_step()

        # Move to CPU
        _move_dict_args_to_device(outputs, to_device)

        # Post-process output as a function of cur_len
        outputs["logits"] = outputs["logits"][:, :cur_len, ...].to(output_dtype)

        return outputs

    return new_fwd


class GeneralNeuronGenerationMixin(GenerationMixin):
    """
    A class containing all functions for auto-regressive text generation on Trn1, to be used as a mixin in [`PreTrainedModel`].
    The generation will be handled on both CPU and TRN1 in the following way:
      1. Model forward pass will be executed on TRN1
      2. All other logics including padding, searching, and sampling will be handled by general device (CPU).
    This implementation allows us to support general searching and sampling methods with minimal code changes.
    """

    @torch.no_grad()
    def generate(
        self,
        inputs: Optional[torch.Tensor] = None,
        generation_config: Optional[GenerationConfig] = None,
        **kwargs,
    ):
        # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
        self._validate_model_class()

        # priority: `generation_config` argument > `model.generation_config` (the default generation config)
        if generation_config is None:
            # legacy: users may modify the model configuration to control generation. To trigger this legacy behavior,
            # two conditions must be met
            # 1) the generation config must have been created from the model config (`_from_model_config` field);
            # 2) the generation config must have seen no modification since its creation (the hash is the same).
            if self.generation_config._from_model_config:
                new_generation_config = GenerationConfig.from_model_config(self.config)
                if new_generation_config != self.generation_config:
                    warnings.warn(
                        "You have modified the pretrained model configuration to control generation. This is a"
                        " deprecated strategy to control generation and will be removed soon, in a future version."
                        " Please use and modify the model generation configuration (see"
                        " https://huggingface.co/docs/transformers/generation_strategies#default-text-generation-configuration )"
                    )
                    self.generation_config = new_generation_config
            generation_config = self.generation_config

        generation_config = copy.deepcopy(generation_config)
        model_kwargs = generation_config.update(**kwargs)  # All unused kwargs must be model kwargs
        generation_config.validate()
        self._validate_model_kwargs(model_kwargs.copy())

        # 2. Set generation parameters if not already defined
        if generation_config.pad_token_id is None and generation_config.eos_token_id is not None:
            if model_kwargs.get("attention_mask", None) is None:
                logger.warning(
                    "The attention mask and the pad token id were not set. As a consequence, you may observe "
                    "unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results."
                )
            eos_token_id = generation_config.eos_token_id
            if isinstance(eos_token_id, list):
                eos_token_id = eos_token_id[0]
            logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
            generation_config.pad_token_id = eos_token_id

        # 3. Define model inputs and move to CPU
        general_device = "cpu"
        if "input_ids" in kwargs and kwargs["input_ids"] is not None:
            kwargs["input_ids"] = kwargs["input_ids"].to(general_device)
        if inputs is not None:
            inputs = inputs.to(general_device)
        input_ids, model_input_name, model_kwargs = self._prepare_model_inputs(
            inputs, generation_config.bos_token_id, model_kwargs
        )

        # 4. Set Neuron specific generation configurations
        original_forward = copy.deepcopy(self.forward)
        try:
            general_forward = _get_fwd_for_general_sampling(
                self.forward,
                generation_config,
                self.config.is_encoder_decoder,
                self.config.vocab_size,
                self.device,
            )
            self.forward = general_forward
            if generation_config.use_cache:
                warnings.warn(
                    "use_cache is not supported for generation on Neuron devices, switching to use_cache=False."
                )
                # decoder-only models with inputs_embeds forwarding must use caching (otherwise we can't detect whether we are
                # generating the first new token or not, and we only want to use the embeddings for the first new token)
                if not self.config.is_encoder_decoder and model_input_name == "inputs_embeds":
                    raise ValueError("Decoder-only models with inputs_embeds forwarding must use `use_cache=True`")
            generation_config.use_cache = False
            if generation_config.max_new_tokens is not None:
                generation_config.max_length = generation_config.max_new_tokens + input_ids.shape[-1]

            # 5. Run HuggingFace generate function
            return super().generate(inputs, generation_config, **kwargs)
        finally:
            self.forward = original_forward

    def _prepare_encoder_decoder_kwargs_for_generation(
        self, inputs_tensor: torch.Tensor, model_kwargs, model_input_name: Optional[str] = None
    ) -> Dict[str, Any]:
        """Move the input tensor to XLA device and move the output tensors back to CPU."""
        output = super()._prepare_encoder_decoder_kwargs_for_generation(
            inputs_tensor.to(self.device), model_kwargs, model_input_name
        )
        _move_dict_args_to_device(output, "cpu")
        return output


class NeuronGenerationMixin(GenerationMixin):
    """
    A class containing all functions for auto-regressive text generation on Trn1, to be used as a mixin in [`PreTrainedModel`].

    The class exposes [`~generation.GenerationMixin.generate`], which can be used for:
        - *greedy decoding* by calling [`~generation.GenerationMixin.greedy_search`] if `num_beams=1` and
          `do_sample=False`
        - *contrastive search* by calling [`~generation.GenerationMixin.contrastive_search`] if `penalty_alpha>0` and
          `top_k>1`
        - *multinomial sampling* by calling [`~generation.GenerationMixin.sample`] if `num_beams=1` and
          `do_sample=True`
        - *beam-search decoding* by calling [`~generation.GenerationMixin.beam_search`] if `num_beams>1` and
          `do_sample=False`
        - *beam-search multinomial sampling* by calling [`~generation.GenerationMixin.beam_sample`] if `num_beams>1`
          and `do_sample=True`
        - *diverse beam-search decoding* by calling [`~generation.GenerationMixin.group_beam_search`], if `num_beams>1`
          and `num_beam_groups>1`
        - *constrained beam-search decoding* by calling [`~generation.GenerationMixin.constrained_beam_search`], if
          `constraints!=None` or `force_words_ids!=None`

    You do not need to call any of the above methods directly. Pass custom parameter values to 'generate' instead. To
    learn more about decoding strategies refer to the [text generation strategies guide](../generation_strategies).
    """

    @staticmethod
    def _initialize_attention(
        model_kwargs,
        num_padding_values,
        batch_size,
        device,
        is_encoder_decoder,
    ):
        """Initializes the appropriate attention mask -- encoder-decoder models use `decoder_attention_mask`"""
        if is_encoder_decoder:
            # One 1 for decoder_start_token_id, 0s for the currently-unfilled locations in the past_key_values tensor,
            # 1s for the actual input_ids
            decoder_attention_mask = torch.cat(
                [
                    torch.zeros((batch_size, num_padding_values), dtype=torch.int32),
                    torch.ones((batch_size, 2), dtype=torch.int32),
                ],
                axis=1,
            ).to(device)
            mask = {"decoder_attention_mask": decoder_attention_mask}
        else:
            attention_mask = model_kwargs.pop("attention_mask")
            # 0s for the currently-unfilled locations in the past_key_values tensor, 1s for the actual input_ids
            attention_mask = torch.cat(
                [
                    torch.zeros(
                        (batch_size, num_padding_values), dtype=attention_mask.dtype, device=attention_mask.device
                    ),
                    attention_mask,
                    torch.ones((batch_size, 1), dtype=attention_mask.dtype, device=attention_mask.device),
                ],
                axis=1,
            )
            mask = {"attention_mask": attention_mask}

        return mask

    @staticmethod
    def _update_attention(model_kwargs, batch_size, is_encoder_decoder):
        """Updates the appropriate attention mask -- encoder-decoder models use `decoder_attention_mask`"""

        attention_mask_name = "decoder_attention_mask" if is_encoder_decoder else "attention_mask"
        attention_mask = model_kwargs.pop(attention_mask_name)
        attention_mask_update_slice = torch.ones(
            (batch_size, 1), dtype=attention_mask.dtype, device=attention_mask.device
        )
        attention_mask = torch.cat([attention_mask[:, 1:], attention_mask_update_slice], dim=-1)
        mask = {attention_mask_name: attention_mask}
        return mask

    @staticmethod
    def _initialize_past(past_key_values, num_padding_values):
        """Initialize past_key_values with zeros -- the structure depends on `batch_axis`"""

        new_past = ()
        for past_layer in past_key_values:
            new_past_layer = list(past_layer)
            for i in range(len(new_past_layer[:2])):
                b, n_heads, _, head_dim = past_layer[i].shape
                new_past_layer[i] = torch.cat(
                    [
                        torch.zeros(
                            (b, n_heads, num_padding_values, head_dim),
                            dtype=past_layer[i].dtype,
                            device=past_layer[i].device,
                        ),
                        past_layer[i],
                    ],
                    dim=2,
                )
            new_past += (tuple(new_past_layer),)

        return new_past

    @staticmethod
    def _update_past(past_key_values):
        new_past = ()
        for past_layer in past_key_values:
            new_past_layer = list(past_layer)
            for i, _ in enumerate(new_past_layer[:2]):
                new_past_layer[i] = past_layer[i][:, :, 1:]
            new_past += (tuple(new_past_layer),)

        return new_past

    def _update_model_kwargs_for_xla_generation(
        self,
        outputs: ModelOutput,
        model_kwargs: Dict[str, Any],
        batch_size: int,
        is_encoder_decoder: bool = False,
        standardize_cache_format: bool = False,
        max_length: Optional[int] = None,
        seq_length: Optional[int] = None,
        use_cache: bool = True,
    ) -> Dict[str, Any]:
        if use_cache:
            past_key_values = self._extract_past_from_model_output(outputs)
            if past_key_values is None:
                raise ValueError(
                    "No known `past_key_values variable` found in model outputs (model outputs keys:"
                    f" {list(outputs.keys())})"
                )
            is_past_initialized = model_kwargs.pop("past_key_values", None) is not None

            if not is_past_initialized:
                # The padded version of `past_key_values` has a length of `max_length - 1`, as `past_key_values` holds information relative to
                # previous autoregressive generation steps (step 0 has no past_key_values, step 1 has 1 past_key_values value, ..., the last step
                # has `max_length - 1` past_key_values values).
                num_padding_values = max_length - seq_length
                mask = self._initialize_attention(
                    model_kwargs, num_padding_values, batch_size, outputs.logits.device, is_encoder_decoder
                )
                new_past = self._initialize_past(past_key_values, num_padding_values)
            else:
                mask = self._update_attention(model_kwargs, batch_size, is_encoder_decoder)
                new_past = self._update_past(past_key_values)

            # sets the updated variables (mask and past_key_values)
            model_kwargs.update(mask)
            model_kwargs["past_key_values"] = tuple(new_past)
        else:
            model_kwargs["past_key_values"] = None
            if "token_type_ids" in model_kwargs:
                token_type_ids = model_kwargs["token_type_ids"]
                model_kwargs["token_type_ids"] = torch.cat(
                    [token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1
                )

            if not is_encoder_decoder:
                # update attention mask
                if "attention_mask" in model_kwargs:
                    batch_size = model_kwargs["attention_mask"].shape[0]
                    update_indices = torch.stack(
                        [torch.arange(batch_size), torch.tensor(seq_length).repeat(batch_size)], dim=-1
                    )
                    model_kwargs["attention_mask"][update_indices[:, 0], update_indices[:, 1]] = model_kwargs[
                        "attention_mask"
                    ].new_ones((batch_size, 1))

            else:
                # update decoder attention mask
                if "decoder_attention_mask" in model_kwargs:
                    batch_size = model_kwargs["decoder_attention_mask"].shape[0]
                    update_indices = torch.stack(
                        [torch.arange(batch_size), torch.tensor(seq_length).repeat(batch_size)], dim=-1
                    )
                    model_kwargs["decoder_attention_mask"][update_indices[:, 0], update_indices[:, 1]] = model_kwargs[
                        "decoder_attention_mask"
                    ].new_ones((batch_size, 1))

        return model_kwargs

    @staticmethod
    def _expand_inputs_for_generation(
        expand_size: int = 1,
        is_encoder_decoder: bool = False,
        input_ids: Optional[torch.LongTensor] = None,
        **model_kwargs,
    ) -> Tuple[torch.LongTensor, Dict[str, Any]]:
        """Expands tensors from [batch_size, ...] to [batch_size * expand_size, ...]"""

        def _expand_dict_for_generation(dict_to_expand):
            for key in dict_to_expand:
                if dict_to_expand[key] is not None and isinstance(dict_to_expand[key], torch.Tensor):
                    if len(dict_to_expand[key].shape) == 2:
                        dict_to_expand[key] = (
                            dict_to_expand[key].repeat(1, expand_size).view(-1, dict_to_expand[key].shape[1])
                        )
                    elif len(dict_to_expand[key].shape) <= 1:
                        dict_to_expand[key] = dict_to_expand[key].repeat(expand_size)
                    else:
                        dict_to_expand[key] = torch.concat(
                            [tensor.unsqueeze(0).repeat(expand_size, 1, 1) for tensor in dict_to_expand[key]]
                        )
            return dict_to_expand

        if input_ids is not None:
            # Manual repeat interleave
            input_ids = input_ids.repeat(1, expand_size).view(-1, input_ids.shape[1])

        model_kwargs = _expand_dict_for_generation(model_kwargs)

        if is_encoder_decoder:
            if model_kwargs.get("encoder_outputs") is None:
                raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.")
            model_kwargs["encoder_outputs"] = _expand_dict_for_generation(model_kwargs["encoder_outputs"])

        return input_ids, model_kwargs

    @torch.no_grad()
    def generate(
        self,
        inputs: Optional[torch.Tensor] = None,
        generation_config: Optional[GenerationConfig] = None,
        logits_processor: Optional[LogitsProcessorList] = None,
        stopping_criteria: Optional[StoppingCriteriaList] = None,
        prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
        synced_gpus: Optional[bool] = None,
        is_traced_inference: bool = False,
        **kwargs,
    ) -> Union[GenerateOutput, torch.LongTensor]:
        r"""

        Generates sequences of token ids for models with a language modeling head.

        <Tip warning={true}>

        Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the
        model's default generation configuration. You can override any `generation_config` by passing the corresponding
        parameters to generate(), e.g. `.generate(inputs, num_beams=4, do_sample=True)`.

        For an overview of generation strategies and code examples, check out the [following
        guide](../generation_strategies).

        </Tip>

        Parameters:
            inputs (`Optional[torch.Tensor]`, defaults to `None`):
                The sequence used as a prompt for the generation or as model inputs to the encoder. If `None` the
                method initializes it with `bos_token_id` and a batch size of 1. For decoder-only models `inputs`
                should of in the format of `input_ids`. For encoder-decoder models *inputs* can represent any of
                `input_ids`, `input_values`, `input_features`, or `pixel_values`.
            generation_config (`Optional[GenerationConfig]`, defaults to `None`):
                The generation configuration to be used as base parametrization for the generation call. `**kwargs`
                passed to generate matching the attributes of `generation_config` will override them. If
                `generation_config` is not provided, the default will be used, which had the following loading
                priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
                configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
                default values, whose documentation should be checked to parameterize generation.
            logits_processor (`Optional[LogitsProcessorList]`, defaults to `None`):
                Custom logits processors that complement the default logits processors built from arguments and
                generation config. If a logit processor is passed that is already created with the arguments or a
                generation config an error is thrown. This feature is intended for advanced users.
            stopping_criteria (`Optional[StoppingCriteriaList]`, defaults to `None`):
                Custom stopping criteria that complement the default stopping criteria built from arguments and a
                generation config. If a stopping criteria is passed that is already created with the arguments or a
                generation config an error is thrown. This feature is intended for advanced users.
            prefix_allowed_tokens_fn (`Callable[[int, torch.Tensor], List[int]]`, *optional*):
                If provided, this function constraints the beam search to allowed tokens only at each step. If not
                provided no constraint is applied. This function takes 2 arguments: the batch ID `batch_id` and
                `input_ids`. It has to return a list with the allowed tokens for the next generation step conditioned
                on the batch ID `batch_id` and the previously generated tokens `inputs_ids`. This argument is useful
                for constrained generation conditioned on the prefix, as described in [Autoregressive Entity
                Retrieval](https://arxiv.org/abs/2010.00904).
            synced_gpus (`Optional[bool]`, defaults to `None`):
                Whether to continue running the while loop until max_length. Unless overridden this flag will be set to
                `True` under DeepSpeed ZeRO Stage 3 multiple GPUs environment to avoid hanging if one GPU finished
                generating before other GPUs. Otherwise it'll be set to `False`.
            is_traced_inference (`bool`, defaults to `False`):
                Whether the decoder is traced or using XLA lazy tensor. If the decoder is traced, next tokens and the beam scores
                are computed inside the decoder.
            kwargs:
                Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
                forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder
                specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with *decoder_*.

        Return:
            [`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True`
            or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor`.

                If the model is *not* an encoder-decoder model (`model.config.is_encoder_decoder=False`), the possible
                [`~utils.ModelOutput`] types are:

                    - [`~generation.GreedySearchDecoderOnlyOutput`],
                    - [`~generation.SampleDecoderOnlyOutput`],
                    - [`~generation.BeamSearchDecoderOnlyOutput`],
                    - [`~generation.BeamSampleDecoderOnlyOutput`]

                If the model is an encoder-decoder model (`model.config.is_encoder_decoder=True`), the possible
                [`~utils.ModelOutput`] types are:

                    - [`~generation.GreedySearchEncoderDecoderOutput`],
                    - [`~generation.SampleEncoderDecoderOutput`],
                    - [`~generation.BeamSearchEncoderDecoderOutput`],
                    - [`~generation.BeamSampleEncoderDecoderOutput`]
        """

        # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
        self._validate_model_class()

        # priority: `generation_config` argument > `model.generation_config` (the default generation config)
        if generation_config is None:
            # legacy: users may modify the model configuration to control generation -- update the generation config
            # model attribute accordingly, if it was created from the model config
            if self.generation_config._from_model_config:
                new_generation_config = GenerationConfig.from_model_config(self.config)
                if new_generation_config != self.generation_config:
                    warnings.warn(
                        "You have modified the pretrained model configuration to control generation. This is a"
                        " deprecated strategy to control generation and will be removed soon, in a future version."
                        " Please use a generation configuration file (see"
                        " https://huggingface.co/docs/transformers/main_classes/text_generation)"
                    )
                    self.generation_config = new_generation_config
            generation_config = self.generation_config

        generation_config = copy.deepcopy(generation_config)
        model_kwargs = generation_config.update(**kwargs)  # All unused kwargs must be model kwargs
        generation_config.validate()
        self._validate_model_kwargs(model_kwargs.copy())

        # 2. Set generation parameters if not already defined
        logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
        stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()

        accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys())
        requires_attention_mask = "encoder_outputs" not in model_kwargs
        kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None

        # 3. Define model inputs
        # inputs_tensor has to be defined
        # model_input_name is defined if model-specific keyword input is passed
        # otherwise model_input_name is None
        # all model-specific keyword inputs are removed from `model_kwargs`
        inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(
            inputs, generation_config.bos_token_id, model_kwargs
        )
        batch_size = inputs_tensor.shape[0]

        # 4. Define other model kwargs
        model_kwargs["output_attentions"] = generation_config.output_attentions
        model_kwargs["output_hidden_states"] = generation_config.output_hidden_states
        if generation_config.use_cache and not is_traced_inference:
            warnings.warn("use_cache is not supported for generation on Neuron devices, switching to use_cache=False.")
            model_kwargs["use_cache"] = False
        else:
            model_kwargs["use_cache"] = generation_config.use_cache

        accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys())
        requires_attention_mask = "encoder_outputs" not in model_kwargs and not is_traced_inference

        if model_kwargs.get("attention_mask", None) is None and requires_attention_mask and accepts_attention_mask:
            model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
                inputs_tensor, generation_config.pad_token_id, generation_config.eos_token_id
            )

        device = inputs_tensor.device
        self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=device)

        # decoder-only models should use left-padding for generation
        if not self.config.is_encoder_decoder:
            if (
                generation_config.pad_token_id is not None
                and torch.sum(inputs_tensor[:, -1] == generation_config.pad_token_id) > 0
            ):
                logger.warning(
                    "A decoder-only architecture is being used, but right-padding was detected! For correct "
                    "generation results, please set `padding_side='left'` when initializing the tokenizer."
                )

        if self.config.is_encoder_decoder and "encoder_outputs" not in model_kwargs and not is_traced_inference:
            # if model is encoder decoder encoder_outputs are created
            # and added to `model_kwargs`
            model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(
                inputs_tensor, model_kwargs, model_input_name
            )

        # 5. Prepare `input_ids` which will be used for auto-regressive generation
        if self.config.is_encoder_decoder:
            input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation(
                batch_size=batch_size,
                model_input_name=model_input_name,
                model_kwargs=model_kwargs,
                decoder_start_token_id=generation_config._decoder_start_token_tensor,
                device=inputs_tensor.device,
            )
        else:
            input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids")

        # 6. Prepare `max_length` depending on other stopping criteria.
        input_ids_seq_length = input_ids.shape[-1]
        has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
        if has_default_max_length and generation_config.max_new_tokens is None:
            warnings.warn(
                f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. "
                "This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we"
                " recommend using `max_new_tokens` to control the maximum length of the generation.",
                UserWarning,
            )
        elif generation_config.max_new_tokens is not None:
            generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
            if not has_default_max_length:
                logger.warning(
                    f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
                    f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
                    "Please refer to the documentation for more information. "
                    "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
                )

        if generation_config.min_length is not None and generation_config.min_length > generation_config.max_length:
            raise ValueError(
                f"Unfeasible length constraints: the minimum length ({generation_config.min_length}) is larger than"
                f" the maximum length ({generation_config.max_length})"
            )
        if input_ids_seq_length >= generation_config.max_length:
            input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
            logger.warning(
                f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to"
                f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
                " increasing `max_new_tokens`."
            )

        # Pad to max_length
        input_ids = torch.cat(
            [
                input_ids,
                (
                    torch.ones(
                        (batch_size, (generation_config.max_length - input_ids_seq_length)),
                    )
                    .long()
                    .to(input_ids.device)
                )
                * generation_config.pad_token_id,
            ],
            1,
        )
        # For decoder only models, pad decoder attention mask in addition to prompts
        if (
            "attention_mask" in model_kwargs
            and model_kwargs.get("use_cache", False) is False
            and not self.config.is_encoder_decoder
        ):
            model_kwargs["attention_mask"] = torch.cat(
                [
                    model_kwargs["attention_mask"],
                    torch.zeros((batch_size, (generation_config.max_length - input_ids_seq_length)))
                    .long()
                    .to(model_kwargs["attention_mask"].device),
                ],
                1,
            )

        # 7. determine generation mode
        is_constraint_gen_mode = (
            generation_config.constraints is not None or generation_config.force_words_ids is not None
        )

        is_contrastive_search_gen_mode = (
            (generation_config.num_beams == 1)
            and generation_config.top_k is not None
            and generation_config.top_k > 1
            and generation_config.do_sample is False
            and generation_config.penalty_alpha is not None
            and generation_config.penalty_alpha > 0
        )

        is_greedy_gen_mode = (
            (generation_config.num_beams == 1)
            and (generation_config.num_beam_groups == 1)
            and generation_config.do_sample is False
            and not is_constraint_gen_mode
            and not is_contrastive_search_gen_mode
        )
        is_beam_gen_mode = (
            (generation_config.num_beams > 1)
            and (generation_config.num_beam_groups == 1)
            and generation_config.do_sample is False
            and not is_constraint_gen_mode
            and not is_contrastive_search_gen_mode
        )

        if generation_config.num_beam_groups > generation_config.num_beams:
            raise ValueError("`num_beam_groups` has to be smaller or equal to `num_beams`")

        if hasattr(self, "device") and self.device.type != input_ids.device.type:
            warnings.warn(
                "You are calling .generate() with the `input_ids` being on a device type different"
                f" than your model's device. `input_ids` is on {input_ids.device.type}, whereas the model"
                f" is on {self.device.type}. You may experience unexpected behaviors or slower generation."
                " Please make sure that you have put `input_ids` to the"
                f" correct device by calling for example input_ids = input_ids.to('{self.device.type}') before"
                " running `.generate()`.",
                UserWarning,
            )

        # 8. prepare distribution pre_processing samplers
        logits_processor = self._get_logits_processor(
            generation_config=generation_config,
            input_ids_seq_length=input_ids_seq_length,
            encoder_input_ids=inputs_tensor,
            prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
            logits_processor=logits_processor,
        )

        # 9. prepare stopping criteria
        stopping_criteria = self._get_stopping_criteria(
            generation_config=generation_config, stopping_criteria=stopping_criteria
        )

        if is_greedy_gen_mode:
            if generation_config.num_return_sequences > 1:
                raise ValueError(
                    "num_return_sequences has to be 1 when doing greedy search, "
                    f"but is {generation_config.num_return_sequences}."
                )

            # 11. run greedy search
            return self.greedy_search(
                input_ids,
                logits_processor=logits_processor,
                stopping_criteria=stopping_criteria,
                pad_token_id=generation_config.pad_token_id,
                eos_token_id=generation_config.eos_token_id,
                output_scores=generation_config.output_scores,
                return_dict_in_generate=generation_config.return_dict_in_generate,
                synced_gpus=synced_gpus,
                seq_length=input_ids_seq_length,
                is_traced_inference=is_traced_inference,
                **model_kwargs,
            )
        elif is_beam_gen_mode:
            if generation_config.num_return_sequences > generation_config.num_beams:
                raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.")

            if stopping_criteria.max_length is None:
                raise ValueError("`max_length` needs to be a stopping_criteria for now.")

            # 11. prepare beam search scorer
            beam_scorer = BeamSearchScorer(
                batch_size=batch_size,
                num_beams=generation_config.num_beams,
                device="cpu",
                length_penalty=generation_config.length_penalty,
                do_early_stopping=generation_config.early_stopping,
                num_beam_hyps_to_keep=generation_config.num_return_sequences,
                max_length=generation_config.max_length,
            )
            # 12. interleave input_ids with `num_beams` additional sequences per batch
            input_ids, model_kwargs = self._expand_inputs_for_generation(
                input_ids=input_ids,
                expand_size=generation_config.num_beams,
                is_encoder_decoder=self.config.is_encoder_decoder,
                **model_kwargs,
            )
            # 13. run beam search
            return self.beam_search(
                input_ids,
                beam_scorer,
                logits_processor=logits_processor,
                stopping_criteria=stopping_criteria,
                pad_token_id=generation_config.pad_token_id,
                eos_token_id=generation_config.eos_token_id,
                output_scores=generation_config.output_scores,
                return_dict_in_generate=generation_config.return_dict_in_generate,
                synced_gpus=synced_gpus,
                seq_length=input_ids_seq_length,
                is_traced_inference=is_traced_inference,
                **model_kwargs,
            )

        else:
            raise ValueError("Only greedy search and beam search are supported on Neuron.")

    def greedy_search(
        self,
        input_ids: torch.LongTensor,
        logits_processor: Optional[LogitsProcessorList] = None,
        stopping_criteria: Optional[StoppingCriteriaList] = None,
        max_length: Optional[int] = None,
        pad_token_id: Optional[int] = None,
        eos_token_id: Optional[Union[int, List[int]]] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        output_scores: Optional[bool] = None,
        return_dict_in_generate: Optional[bool] = None,
        synced_gpus: bool = False,
        seq_length: Optional[int] = None,
        is_traced_inference: bool = False,
        **model_kwargs,
    ) -> Union[GreedySearchOutput, torch.LongTensor]:
        r"""
        Generates sequences of token ids for models with a language modeling head using **greedy decoding** and can be
        used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.

        <Tip warning={true}>

        In most cases, you do not need to call [`~generation.GenerationMixin.greedy_search`] directly. Use generate()
        instead. For an overview of generation strategies and code examples, check the [following
        guide](../generation_strategies).

        </Tip>


        Parameters:
            input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
                The sequence used as a prompt for the generation.
            logits_processor (`LogitsProcessorList`, *optional*):
                An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
                used to modify the prediction scores of the language modeling head applied at each generation step.
            stopping_criteria (`StoppingCriteriaList`, *optional*):
                An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
                used to tell if the generation loop should stop.

            max_length (`int`, *optional*, defaults to 20):
                **DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated
                tokens. The maximum length of the sequence to be generated.
            pad_token_id (`int`, *optional*):
                The id of the *padding* token.
            eos_token_id (`Union[int, List[int]]`, *optional*):
                The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
            output_attentions (`bool`, *optional*, defaults to `False`):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more details.
            output_hidden_states (`bool`, *optional*, defaults to `False`):
                Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
                for more details.
            output_scores (`bool`, *optional*, defaults to `False`):
                Whether or not to return the prediction scores. See `scores` under returned tensors for more details.
            return_dict_in_generate (`bool`, *optional*, defaults to `False`):
                Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
            synced_gpus (`bool`, *optional*, defaults to `False`):
                Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
            seq_length (`Optional[int]`, defaults to `False`):
                Length of current input_ids sequence
            is_traced_inference (`bool`, defaults to `False`):
                Whether the decoder is traced or using XLA lazy tensor. If the decoder is traced, next tokens and the beam scores
                are computed inside the decoder.
            model_kwargs:
                Additional model specific keyword arguments will be forwarded to the `forward` function of the model.
                If model is an encoder-decoder model the kwargs should include `encoder_outputs`.

        Return:
            [`~generation.GreedySearchDecoderOnlyOutput`], [`~generation.GreedySearchEncoderDecoderOutput`] or
            `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a
            [`~generation.GreedySearchDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
            `return_dict_in_generate=True` or a [`~generation.GreedySearchEncoderDecoderOutput`] if
            `model.config.is_encoder_decoder=True`.

        Examples:

        ```python
        >>> from transformers import AutoTokenizer
        >>> from optimum.neuron import NeuronModelForSeq2SeqLM

        >>> tokenizer = AutoTokenizer.from_pretrained("t5-small")
        >>> input_shapes = {"batch_size": 1, "sequence_length": 128, "num_beams": 1}
        >>> model = NeuronModelForSeq2SeqLM.from_pretrained("t5-small", export=True, dynamic_batch_size=False, **input_shapes)

        >>> input_prompt = "translate English to German: Lets eat good food."
        >>> inputs = tokenizer(input_prompt, return_tensors="pt")

        >>> outputs = model.greedy_search(input_ids)

        >>> results = [tokenizer.decode(t, skip_special_tokens=True) for t in outputs]
        ```
        """
        # init values
        if logits_processor is not None and is_traced_inference:
            logger.warning(
                "`logits_processor` will not be neglected because in `optimum-neuron`, `next_tokens` is computed inside the compiled decoder. If you want us to support custom logits_processor during the compilation, please file an issue to https://github.com/huggingface/optimum-neuron."
            )
        elif logits_processor is None:
            logits_processor = LogitsProcessorList()
        use_cache = model_kwargs.pop("use_cache", False)
        stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
        if max_length is not None:
            warnings.warn(
                "`max_length` is deprecated in this function, use"
                " `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead.",
                UserWarning,
            )
            stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
        pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
        eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
        if isinstance(eos_token_id, int):
            eos_token_id = [eos_token_id]
        eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
        output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
        output_attentions = (
            output_attentions if output_attentions is not None else self.generation_config.output_attentions
        )
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states
        )
        return_dict_in_generate = (
            return_dict_in_generate
            if return_dict_in_generate is not None
            else self.generation_config.return_dict_in_generate
        )

        # init attention / hidden states / scores tuples
        scores = None
        if return_dict_in_generate and output_scores:
            if is_traced_inference:
                logger.warning(
                    "`output_scores` will be neglected because currently we do not trace `next_token_scores` for greedy search (we do only in beam search). If you want us to support the option during the compilation, please file an issue to https://github.com/huggingface/optimum-neuron."
                )
            else:
                scores = ()
        decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
        cross_attentions = () if (return_dict_in_generate and output_attentions) else None
        decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None

        # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
        if return_dict_in_generate and self.config.is_encoder_decoder:
            encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
            encoder_hidden_states = (
                model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
            )

        # keep track of which sequences are already finished
        unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)

        this_peer_finished = False  # used by synced_gpus only
        while True:
            if synced_gpus:
                # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
                # The following logic allows an early break if all peers finished generating their sequence
                this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
                # send 0.0 if we finished, 1.0 otherwise
                dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
                # did all peers finish? the reduced sum will be 0.0 then
                if this_peer_finished_flag.item() == 0.0:
                    break

            # prepare model inputs
            if use_cache:
                # From max_length-sized input_ids, select first
                # seq_length - 1 values.

                if model_kwargs.get("past_key_values") is None:
                    input_ids_ = input_ids[:, :seq_length]
                else:
                    update_indices = torch.stack(
                        [torch.arange(input_ids.size(0)), torch.tensor(seq_length - 1).repeat(input_ids.size(0))],
                        dim=-1,
                    )
                    input_ids_ = input_ids[update_indices[:, 0], update_indices[:, 1], None]

                model_inputs = self.prepare_inputs_for_generation(input_ids_, **model_kwargs)
            else:
                model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)

            # forward pass to get next token
            outputs = self(
                **model_inputs,
                return_dict=True,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
            )

            if synced_gpus and this_peer_finished:
                continue  # don't waste resources running the code we don't need

            if not is_traced_inference:
                if not use_cache:
                    one_hot = (
                        torch.cat(
                            [
                                torch.tensor([0]).repeat(1, seq_length - 1),
                                torch.tensor([1]).repeat(1, 1),
                                torch.tensor([0]).repeat(1, input_ids.size(1) - seq_length),
                            ],
                            dim=1,
                        )
                        .to(device=outputs.logits.device)
                        .float()
                    )
                    next_token_logits = torch.matmul(one_hot, outputs.logits)
                    next_token_logits = next_token_logits.squeeze(1)
                else:
                    next_token_logits = outputs.logits[:, -1, :]

                # pre-process distribution
                # Move to cpu to handle arbitrary logits_processor
                next_tokens_scores = logits_processor(input_ids.to("cpu")[:, :seq_length], next_token_logits.to("cpu"))
                next_tokens_scores = next_tokens_scores.to(input_ids.device)

                # argmax
                next_tokens = torch.argmax(next_tokens_scores, dim=-1)

                if return_dict_in_generate and output_scores:
                    scores += (next_tokens_scores,)
            else:
                next_tokens = outputs[0]

            # Store scores, attentions and hidden_states when required
            if return_dict_in_generate:
                if output_attentions:
                    decoder_attentions += (
                        (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
                    )
                    if self.config.is_encoder_decoder:
                        cross_attentions += (outputs.cross_attentions,)

                if output_hidden_states:
                    decoder_hidden_states += (
                        (outputs.decoder_hidden_states,)
                        if self.config.is_encoder_decoder
                        else (outputs.hidden_states,)
                    )

            # finished sentences should have their next token be a padding token
            if eos_token_id is not None:
                if pad_token_id is None:
                    raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
                next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)

            # update generated ids, model inputs, and length for next step
            batch_size, _ = input_ids.shape
            update_indices = torch.stack(
                [torch.arange(batch_size), torch.tensor(seq_length).repeat(batch_size)], dim=-1
            )
            input_ids[update_indices[:, 0], update_indices[:, 1]] = next_tokens[:]
            model_kwargs = self._update_model_kwargs_for_xla_generation(
                outputs=outputs,
                model_kwargs=model_kwargs,
                batch_size=batch_size,
                is_encoder_decoder=self.config.is_encoder_decoder,
                max_length=stopping_criteria.max_length,
                seq_length=seq_length,
                use_cache=use_cache,
            )

            seq_length += 1

            # if eos_token was found in one sentence, set sentence to finished
            if eos_token_id_tensor is not None:
                unfinished_sequences = unfinished_sequences.mul(
                    next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
                )

            if not is_traced_inference:
                xm.mark_step()

            # stop when each sentence is finished, or if we exceed the maximum length
            stop_criterion_1 = unfinished_sequences.max() == 0

            if isinstance(stopping_criteria, list):
                if len(stopping_criteria) == 1:
                    stopping_criteria = stopping_criteria[0]

            # Cases that can be handled in XLA without requiring
            # non-padded input_ids
            if isinstance(stopping_criteria, MaxLengthCriteria):
                stop_criterion_2 = seq_length >= stopping_criteria.max_length
            elif isinstance(stopping_criteria, MaxTimeCriteria):
                stop_criterion_2 = stopping_criteria(input_ids, scores)
            else:
                # Other cases will be handled on CPU
                batch_size, _ = input_ids.shape
                mask = torch.cat(
                    [torch.ones(batch_size, seq_length), torch.zeros(batch_size, input_ids.shape[1] - seq_length)],
                    dim=1,
                ).bool()
                input_ids_cpu = torch.masked_select(input_ids, mask).reshape((batch_size, seq_length)).to("cpu")
                scores_cpu = scores.to("cpu") if torch.is_tensor(scores) else scores
                stop_criterion_2 = stopping_criteria(input_ids_cpu, scores_cpu)

            if stop_criterion_1 or stop_criterion_2:
                this_peer_finished = True

            if this_peer_finished and not synced_gpus:
                break

        if return_dict_in_generate:
            if self.config.is_encoder_decoder:
                return GreedySearchEncoderDecoderOutput(
                    sequences=input_ids,
                    scores=scores,
                    encoder_attentions=encoder_attentions,
                    encoder_hidden_states=encoder_hidden_states,
                    decoder_attentions=decoder_attentions,
                    cross_attentions=cross_attentions,
                    decoder_hidden_states=decoder_hidden_states,
                )
            else:
                return GreedySearchDecoderOnlyOutput(
                    sequences=input_ids,
                    scores=scores,
                    attentions=decoder_attentions,
                    hidden_states=decoder_hidden_states,
                )
        else:
            return input_ids

    def beam_search(
        self,
        input_ids: torch.LongTensor,
        beam_scorer: BeamScorer,
        logits_processor: Optional[LogitsProcessorList] = None,
        stopping_criteria: Optional[StoppingCriteriaList] = None,
        max_length: Optional[int] = None,
        pad_token_id: Optional[int] = None,
        eos_token_id: Optional[Union[int, List[int]]] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        output_scores: Optional[bool] = None,
        return_dict_in_generate: Optional[bool] = None,
        synced_gpus: Optional[bool] = False,
        seq_length: Optional[int] = None,
        is_traced_inference: bool = False,
        **model_kwargs,
    ) -> Union[BeamSearchOutput, torch.LongTensor]:
        r"""
        Generates sequences of token ids for models with a language modeling head using **beam search decoding** and
        can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.

        <Tip warning={true}>

        In most cases, you do not need to call [`~generation.GenerationMixin.beam_search`] directly. Use generate()
        instead. For an overview of generation strategies and code examples, check the [following
        guide](../generation_strategies).

        </Tip>

        Parameters:
            input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
                The sequence used as a prompt for the generation.
            beam_scorer (`BeamScorer`):
                An derived instance of [`BeamScorer`] that defines how beam hypotheses are constructed, stored and
                sorted during generation. For more information, the documentation of [`BeamScorer`] should be read.
            logits_processor (`LogitsProcessorList`, *optional*):
                An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
                used to modify the prediction scores of the language modeling head applied at each generation step.
            stopping_criteria (`StoppingCriteriaList`, *optional*):
                An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
                used to tell if the generation loop should stop.
            max_length (`int`, *optional*, defaults to 20):
                **DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated
                tokens. The maximum length of the sequence to be generated.
            pad_token_id (`int`, *optional*):
                The id of the *padding* token.
            eos_token_id (`Union[int, List[int]]`, *optional*):
                The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
            output_attentions (`bool`, *optional*, defaults to `False`):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more details.
            output_hidden_states (`bool`, *optional*, defaults to `False`):
                Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
                for more details.
            output_scores (`bool`, *optional*, defaults to `False`):
                Whether or not to return the prediction scores. See `scores` under returned tensors for more details.
            return_dict_in_generate (`bool`, *optional*, defaults to `False`):
                Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
            synced_gpus (`bool`, *optional*, defaults to `False`):
                Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
            seq_length (`Optional[int]`, defaults to `False`):
                Length of current input_ids sequence
            is_traced_inference (`bool`, defaults to `False`):
                Whether the decoder is traced or using XLA lazy tensor. If the decoder is traced, next tokens and the beam scores
                are computed inside the decoder.
            model_kwargs:
                Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
                an encoder-decoder model the kwargs should include `encoder_outputs`.

        Return:
            [`generation.BeamSearchDecoderOnlyOutput`], [`~generation.BeamSearchEncoderDecoderOutput`] or
            `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a
            [`~generation.BeamSearchDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
            `return_dict_in_generate=True` or a [`~generation.BeamSearchEncoderDecoderOutput`] if
            `model.config.is_encoder_decoder=True`.


        Examples:

        ```python
        >>> from transformers import AutoTokenizer
        >>> from optimum.neuron import NeuronModelForSeq2SeqLM

        >>> tokenizer = AutoTokenizer.from_pretrained("t5-small")
        >>> input_shapes = {"batch_size": 1, "sequence_length": 128, "num_beams": 4}
        >>> model = NeuronModelForSeq2SeqLM.from_pretrained("t5-small", export=True, dynamic_batch_size=False, **input_shapes)

        >>> input_prompt = "translate English to German: Lets eat good food."
        >>> inputs = tokenizer(input_prompt, return_tensors="pt")

        >>> # add encoder_outputs to model keyword arguments
        >>> model_kwargs = {
        ...     "encoder_outputs": model.get_encoder()(
        ...         encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True
        ...     )
        ... }
        >>> # instantiate beam scorer
        >>> beam_scorer = BeamSearchScorer(
        ...     batch_size=1,
        ...     num_beams=num_beams,
        ...     device=model.device,
        ... )

        >>> outputs = model.beam_search(input_ids, beam_scorer)
        >>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
        ```
        """
        # init values
        if logits_processor is not None and is_traced_inference:
            logger.warning(
                "`logits_processor` will be neglected because in `optimum-neuron`, `next_tokens` is computed inside the compiled decoder. If you want us to support custom logits_processor during the compilation, please file an issue to https://github.com/huggingface/optimum-neuron."
            )
        elif logits_processor is None:
            logits_processor = LogitsProcessorList()
        stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
        if max_length is not None:
            warnings.warn(
                "`max_length` is deprecated in this function, use"
                " `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.",
                UserWarning,
            )
            stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
        if len(stopping_criteria) == 0:
            warnings.warn("You don't have defined any stopping_criteria, this will likely loop forever", UserWarning)
        pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
        eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
        if isinstance(eos_token_id, int):
            eos_token_id = [eos_token_id]
        output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
        output_attentions = (
            output_attentions if output_attentions is not None else self.generation_config.output_attentions
        )
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states
        )
        return_dict_in_generate = (
            return_dict_in_generate
            if return_dict_in_generate is not None
            else self.generation_config.return_dict_in_generate
        )

        batch_size = len(beam_scorer._beam_hyps)
        num_beams = beam_scorer.num_beams

        batch_beam_size, cur_len = input_ids.shape

        # Overwrite cur_len
        cur_len = seq_length

        if num_beams * batch_size != batch_beam_size:
            raise ValueError(
                f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
            )

        # init attention / hidden states / scores tuples
        scores = () if (return_dict_in_generate and output_scores) else None
        beam_indices = (
            tuple(() for _ in range(batch_beam_size)) if (return_dict_in_generate and output_scores) else None
        )
        decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
        cross_attentions = () if (return_dict_in_generate and output_attentions) else None
        decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None

        # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
        if return_dict_in_generate and self.config.is_encoder_decoder:
            encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
            encoder_hidden_states = (
                model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
            )

        # initialise score of first beam with 0 and the rest with -1e9. This makes sure that only tokens
        # of the first beam are considered to avoid sampling the exact same tokens across all beams.
        beam_scores_device = "cpu"
        beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=beam_scores_device)
        beam_scores[:, 1:] = -1e9
        beam_scores = beam_scores.view((batch_size * num_beams,))

        this_peer_finished = False  # used by synced_gpus only
        while True:
            if synced_gpus:
                # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
                # The following logic allows an early break if all peers finished generating their sequence
                this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
                # send 0.0 if we finished, 1.0 otherwise
                dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
                # did all peers finish? the reduced sum will be 0.0 then
                if this_peer_finished_flag.item() == 0.0:
                    break

            # prepare model inputs
            if model_kwargs["use_cache"]:
                # From max_length-sized input_ids, select first
                # cur_len - 1 values.
                update_indices = torch.stack(
                    [torch.arange(input_ids.size(0)), torch.tensor(cur_len - 1).repeat(input_ids.size(0))], dim=-1
                )
                input_ids_ = input_ids[update_indices[:, 0], update_indices[:, 1], None]
                model_inputs = self.prepare_inputs_for_generation(input_ids_, **model_kwargs)
            else:
                model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)

            if is_traced_inference:
                outputs = self(
                    **model_inputs,
                    beam_scores=beam_scores,
                    return_dict=True,
                    output_attentions=output_attentions,
                    output_hidden_states=output_hidden_states,
                )
                next_token_scores = outputs.next_token_scores
                next_tokens = outputs.next_tokens
                next_indices = outputs.next_indices

                if return_dict_in_generate and output_scores:
                    scores += (next_token_scores,)
            else:
                outputs = self(
                    **model_inputs,
                    return_dict=True,
                    output_attentions=output_attentions,
                    output_hidden_states=output_hidden_states,
                )

                if synced_gpus and this_peer_finished:
                    cur_len = cur_len + 1
                    continue  # don't waste resources running the code we don't need

                if not model_kwargs["use_cache"]:
                    one_hot = (
                        torch.cat(
                            [
                                torch.tensor([0]).repeat(1, cur_len - 1),
                                torch.tensor([1]).repeat(1, 1),
                                torch.tensor([0]).repeat(1, input_ids.size(1) - cur_len),
                            ],
                            dim=1,
                        )
                        .to(device=outputs.logits.device)
                        .float()
                    )
                    next_token_logits = torch.matmul(one_hot, outputs.logits)
                    next_token_logits = next_token_logits.squeeze(1)
                else:
                    next_token_logits = outputs.logits[:, -1, :]

                # Manually compute log softmax
                # log_softmax(vi) = vi - max(vi) - log(sum(exp(vi - max(vi))))
                logit_max, _ = torch.max(next_token_logits, dim=-1, keepdim=True)
                logsumexp = torch.log(torch.exp(next_token_logits - logit_max).sum(dim=-1, keepdim=True))
                next_token_scores = next_token_logits - logit_max - logsumexp
                # (batch_size * num_beams, vocab_size)

                xm.mark_step()

                # We don't want to change every single logit processor, so
                # we perform this processing on CPU.
                input_ids_ = input_ids.to("cpu")[:, :cur_len]
                next_token_scores_ = next_token_scores.to("cpu")
                next_token_scores_processed = logits_processor(input_ids_, next_token_scores_)

                next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(next_token_scores)

                # reshape for beam search
                vocab_size = next_token_scores.shape[-1]
                next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size)
                next_token_scores = next_token_scores * 1

                # Sample 2 next tokens for each beam (so we have some spare tokens and match output of beam search)
                next_token_scores, next_tokens = torch.topk(
                    next_token_scores, 2 * num_beams, dim=1, largest=True, sorted=True
                )

                next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor")
                next_tokens = next_tokens % vocab_size

                if return_dict_in_generate and output_scores:
                    scores += (next_token_scores_processed,)

            # Store scores, attentions and hidden_states when required
            if return_dict_in_generate:
                if output_attentions:
                    decoder_attentions += (
                        (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
                    )
                    if self.config.is_encoder_decoder:
                        cross_attentions += (outputs.cross_attentions,)

                if output_hidden_states:
                    decoder_hidden_states += (
                        (outputs.decoder_hidden_states,)
                        if self.config.is_encoder_decoder
                        else (outputs.hidden_states,)
                    )
            # stateless
            beam_outputs = beam_scorer.process(
                input_ids.to("cpu")[:, :cur_len],
                next_token_scores.to("cpu"),
                next_tokens.to("cpu"),
                next_indices.to("cpu"),
                pad_token_id=pad_token_id,
                eos_token_id=eos_token_id,
                beam_indices=beam_indices,
            )

            beam_scores = beam_outputs["next_beam_scores"]
            beam_next_tokens = beam_outputs["next_beam_tokens"]
            beam_idx = beam_outputs["next_beam_indices"]

            update_indices = torch.stack(
                [torch.arange(batch_beam_size), torch.tensor(cur_len - 1).repeat(batch_beam_size)], dim=-1
            )
            update_indices_2 = torch.stack(
                [torch.arange(batch_beam_size), torch.tensor(cur_len).repeat(batch_beam_size)], dim=-1
            )
            # First select beam_indices
            device = input_ids.device
            beam_idx_device = beam_idx.to(device=input_ids.device)
            input_ids[:, :] = input_ids[beam_idx_device.long(), :]

            # Then append new tokens
            if is_traced_inference:
                # int64 is not natively supported by inf2 and has been cast down to int32
                input_ids[update_indices_2[:, 0], update_indices_2[:, 1], None] = (
                    beam_next_tokens.unsqueeze(-1).to(device).to(torch.long)
                )
            else:
                input_ids[update_indices_2[:, 0], update_indices_2[:, 1], None] = beam_next_tokens.unsqueeze(-1).to(
                    device
                )
            input_ids = input_ids * 1  # Hack to materialize tensor

            # update generated ids, model inputs, and length for next step
            model_kwargs = self._update_model_kwargs_for_xla_generation(
                outputs=outputs,
                model_kwargs=model_kwargs,
                batch_size=batch_beam_size,
                is_encoder_decoder=self.config.is_encoder_decoder,
                max_length=stopping_criteria.max_length,
                seq_length=cur_len,
                use_cache=model_kwargs["use_cache"],
            )
            if is_traced_inference:
                self._reorder_cache(beam_idx.to(torch.int64))
            elif model_kwargs["past_key_values"] is not None:
                model_kwargs["past_key_values"] = self._reorder_cache(model_kwargs["past_key_values"], beam_idx)

            if return_dict_in_generate and output_scores:
                beam_indices = tuple((beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices))))

            # increase cur_len
            cur_len = cur_len + 1

            # stop when each sentence is finished, or if we exceed the maximum length
            stop_criterion_1 = beam_scorer.is_done
            if isinstance(stopping_criteria, list):
                if len(stopping_criteria) == 1:
                    stopping_criteria = stopping_criteria[0]

            # Cases that can be handled in XLA without requiring
            # non-padded input_ids
            if isinstance(stopping_criteria, MaxLengthCriteria):
                stop_criterion_2 = cur_len >= stopping_criteria.max_length
            elif isinstance(stopping_criteria, MaxTimeCriteria):
                stop_criterion_2 = stopping_criteria(input_ids, scores)
            else:
                # Other cases will be handled on CPU
                batch_size, _ = input_ids.shape
                input_ids_cpu = input_ids.to("cpu")
                mask = torch.cat(
                    [torch.ones(batch_size, cur_len), torch.zeros(batch_size, input_ids.shape[1] - cur_len)], dim=1
                ).bool()
                input_ids_cpu = torch.masked_select(input_ids_cpu, mask).reshape((batch_size, cur_len))
                scores_cpu = scores.to("cpu") if torch.is_tensor(scores) else scores
                stop_criterion_2 = stopping_criteria(input_ids_cpu, scores_cpu)

            # TODO: validate with @JingyaHuang
            if stop_criterion_1 or torch.all(stop_criterion_2):
                if not synced_gpus:
                    break
                else:
                    this_peer_finished = True

        sequence_outputs = beam_scorer.finalize(
            input_ids.to("cpu"),
            beam_scores.to("cpu"),
            next_tokens.to("cpu"),
            next_indices.to("cpu"),
            pad_token_id=pad_token_id,
            eos_token_id=eos_token_id,
            max_length=stopping_criteria.max_length,
            beam_indices=beam_indices,
        )

        for k, v in sequence_outputs.items():
            if type(v) is torch.Tensor:
                sequence_outputs[k] = sequence_outputs[k].to(input_ids.device)

        if return_dict_in_generate:
            if not output_scores:
                sequence_outputs["sequence_scores"] = None

            if self.config.is_encoder_decoder:
                return BeamSearchEncoderDecoderOutput(
                    sequences=sequence_outputs["sequences"],
                    sequences_scores=sequence_outputs["sequence_scores"],
                    scores=scores,
                    beam_indices=sequence_outputs["beam_indices"],
                    encoder_attentions=encoder_attentions,
                    encoder_hidden_states=encoder_hidden_states,
                    decoder_attentions=decoder_attentions,
                    cross_attentions=cross_attentions,
                    decoder_hidden_states=decoder_hidden_states,
                )
            else:
                return BeamSearchDecoderOnlyOutput(
                    sequences=sequence_outputs["sequences"],
                    sequences_scores=sequence_outputs["sequence_scores"],
                    scores=scores,
                    beam_indices=sequence_outputs["beam_indices"],
                    attentions=decoder_attentions,
                    hidden_states=decoder_hidden_states,
                )
        else:
            return sequence_outputs["sequences"]
