#  Copyright 2022 The HuggingFace Team. All rights reserved.
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
"""Classes handling causal-lm related architectures in ONNX Runtime."""

import logging
import os
import re
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, Tuple, Union

import onnx
import torch
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
from onnx.tools import update_model_dims
from transformers import AutoModelForCausalLM, GenerationConfig
from transformers.file_utils import add_end_docstrings, add_start_docstrings_to_model_forward
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.utils import cached_file

from onnxruntime import InferenceSession, SessionOptions

from ..exporters.onnx import MODEL_TYPES_REQUIRING_POSITION_IDS, main_export
from ..exporters.tasks import TasksManager
from ..onnx.utils import check_model_uses_external_data
from ..utils import NormalizedConfigManager, is_transformers_version
from ..utils.file_utils import find_files_matching_pattern
from ..utils.save_utils import maybe_save_preprocessors
from .constants import (
    DECODER_MERGED_ONNX_FILE_PATTERN,
    DECODER_ONNX_FILE_PATTERN,
    DECODER_WITH_PAST_ONNX_FILE_PATTERN,
    ONNX_FILE_PATTERN,
)
from .modeling_ort import ONNX_MODEL_END_DOCSTRING, ORTModel
from .utils import prepare_providers_and_provider_options


if TYPE_CHECKING:
    from transformers import PretrainedConfig

if is_transformers_version(">=", "4.25.0"):
    from transformers.generation import GenerationMixin
else:
    from transformers.generation_utils import GenerationMixin  # type: ignore # noqa: F401


logger = logging.getLogger(__name__)

DECODER_INPUTS_DOCSTRING = r"""
    Args:
        input_ids (`torch.LongTensor`):
            Indices of decoder input sequence tokens in the vocabulary of shape `(batch_size, sequence_length)`.
        attention_mask (`torch.LongTensor`, *optional*):
            Mask to avoid performing attention on padding token indices, of shape
            `(batch_size, sequence_length)`. Mask values selected in `[0, 1]`.
        past_key_values (`tuple(tuple(torch.FloatTensor), *optional*, defaults to `None`)`
            Contains the precomputed key and value hidden states of the attention blocks used to speed up decoding.
            The tuple is of length `config.n_layers` with each tuple having 2 tensors of shape
            `(batch_size, num_heads, sequence_length, embed_size_per_head)`.
"""

CAUSALLM_ONNX_MODEL_DOCSTRING = r"""
    Args:
        input_ids (`torch.LongTensor`):
            Indices of decoder input sequence tokens in the vocabulary of shape `(batch_size, sequence_length)`.
        attention_mask (`torch.LongTensor`):
            Mask to avoid performing attention on padding token indices, of shape
            `(batch_size, sequence_length)`. Mask values selected in `[0, 1]`.
        past_key_values (`tuple(tuple(torch.FloatTensor), *optional*, defaults to `None`)`
            Contains the precomputed key and value hidden states of the attention blocks used to speed up decoding.
            The tuple is of length `config.n_layers` with each tuple having 2 tensors of shape
            `(batch_size, num_heads, sequence_length, embed_size_per_head)`.
"""

_TOKENIZER_FOR_DOC = "AutoTokenizer"

TEXT_GENERATION_EXAMPLE = r"""
    Example of text generation:

    ```python
    >>> from transformers import {processor_class}
    >>> from optimum.onnxruntime import {model_class}
    >>> import torch

    >>> tokenizer = {processor_class}.from_pretrained("{checkpoint}")
    >>> model = {model_class}.from_pretrained("{checkpoint}")

    >>> inputs = tokenizer("My name is Arthur and I live in", return_tensors="pt")

    >>> gen_tokens = model.generate(**inputs,do_sample=True,temperature=0.9, min_length=20,max_length=20)
    >>> tokenizer.batch_decode(gen_tokens)  # doctest: +IGNORE_RESULT
    ```

    Example using `transformers.pipelines`:

    ```python
    >>> from transformers import {processor_class}, pipeline
    >>> from optimum.onnxruntime import {model_class}

    >>> tokenizer = {processor_class}.from_pretrained("{checkpoint}")
    >>> model = {model_class}.from_pretrained("{checkpoint}")
    >>> onnx_gen = pipeline("text-generation", model=model, tokenizer=tokenizer)

    >>> text = "My name is Arthur and I live in"
    >>> gen = onnx_gen(text)
    ```
"""


@add_end_docstrings(ONNX_MODEL_END_DOCSTRING)
class ORTModelForCausalLM(ORTModel, GenerationMixin):
    """
    ONNX model with a causal language modeling head for ONNX Runtime inference. This class officially supports bloom, codegen, falcon, gpt2, gpt-bigcode, gpt_neo, gpt_neox, gptj, llama.
    """

    auto_model_class = AutoModelForCausalLM
    main_input_name = "input_ids"
    _supports_cache_class = False

    def __init__(
        self,
        *args,
        config: "PretrainedConfig" = None,
        session: "InferenceSession" = None,
        use_io_binding: Optional[bool] = None,
        generation_config: Optional["GenerationConfig"] = None,
        model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
        **kwargs,
    ):
        # DEPRECATED BEHAVIOR
        if args:
            logger.warning(
                "Instantiating an ORTModelForCausalLM with positional arguments is deprecated and will be removed in the next version. "
                "Please use the keywords arguments {config, session, use_io_binding, generation_config, model_save_dir, use_cache} instead."
            )
            # the old signature is ORTModelForCausalLM(model, config, use_io_binding, model_save_dir, preprocessors, generation_config, use_cache)
            session = args[0]
            if len(args) > 1:
                config = args[1]
            if len(args) > 2:
                use_io_binding = args[2]
            if len(args) > 3:
                model_save_dir = args[3]
            if len(args) > 4:
                _ = args[4]
            if len(args) > 5:
                generation_config = args[5]
            if len(args) > 6:
                _ = args[6]

        if kwargs.get("model", None) is not None:
            logger.warning(
                "Passing the inference session as `model` argument to an ORTModelForCausalLM is deprecated. Please use `session` instead."
            )
            session = kwargs.pop("model")
        if kwargs:
            logger.warning(
                f"Some keyword arguments were passed to the ORTModelForCausalLM constructor that are not part of its signature: {', '.join(kwargs.keys())}. "
                "These arguments will be ignored in the current version and will raise an error in the next version."
            )

        if config is None:
            raise ValueError(
                "The parameter config is required. Please pass a config or use the from_pretrained method."
            )
        if session is None:
            raise ValueError(
                "The parameter session is required. Please pass a session or use the from_pretrained method."
            )
        ## END OF DEPRECATED BEHAVIOR
        super().__init__(config=config, session=session, use_io_binding=use_io_binding, model_save_dir=model_save_dir)

        self.normalized_config = NormalizedConfigManager.get_normalized_config_class(config.model_type)(config)
        self.key_value_input_names = [key for key in self.input_names if (".key" in key) or (".value" in key)]
        self.key_value_output_names = [key for key in self.output_names if (".key" in key) or (".value" in key)]
        self.can_use_cache = len(self.key_value_input_names) > 0 and len(self.key_value_output_names) > 0
        self.is_merged = "use_cache_branch" in self.input_names
        self.generation_config = generation_config

        # Reference: https://github.com/huggingface/optimum/pull/1381
        model_type = self.config.model_type
        if model_type in MODEL_TYPES_REQUIRING_POSITION_IDS and "position_ids" not in self.input_names:
            logger.warning(
                f"ORTModelForCausalLM loaded a legacy ONNX model with no position_ids input, although the model type {model_type} "
                "requires it. for correct batched generation. We strongly encourage to re-export the model with "
                "a newer version of Optimum for better performance and more reliable generation. "
            )

        if not self.can_use_cache and self.generation_config.use_cache:
            logger.warning(
                "`model.generation_config.use_cache=True` but the loaded model does not support using the past key values cache."
                "Please re-export the original model once again with `use_cache=True` to be able to use it during generation. "
                "Or set `model.generation_config.use_cache=False` to avoid errors from attempting to use the cache. "
                "To re-export your model, simply set `export=True` as in `from_pretrained(..., export=True, use_cache=True)`."
            )

        if self.config.model_type == "gemma":
            self.embed_size_per_head = self.normalized_config.head_dim
        else:
            self.embed_size_per_head = self.normalized_config.hidden_size // self.normalized_config.num_attention_heads
        if self.config.model_type in {"gemma", "mistral", "llama", "qwen2", "qwen3", "qwen3_moe", "granite"}:
            self.num_key_value_heads = self.normalized_config.num_key_value_heads
        elif self.config.model_type == "falcon":
            self.num_key_value_heads = (
                self.config.num_kv_heads
                if (self.config.new_decoder_architecture or not self.config.multi_query)
                else 1
            )
        else:
            self.num_key_value_heads = self.normalized_config.num_attention_heads

    @property
    def use_cache(self):
        logger.warning(
            "The `ORTModelForCausalLM.use_cache` property is deprecated and will be removed in a future version. "
            "Please rather use `ORTModelForCausalLM.can_use_cache` to check if a model supports using cache during generation. "
            "And use `ORTModelForCausalLM.generation_config.use_cache` to check if the model is configured to use cache during generation."
        )
        return self.can_use_cache

    @property
    def use_merged(self):
        logger.warning(
            "The `ORTModelForCausalLM.use_merged` property is deprecated and will be removed in a future version. "
            "Please rather use `ORTModelForCausalLM.is_merged` to check if the underlying model is merged or not."
        )
        return self.is_merged

    @add_start_docstrings_to_model_forward(
        CAUSALLM_ONNX_MODEL_DOCSTRING.format("batch_size, sequence_length")
        + TEXT_GENERATION_EXAMPLE.format(
            processor_class=_TOKENIZER_FOR_DOC,
            model_class="ORTModelForCausalLM",
            checkpoint="optimum/gpt2",
        )
    )
    def forward(
        self,
        input_ids: torch.LongTensor,
        attention_mask: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
        position_ids: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        **kwargs,
    ) -> CausalLMOutputWithPast:
        use_torch = isinstance(input_ids, torch.Tensor)
        self.raise_on_numpy_input_io_binding(use_torch)
        use_cache = use_cache if use_cache is not None else self.config.use_cache

        if use_cache and not self.can_use_cache:
            raise ValueError(
                f"`use_cache={use_cache}` was passed to the model but the loaded model only supports `use_cache={self.can_use_cache}`. "
                f"Please load your current model with `use_cache={self.can_use_cache}` or export the original model "
                f"once again with `use_cache={use_cache}` when calling the `from_pretrained` method. "
                "To re-export your model, simply set `export=True` in the `from_pretrained` method."
            )

        if past_key_values is not None and isinstance(past_key_values[0], tuple):
            # Flattens the past_key_values to a single tuple
            past_key_values = sum(past_key_values, ())

        if "position_ids" in self.input_names and position_ids is None:
            if attention_mask is not None:
                # Create position_ids from attention_mask
                position_ids = attention_mask.cumsum(-1) - 1
                position_ids.masked_fill_(attention_mask == 0, 1)
                if past_key_values is not None:
                    position_ids = position_ids[:, -1].unsqueeze(-1)
            else:
                raise ValueError(
                    "The model requires position_ids for batched generation but none were provided. "
                    "Please provide position_ids or attention_mask (from which position_ids can be inferred)."
                )

        use_cache_branch = None
        if self.is_merged:
            # Uses cache branch of merged decoders depending on whether real past key values are passed
            use_cache_branch = torch.full((1,), past_key_values is not None, dtype=torch.bool, device=self.device)

        if past_key_values is None and len(self.key_value_input_names) > 0:
            # Generates the input pkv for the first forward of the model (merged or with past)
            batch_size, seq_len = input_ids.shape
            if self.config.model_type == "gpt_bigcode":
                shape = (batch_size, 0, self.embed_size_per_head * 2)
            else:
                shape = (batch_size, self.num_key_value_heads, 0, self.embed_size_per_head)
            tensor = torch.empty(shape, dtype=self.dtype, device=self.device)
            past_key_values = tuple(tensor for _ in range(len(self.key_value_input_names)))

        model_inputs = {
            "input_ids": input_ids,
            "position_ids": position_ids,
            "attention_mask": attention_mask,
            "use_cache_branch": use_cache_branch,
        }
        if len(self.key_value_input_names) > 0:
            model_inputs.update(zip(self.key_value_input_names, past_key_values))

        known_output_shapes = None
        outputs_to_not_bind = None
        if use_cache:
            # Infers the shape of the output pkv
            batch_size, seq_len = input_ids.shape
            if self.config.model_type == "gpt_bigcode":
                pkv_seq_len, embed_size_per_head_2 = past_key_values[0].shape[1:]
                pkv_output_shape = (batch_size, pkv_seq_len + seq_len, embed_size_per_head_2)
            else:
                num_key_value_heads, pkv_seq_len, embed_size_per_head = past_key_values[0].shape[1:]
                pkv_output_shape = (batch_size, num_key_value_heads, pkv_seq_len + seq_len, embed_size_per_head)
            known_output_shapes = dict.fromkeys(self.key_value_output_names, pkv_output_shape)
        else:
            # Don't bind the output pkv if not used/returned
            outputs_to_not_bind = self.key_value_output_names

        if self.use_io_binding:
            output_shapes, output_buffers = self._prepare_io_binding(
                model_inputs,
                outputs_to_not_bind=outputs_to_not_bind,
                known_output_shapes=known_output_shapes,
            )

            if self.device.type == "cpu":
                self.session.run_with_iobinding(self._io_binding)
            else:
                self._io_binding.synchronize_inputs()
                self.session.run_with_iobinding(self._io_binding)
                self._io_binding.synchronize_outputs()

            loss = output_buffers.get("loss", None)
            logits = output_buffers["logits"].view(output_shapes["logits"])

            if use_cache:
                past_key_values = tuple(
                    output_buffers.pop(name).view(output_shapes[name]) for name in self.key_value_output_names
                )
        else:
            onnx_inputs = self._prepare_onnx_inputs(use_torch, model_inputs)
            onnx_outputs = self.session.run(None, onnx_inputs)
            model_outputs = self._prepare_onnx_outputs(use_torch, onnx_outputs)

            loss = model_outputs.pop("loss", None)
            logits = model_outputs.pop("logits")

            if use_cache:
                past_key_values = tuple(model_outputs.pop(name) for name in self.key_value_output_names)

        if use_cache and self.config.model_type != "gpt_bigcode":
            # Tuple of tuple of length `n_layers`, with each tuple of length equal to the number of self-attention and per decoder layer
            past_key_values = tuple(past_key_values[i : i + 2] for i in range(0, len(past_key_values), 2))

        return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=past_key_values)

    def prepare_inputs_for_generation(self, *args, **kwargs):
        if is_transformers_version("<", "4.46.0"):
            return self._prepare_inputs_for_generation_legacy(*args, **kwargs)
        else:
            return super().prepare_inputs_for_generation(*args, **kwargs)

    # Adapted from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel.prepare_inputs_for_generation
    def _prepare_inputs_for_generation_legacy(
        self,
        input_ids,
        attention_mask=None,
        past_key_values=None,
        token_type_ids=None,
        position_ids=None,
        use_cache=None,
        **kwargs,
    ):
        if past_key_values is not None:
            if self.config.model_type == "gpt_bigcode":
                if self.config.multi_query:
                    past_length = past_key_values[0].shape[1]
                else:
                    past_length = past_key_values[0].shape[2]
            else:
                past_length = past_key_values[0][0].shape[2]

            if input_ids.shape[1] > past_length:
                remove_prefix_length = past_length
            else:
                remove_prefix_length = input_ids.shape[1] - 1
            input_ids = input_ids[:, remove_prefix_length:]

        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "past_key_values": past_key_values,
            "token_type_ids": token_type_ids,
            "position_ids": position_ids,
            "use_cache": use_cache,
        }

    @staticmethod
    def _reorder_cache(
        past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
    ) -> Tuple[Tuple[torch.Tensor]]:
        if isinstance(past_key_values, tuple) and isinstance(past_key_values[0], tuple):
            # GPT2 style
            return tuple(
                tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
                for layer_past in past_key_values
            )
        elif isinstance(past_key_values, tuple) and isinstance(past_key_values[0], torch.Tensor):
            # GPT BigCode style
            return tuple(layer_past.index_select(0, beam_idx.to(layer_past.device)) for layer_past in past_key_values)
        else:
            raise ValueError(
                f"Unexpected past_key_values: {past_key_values}. "
                "Expected tuple of tuples (GPT2 style) or tuple of tensors (GPT BigCode style)."
            )

    @classmethod
    def _from_pretrained(
        cls,
        model_id: Union[str, Path],
        config: "PretrainedConfig",
        # hub options
        subfolder: str = "",
        revision: str = "main",
        force_download: bool = False,
        local_files_only: bool = False,
        trust_remote_code: bool = False,
        cache_dir: str = HUGGINGFACE_HUB_CACHE,
        token: Optional[Union[bool, str]] = None,
        # file options
        file_name: Optional[str] = None,
        # session options
        provider: str = "CPUExecutionProvider",
        providers: Optional[Sequence[str]] = None,
        provider_options: Optional[Union[Sequence[Dict[str, Any]], Dict[str, Any]]] = None,
        session_options: Optional[SessionOptions] = None,
        # inference options
        use_cache: bool = True,
        use_merged: Optional[bool] = None,
        use_io_binding: Optional[bool] = None,
        generation_config: Optional[GenerationConfig] = None,
        # other arguments
        model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
    ) -> "ORTModelForCausalLM":
        onnx_files = find_files_matching_pattern(
            model_id,
            ONNX_FILE_PATTERN,
            glob_pattern="**/*.onnx",
            subfolder=subfolder,
            token=token,
            revision=revision,
        )

        if len(onnx_files) == 0:
            raise FileNotFoundError(f"Could not find any ONNX model file in {model_id}")

        if len(onnx_files) == 1:
            subfolder = onnx_files[0].parent
            _file_name = onnx_files[0].name
            if file_name and file_name != _file_name:
                raise FileNotFoundError(f"Trying to load {file_name} but only found {_file_name}")
            file_name = _file_name

        else:
            model_files = []
            # Check first for merged models and then for decoder / decoder_with_past models
            if use_merged is not False:
                model_files = [p for p in onnx_files if re.search(DECODER_MERGED_ONNX_FILE_PATTERN, str(p))]
                use_merged = len(model_files) != 0

            if use_merged is False:
                pattern = DECODER_WITH_PAST_ONNX_FILE_PATTERN if use_cache else DECODER_ONNX_FILE_PATTERN
                model_files = [p for p in onnx_files if re.search(pattern, str(p))]

            # if file_name is specified we don't filter legacy models
            if not model_files or file_name:
                model_files = onnx_files
            else:
                logger.warning(
                    f"Legacy models found in {model_files} will be loaded. "
                    "Legacy models will be deprecated in the next version of optimum, please re-export your model"
                )
            _file_name = model_files[0].name
            subfolder = model_files[0].parent

            defaut_file_name = file_name or "model.onnx"
            for file in model_files:
                if file.name == defaut_file_name:
                    _file_name = file.name
                    subfolder = file.parent
                    break

            file_name = _file_name

            if len(model_files) > 1:
                logger.warning(
                    f"Too many ONNX model files were found in {' ,'.join(map(str, model_files))}. "
                    "specify which one to load by using the `file_name` and/or the `subfolder` arguments. "
                    f"Loading the file {file_name} in the subfolder {subfolder}."
                )

        if os.path.isdir(model_id):
            model_id = subfolder
            subfolder = ""

        if isinstance(subfolder, Path):
            subfolder = subfolder.as_posix()

        model_cache_path = cached_file(
            model_id,
            filename=file_name,
            # hub options
            token=token,
            revision=revision,
            subfolder=subfolder,
            cache_dir=cache_dir,
            force_download=force_download,
            local_files_only=local_files_only,
        )

        # model_save_dir can be provided in kwargs as a TemporaryDirectory instance, in which case we want to keep it
        # instead of the path only.
        if model_save_dir is None:
            model_save_dir = Path(model_cache_path).parent

        try:
            cached_file(
                model_id,
                filename=file_name + "_data",
                # hub options
                token=token,
                revision=revision,
                subfolder=subfolder,
                cache_dir=cache_dir,
                force_download=force_download,
                local_files_only=local_files_only,
            )
        except EnvironmentError:
            # If the external data file is not found, we assume that the model is not using external data.
            pass

        # This should be removed at some point
        onnx_model = onnx.load(str(model_cache_path), load_external_data=False)
        model_uses_external_data = check_model_uses_external_data(onnx_model)
        if model_uses_external_data:
            onnx_model = onnx.load(str(model_cache_path), load_external_data=True)
        input_dims = {
            node.name: [dim.dim_value or dim.dim_param for dim in node.type.tensor_type.shape.dim]
            for node in onnx_model.graph.input
        }
        output_dims = {
            node.name: [dim.dim_value or dim.dim_param for dim in node.type.tensor_type.shape.dim]
            for node in onnx_model.graph.output
        }
        override_dims = False
        # Since v1.7.0 decoder with past models have fixed sequence length of 1
        # To keep these models compatible we set this dimension to dynamic
        if input_dims["input_ids"][1] == 1:
            input_dims["input_ids"][1] = "sequence_length"
            output_dims["logits"][1] = "sequence_length"
            override_dims = True
        # Since https://github.com/huggingface/optimum/pull/871/
        # changed axis notation/naming during export, we need to update the dims
        for input_name in input_dims.keys():
            if "past" in input_name and input_dims[input_name][2] == "past_sequence_length + sequence_length":
                input_dims[input_name][2] = "past_sequence_length"
                override_dims = True
        if override_dims:
            # this is kinda dangerous, warning the user is the least we can do
            logger.warning(
                "The ONNX model was probably exported with an older version of optimum. "
                "We are updating the input/output dimensions and overwriting the model file "
                "with new dimensions. This is necessary for the model to work correctly with "
                "the current version of optimum. If you encounter any issues, please re-export "
                "the model with the latest version of optimum for optimal performance."
            )
            onnx_model = update_model_dims.update_inputs_outputs_dims(onnx_model, input_dims, output_dims)
            onnx.save(
                onnx_model,
                str(model_cache_path),
                save_as_external_data=model_uses_external_data,
                location=Path(model_cache_path).name + "_data",
                all_tensors_to_one_file=True,
                convert_attribute=True,
                size_threshold=0,
            )
        del onnx_model

        # Important: for encoder-decoder models used with CausalLM, we need to set the is_decoder flag to True
        # and the is_encoder_decoder flag to False. This is needed for the model to work correctly with generation logic.
        if hasattr(config, "is_decoder"):
            config.is_decoder = True
        if hasattr(config, "is_encoder_decoder"):
            config.is_encoder_decoder = False

        if generation_config is None:
            try:
                generation_config = GenerationConfig.from_pretrained(
                    model_id,
                    token=token,
                    revision=revision,
                    subfolder=subfolder,
                    cache_dir=cache_dir,
                    force_download=force_download,
                    local_files_only=local_files_only,
                )
            except OSError:
                logger.info("Generation config file not found, creating a new one from model config.")
                generation_config = GenerationConfig.from_model_config(config)

        # TODO: not sure if setting config.use_cache is needed for older versions of transformers
        generation_config.use_cache = use_cache
        config.use_cache = use_cache

        if is_transformers_version(">=", "4.45.0"):
            misplaced_generation_parameters = config._get_non_default_generation_parameters()
            if len(misplaced_generation_parameters) > 0:
                logger.warning(
                    "Moving the following attributes in the config to the generation config: "
                    f"{misplaced_generation_parameters}. You are seeing this warning because you've set "
                    "generation parameters in the model config, as opposed to in the generation config.",
                )
                for param_name, param_value in misplaced_generation_parameters.items():
                    setattr(generation_config, param_name, param_value)
                    setattr(config, param_name, None)

        providers, provider_options = prepare_providers_and_provider_options(
            provider=provider, providers=providers, provider_options=provider_options
        )
        session = InferenceSession(
            model_cache_path,
            providers=providers,
            provider_options=provider_options,
            sess_options=session_options,
        )

        return cls(
            config=config,
            session=session,
            use_io_binding=use_io_binding,
            generation_config=generation_config,
            model_save_dir=model_save_dir,
        )

    @classmethod
    def _export(
        cls,
        model_id: Union[str, Path],
        config: "PretrainedConfig",
        # hub options
        subfolder: str = "",
        revision: str = "main",
        force_download: bool = False,
        local_files_only: bool = False,
        trust_remote_code: bool = False,
        cache_dir: str = HUGGINGFACE_HUB_CACHE,
        token: Optional[Union[bool, str]] = None,
        # inference options
        use_cache: bool = True,
        **kwargs,
    ) -> "ORTModelForCausalLM":
        # this is garanteed to work since we it uses a mapping from model classes to task names
        # instead of relying on the hub metadata or the model configuration
        task = TasksManager._infer_task_from_model_or_model_class(model_class=cls.auto_model_class)
        if use_cache:
            task += "-with-past"

        if kwargs.get("task", None) is not None:
            raise ValueError(
                f"The `task` argument is not needed when exporting a model with `{cls.__name__}`. "
                f"The `task` is automatically inferred from the class as `{task}`."
            )

        save_dir = TemporaryDirectory()
        save_dir_path = Path(save_dir.name)

        main_export(
            model_name_or_path=model_id,
            output=save_dir_path,
            task=task,
            do_validation=False,
            no_post_process=False,
            legacy=False,
            subfolder=subfolder,
            revision=revision,
            cache_dir=cache_dir,
            token=token,
            local_files_only=local_files_only,
            force_download=force_download,
            trust_remote_code=trust_remote_code,
        )
        maybe_save_preprocessors(model_id, save_dir_path, src_subfolder=subfolder)

        return cls._from_pretrained(
            save_dir_path,
            config,
            use_cache=use_cache,
            model_save_dir=save_dir,
            **kwargs,
        )

    def _save_config(self, save_directory):
        """
        Save the model and generation configs to the specified directory.

        Args:
            save_directory (`str` or `os.PathLike`):
                Directory where the model and generation configs will be saved.
        """
        self.config.save_pretrained(save_directory)
        self.generation_config.save_pretrained(save_directory)
