optimum/onnxruntime/modeling_decoder.py (531 lines of code) (raw):
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Classes handling causal-lm related architectures in ONNX Runtime."""
import logging
import os
import re
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, Tuple, Union
import onnx
import torch
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
from onnx.tools import update_model_dims
from transformers import AutoModelForCausalLM, GenerationConfig
from transformers.file_utils import add_end_docstrings, add_start_docstrings_to_model_forward
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.utils import cached_file
from onnxruntime import InferenceSession, SessionOptions
from ..exporters.onnx import MODEL_TYPES_REQUIRING_POSITION_IDS, main_export
from ..exporters.tasks import TasksManager
from ..onnx.utils import check_model_uses_external_data
from ..utils import NormalizedConfigManager, is_transformers_version
from ..utils.file_utils import find_files_matching_pattern
from ..utils.save_utils import maybe_save_preprocessors
from .constants import (
DECODER_MERGED_ONNX_FILE_PATTERN,
DECODER_ONNX_FILE_PATTERN,
DECODER_WITH_PAST_ONNX_FILE_PATTERN,
ONNX_FILE_PATTERN,
)
from .modeling_ort import ONNX_MODEL_END_DOCSTRING, ORTModel
from .utils import prepare_providers_and_provider_options
if TYPE_CHECKING:
from transformers import PretrainedConfig
if is_transformers_version(">=", "4.25.0"):
from transformers.generation import GenerationMixin
else:
from transformers.generation_utils import GenerationMixin # type: ignore # noqa: F401
logger = logging.getLogger(__name__)
DECODER_INPUTS_DOCSTRING = r"""
Args:
input_ids (`torch.LongTensor`):
Indices of decoder input sequence tokens in the vocabulary of shape `(batch_size, sequence_length)`.
attention_mask (`torch.LongTensor`, *optional*):
Mask to avoid performing attention on padding token indices, of shape
`(batch_size, sequence_length)`. Mask values selected in `[0, 1]`.
past_key_values (`tuple(tuple(torch.FloatTensor), *optional*, defaults to `None`)`
Contains the precomputed key and value hidden states of the attention blocks used to speed up decoding.
The tuple is of length `config.n_layers` with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`.
"""
CAUSALLM_ONNX_MODEL_DOCSTRING = r"""
Args:
input_ids (`torch.LongTensor`):
Indices of decoder input sequence tokens in the vocabulary of shape `(batch_size, sequence_length)`.
attention_mask (`torch.LongTensor`):
Mask to avoid performing attention on padding token indices, of shape
`(batch_size, sequence_length)`. Mask values selected in `[0, 1]`.
past_key_values (`tuple(tuple(torch.FloatTensor), *optional*, defaults to `None`)`
Contains the precomputed key and value hidden states of the attention blocks used to speed up decoding.
The tuple is of length `config.n_layers` with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`.
"""
_TOKENIZER_FOR_DOC = "AutoTokenizer"
TEXT_GENERATION_EXAMPLE = r"""
Example of text generation:
```python
>>> from transformers import {processor_class}
>>> from optimum.onnxruntime import {model_class}
>>> import torch
>>> tokenizer = {processor_class}.from_pretrained("{checkpoint}")
>>> model = {model_class}.from_pretrained("{checkpoint}")
>>> inputs = tokenizer("My name is Arthur and I live in", return_tensors="pt")
>>> gen_tokens = model.generate(**inputs,do_sample=True,temperature=0.9, min_length=20,max_length=20)
>>> tokenizer.batch_decode(gen_tokens) # doctest: +IGNORE_RESULT
```
Example using `transformers.pipelines`:
```python
>>> from transformers import {processor_class}, pipeline
>>> from optimum.onnxruntime import {model_class}
>>> tokenizer = {processor_class}.from_pretrained("{checkpoint}")
>>> model = {model_class}.from_pretrained("{checkpoint}")
>>> onnx_gen = pipeline("text-generation", model=model, tokenizer=tokenizer)
>>> text = "My name is Arthur and I live in"
>>> gen = onnx_gen(text)
```
"""
@add_end_docstrings(ONNX_MODEL_END_DOCSTRING)
class ORTModelForCausalLM(ORTModel, GenerationMixin):
"""
ONNX model with a causal language modeling head for ONNX Runtime inference. This class officially supports bloom, codegen, falcon, gpt2, gpt-bigcode, gpt_neo, gpt_neox, gptj, llama.
"""
auto_model_class = AutoModelForCausalLM
main_input_name = "input_ids"
_supports_cache_class = False
def __init__(
self,
*args,
config: "PretrainedConfig" = None,
session: "InferenceSession" = None,
use_io_binding: Optional[bool] = None,
generation_config: Optional["GenerationConfig"] = None,
model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
**kwargs,
):
# DEPRECATED BEHAVIOR
if args:
logger.warning(
"Instantiating an ORTModelForCausalLM with positional arguments is deprecated and will be removed in the next version. "
"Please use the keywords arguments {config, session, use_io_binding, generation_config, model_save_dir, use_cache} instead."
)
# the old signature is ORTModelForCausalLM(model, config, use_io_binding, model_save_dir, preprocessors, generation_config, use_cache)
session = args[0]
if len(args) > 1:
config = args[1]
if len(args) > 2:
use_io_binding = args[2]
if len(args) > 3:
model_save_dir = args[3]
if len(args) > 4:
_ = args[4]
if len(args) > 5:
generation_config = args[5]
if len(args) > 6:
_ = args[6]
if kwargs.get("model", None) is not None:
logger.warning(
"Passing the inference session as `model` argument to an ORTModelForCausalLM is deprecated. Please use `session` instead."
)
session = kwargs.pop("model")
if kwargs:
logger.warning(
f"Some keyword arguments were passed to the ORTModelForCausalLM constructor that are not part of its signature: {', '.join(kwargs.keys())}. "
"These arguments will be ignored in the current version and will raise an error in the next version."
)
if config is None:
raise ValueError(
"The parameter config is required. Please pass a config or use the from_pretrained method."
)
if session is None:
raise ValueError(
"The parameter session is required. Please pass a session or use the from_pretrained method."
)
## END OF DEPRECATED BEHAVIOR
super().__init__(config=config, session=session, use_io_binding=use_io_binding, model_save_dir=model_save_dir)
self.normalized_config = NormalizedConfigManager.get_normalized_config_class(config.model_type)(config)
self.key_value_input_names = [key for key in self.input_names if (".key" in key) or (".value" in key)]
self.key_value_output_names = [key for key in self.output_names if (".key" in key) or (".value" in key)]
self.can_use_cache = len(self.key_value_input_names) > 0 and len(self.key_value_output_names) > 0
self.is_merged = "use_cache_branch" in self.input_names
self.generation_config = generation_config
# Reference: https://github.com/huggingface/optimum/pull/1381
model_type = self.config.model_type
if model_type in MODEL_TYPES_REQUIRING_POSITION_IDS and "position_ids" not in self.input_names:
logger.warning(
f"ORTModelForCausalLM loaded a legacy ONNX model with no position_ids input, although the model type {model_type} "
"requires it. for correct batched generation. We strongly encourage to re-export the model with "
"a newer version of Optimum for better performance and more reliable generation. "
)
if not self.can_use_cache and self.generation_config.use_cache:
logger.warning(
"`model.generation_config.use_cache=True` but the loaded model does not support using the past key values cache."
"Please re-export the original model once again with `use_cache=True` to be able to use it during generation. "
"Or set `model.generation_config.use_cache=False` to avoid errors from attempting to use the cache. "
"To re-export your model, simply set `export=True` as in `from_pretrained(..., export=True, use_cache=True)`."
)
if self.config.model_type == "gemma":
self.embed_size_per_head = self.normalized_config.head_dim
else:
self.embed_size_per_head = self.normalized_config.hidden_size // self.normalized_config.num_attention_heads
if self.config.model_type in {"gemma", "mistral", "llama", "qwen2", "qwen3", "qwen3_moe", "granite"}:
self.num_key_value_heads = self.normalized_config.num_key_value_heads
elif self.config.model_type == "falcon":
self.num_key_value_heads = (
self.config.num_kv_heads
if (self.config.new_decoder_architecture or not self.config.multi_query)
else 1
)
else:
self.num_key_value_heads = self.normalized_config.num_attention_heads
@property
def use_cache(self):
logger.warning(
"The `ORTModelForCausalLM.use_cache` property is deprecated and will be removed in a future version. "
"Please rather use `ORTModelForCausalLM.can_use_cache` to check if a model supports using cache during generation. "
"And use `ORTModelForCausalLM.generation_config.use_cache` to check if the model is configured to use cache during generation."
)
return self.can_use_cache
@property
def use_merged(self):
logger.warning(
"The `ORTModelForCausalLM.use_merged` property is deprecated and will be removed in a future version. "
"Please rather use `ORTModelForCausalLM.is_merged` to check if the underlying model is merged or not."
)
return self.is_merged
@add_start_docstrings_to_model_forward(
CAUSALLM_ONNX_MODEL_DOCSTRING.format("batch_size, sequence_length")
+ TEXT_GENERATION_EXAMPLE.format(
processor_class=_TOKENIZER_FOR_DOC,
model_class="ORTModelForCausalLM",
checkpoint="optimum/gpt2",
)
)
def forward(
self,
input_ids: torch.LongTensor,
attention_mask: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
position_ids: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
**kwargs,
) -> CausalLMOutputWithPast:
use_torch = isinstance(input_ids, torch.Tensor)
self.raise_on_numpy_input_io_binding(use_torch)
use_cache = use_cache if use_cache is not None else self.config.use_cache
if use_cache and not self.can_use_cache:
raise ValueError(
f"`use_cache={use_cache}` was passed to the model but the loaded model only supports `use_cache={self.can_use_cache}`. "
f"Please load your current model with `use_cache={self.can_use_cache}` or export the original model "
f"once again with `use_cache={use_cache}` when calling the `from_pretrained` method. "
"To re-export your model, simply set `export=True` in the `from_pretrained` method."
)
if past_key_values is not None and isinstance(past_key_values[0], tuple):
# Flattens the past_key_values to a single tuple
past_key_values = sum(past_key_values, ())
if "position_ids" in self.input_names and position_ids is None:
if attention_mask is not None:
# Create position_ids from attention_mask
position_ids = attention_mask.cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values is not None:
position_ids = position_ids[:, -1].unsqueeze(-1)
else:
raise ValueError(
"The model requires position_ids for batched generation but none were provided. "
"Please provide position_ids or attention_mask (from which position_ids can be inferred)."
)
use_cache_branch = None
if self.is_merged:
# Uses cache branch of merged decoders depending on whether real past key values are passed
use_cache_branch = torch.full((1,), past_key_values is not None, dtype=torch.bool, device=self.device)
if past_key_values is None and len(self.key_value_input_names) > 0:
# Generates the input pkv for the first forward of the model (merged or with past)
batch_size, seq_len = input_ids.shape
if self.config.model_type == "gpt_bigcode":
shape = (batch_size, 0, self.embed_size_per_head * 2)
else:
shape = (batch_size, self.num_key_value_heads, 0, self.embed_size_per_head)
tensor = torch.empty(shape, dtype=self.dtype, device=self.device)
past_key_values = tuple(tensor for _ in range(len(self.key_value_input_names)))
model_inputs = {
"input_ids": input_ids,
"position_ids": position_ids,
"attention_mask": attention_mask,
"use_cache_branch": use_cache_branch,
}
if len(self.key_value_input_names) > 0:
model_inputs.update(zip(self.key_value_input_names, past_key_values))
known_output_shapes = None
outputs_to_not_bind = None
if use_cache:
# Infers the shape of the output pkv
batch_size, seq_len = input_ids.shape
if self.config.model_type == "gpt_bigcode":
pkv_seq_len, embed_size_per_head_2 = past_key_values[0].shape[1:]
pkv_output_shape = (batch_size, pkv_seq_len + seq_len, embed_size_per_head_2)
else:
num_key_value_heads, pkv_seq_len, embed_size_per_head = past_key_values[0].shape[1:]
pkv_output_shape = (batch_size, num_key_value_heads, pkv_seq_len + seq_len, embed_size_per_head)
known_output_shapes = dict.fromkeys(self.key_value_output_names, pkv_output_shape)
else:
# Don't bind the output pkv if not used/returned
outputs_to_not_bind = self.key_value_output_names
if self.use_io_binding:
output_shapes, output_buffers = self._prepare_io_binding(
model_inputs,
outputs_to_not_bind=outputs_to_not_bind,
known_output_shapes=known_output_shapes,
)
if self.device.type == "cpu":
self.session.run_with_iobinding(self._io_binding)
else:
self._io_binding.synchronize_inputs()
self.session.run_with_iobinding(self._io_binding)
self._io_binding.synchronize_outputs()
loss = output_buffers.get("loss", None)
logits = output_buffers["logits"].view(output_shapes["logits"])
if use_cache:
past_key_values = tuple(
output_buffers.pop(name).view(output_shapes[name]) for name in self.key_value_output_names
)
else:
onnx_inputs = self._prepare_onnx_inputs(use_torch, model_inputs)
onnx_outputs = self.session.run(None, onnx_inputs)
model_outputs = self._prepare_onnx_outputs(use_torch, onnx_outputs)
loss = model_outputs.pop("loss", None)
logits = model_outputs.pop("logits")
if use_cache:
past_key_values = tuple(model_outputs.pop(name) for name in self.key_value_output_names)
if use_cache and self.config.model_type != "gpt_bigcode":
# Tuple of tuple of length `n_layers`, with each tuple of length equal to the number of self-attention and per decoder layer
past_key_values = tuple(past_key_values[i : i + 2] for i in range(0, len(past_key_values), 2))
return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=past_key_values)
def prepare_inputs_for_generation(self, *args, **kwargs):
if is_transformers_version("<", "4.46.0"):
return self._prepare_inputs_for_generation_legacy(*args, **kwargs)
else:
return super().prepare_inputs_for_generation(*args, **kwargs)
# Adapted from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel.prepare_inputs_for_generation
def _prepare_inputs_for_generation_legacy(
self,
input_ids,
attention_mask=None,
past_key_values=None,
token_type_ids=None,
position_ids=None,
use_cache=None,
**kwargs,
):
if past_key_values is not None:
if self.config.model_type == "gpt_bigcode":
if self.config.multi_query:
past_length = past_key_values[0].shape[1]
else:
past_length = past_key_values[0].shape[2]
else:
past_length = past_key_values[0][0].shape[2]
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
remove_prefix_length = input_ids.shape[1] - 1
input_ids = input_ids[:, remove_prefix_length:]
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"past_key_values": past_key_values,
"token_type_ids": token_type_ids,
"position_ids": position_ids,
"use_cache": use_cache,
}
@staticmethod
def _reorder_cache(
past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
) -> Tuple[Tuple[torch.Tensor]]:
if isinstance(past_key_values, tuple) and isinstance(past_key_values[0], tuple):
# GPT2 style
return tuple(
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
for layer_past in past_key_values
)
elif isinstance(past_key_values, tuple) and isinstance(past_key_values[0], torch.Tensor):
# GPT BigCode style
return tuple(layer_past.index_select(0, beam_idx.to(layer_past.device)) for layer_past in past_key_values)
else:
raise ValueError(
f"Unexpected past_key_values: {past_key_values}. "
"Expected tuple of tuples (GPT2 style) or tuple of tensors (GPT BigCode style)."
)
@classmethod
def _from_pretrained(
cls,
model_id: Union[str, Path],
config: "PretrainedConfig",
# hub options
subfolder: str = "",
revision: str = "main",
force_download: bool = False,
local_files_only: bool = False,
trust_remote_code: bool = False,
cache_dir: str = HUGGINGFACE_HUB_CACHE,
token: Optional[Union[bool, str]] = None,
# file options
file_name: Optional[str] = None,
# session options
provider: str = "CPUExecutionProvider",
providers: Optional[Sequence[str]] = None,
provider_options: Optional[Union[Sequence[Dict[str, Any]], Dict[str, Any]]] = None,
session_options: Optional[SessionOptions] = None,
# inference options
use_cache: bool = True,
use_merged: Optional[bool] = None,
use_io_binding: Optional[bool] = None,
generation_config: Optional[GenerationConfig] = None,
# other arguments
model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
) -> "ORTModelForCausalLM":
onnx_files = find_files_matching_pattern(
model_id,
ONNX_FILE_PATTERN,
glob_pattern="**/*.onnx",
subfolder=subfolder,
token=token,
revision=revision,
)
if len(onnx_files) == 0:
raise FileNotFoundError(f"Could not find any ONNX model file in {model_id}")
if len(onnx_files) == 1:
subfolder = onnx_files[0].parent
_file_name = onnx_files[0].name
if file_name and file_name != _file_name:
raise FileNotFoundError(f"Trying to load {file_name} but only found {_file_name}")
file_name = _file_name
else:
model_files = []
# Check first for merged models and then for decoder / decoder_with_past models
if use_merged is not False:
model_files = [p for p in onnx_files if re.search(DECODER_MERGED_ONNX_FILE_PATTERN, str(p))]
use_merged = len(model_files) != 0
if use_merged is False:
pattern = DECODER_WITH_PAST_ONNX_FILE_PATTERN if use_cache else DECODER_ONNX_FILE_PATTERN
model_files = [p for p in onnx_files if re.search(pattern, str(p))]
# if file_name is specified we don't filter legacy models
if not model_files or file_name:
model_files = onnx_files
else:
logger.warning(
f"Legacy models found in {model_files} will be loaded. "
"Legacy models will be deprecated in the next version of optimum, please re-export your model"
)
_file_name = model_files[0].name
subfolder = model_files[0].parent
defaut_file_name = file_name or "model.onnx"
for file in model_files:
if file.name == defaut_file_name:
_file_name = file.name
subfolder = file.parent
break
file_name = _file_name
if len(model_files) > 1:
logger.warning(
f"Too many ONNX model files were found in {' ,'.join(map(str, model_files))}. "
"specify which one to load by using the `file_name` and/or the `subfolder` arguments. "
f"Loading the file {file_name} in the subfolder {subfolder}."
)
if os.path.isdir(model_id):
model_id = subfolder
subfolder = ""
if isinstance(subfolder, Path):
subfolder = subfolder.as_posix()
model_cache_path = cached_file(
model_id,
filename=file_name,
# hub options
token=token,
revision=revision,
subfolder=subfolder,
cache_dir=cache_dir,
force_download=force_download,
local_files_only=local_files_only,
)
# model_save_dir can be provided in kwargs as a TemporaryDirectory instance, in which case we want to keep it
# instead of the path only.
if model_save_dir is None:
model_save_dir = Path(model_cache_path).parent
try:
cached_file(
model_id,
filename=file_name + "_data",
# hub options
token=token,
revision=revision,
subfolder=subfolder,
cache_dir=cache_dir,
force_download=force_download,
local_files_only=local_files_only,
)
except EnvironmentError:
# If the external data file is not found, we assume that the model is not using external data.
pass
# This should be removed at some point
onnx_model = onnx.load(str(model_cache_path), load_external_data=False)
model_uses_external_data = check_model_uses_external_data(onnx_model)
if model_uses_external_data:
onnx_model = onnx.load(str(model_cache_path), load_external_data=True)
input_dims = {
node.name: [dim.dim_value or dim.dim_param for dim in node.type.tensor_type.shape.dim]
for node in onnx_model.graph.input
}
output_dims = {
node.name: [dim.dim_value or dim.dim_param for dim in node.type.tensor_type.shape.dim]
for node in onnx_model.graph.output
}
override_dims = False
# Since v1.7.0 decoder with past models have fixed sequence length of 1
# To keep these models compatible we set this dimension to dynamic
if input_dims["input_ids"][1] == 1:
input_dims["input_ids"][1] = "sequence_length"
output_dims["logits"][1] = "sequence_length"
override_dims = True
# Since https://github.com/huggingface/optimum/pull/871/
# changed axis notation/naming during export, we need to update the dims
for input_name in input_dims.keys():
if "past" in input_name and input_dims[input_name][2] == "past_sequence_length + sequence_length":
input_dims[input_name][2] = "past_sequence_length"
override_dims = True
if override_dims:
# this is kinda dangerous, warning the user is the least we can do
logger.warning(
"The ONNX model was probably exported with an older version of optimum. "
"We are updating the input/output dimensions and overwriting the model file "
"with new dimensions. This is necessary for the model to work correctly with "
"the current version of optimum. If you encounter any issues, please re-export "
"the model with the latest version of optimum for optimal performance."
)
onnx_model = update_model_dims.update_inputs_outputs_dims(onnx_model, input_dims, output_dims)
onnx.save(
onnx_model,
str(model_cache_path),
save_as_external_data=model_uses_external_data,
location=Path(model_cache_path).name + "_data",
all_tensors_to_one_file=True,
convert_attribute=True,
size_threshold=0,
)
del onnx_model
# Important: for encoder-decoder models used with CausalLM, we need to set the is_decoder flag to True
# and the is_encoder_decoder flag to False. This is needed for the model to work correctly with generation logic.
if hasattr(config, "is_decoder"):
config.is_decoder = True
if hasattr(config, "is_encoder_decoder"):
config.is_encoder_decoder = False
if generation_config is None:
try:
generation_config = GenerationConfig.from_pretrained(
model_id,
token=token,
revision=revision,
subfolder=subfolder,
cache_dir=cache_dir,
force_download=force_download,
local_files_only=local_files_only,
)
except OSError:
logger.info("Generation config file not found, creating a new one from model config.")
generation_config = GenerationConfig.from_model_config(config)
# TODO: not sure if setting config.use_cache is needed for older versions of transformers
generation_config.use_cache = use_cache
config.use_cache = use_cache
if is_transformers_version(">=", "4.45.0"):
misplaced_generation_parameters = config._get_non_default_generation_parameters()
if len(misplaced_generation_parameters) > 0:
logger.warning(
"Moving the following attributes in the config to the generation config: "
f"{misplaced_generation_parameters}. You are seeing this warning because you've set "
"generation parameters in the model config, as opposed to in the generation config.",
)
for param_name, param_value in misplaced_generation_parameters.items():
setattr(generation_config, param_name, param_value)
setattr(config, param_name, None)
providers, provider_options = prepare_providers_and_provider_options(
provider=provider, providers=providers, provider_options=provider_options
)
session = InferenceSession(
model_cache_path,
providers=providers,
provider_options=provider_options,
sess_options=session_options,
)
return cls(
config=config,
session=session,
use_io_binding=use_io_binding,
generation_config=generation_config,
model_save_dir=model_save_dir,
)
@classmethod
def _export(
cls,
model_id: Union[str, Path],
config: "PretrainedConfig",
# hub options
subfolder: str = "",
revision: str = "main",
force_download: bool = False,
local_files_only: bool = False,
trust_remote_code: bool = False,
cache_dir: str = HUGGINGFACE_HUB_CACHE,
token: Optional[Union[bool, str]] = None,
# inference options
use_cache: bool = True,
**kwargs,
) -> "ORTModelForCausalLM":
# this is garanteed to work since we it uses a mapping from model classes to task names
# instead of relying on the hub metadata or the model configuration
task = TasksManager._infer_task_from_model_or_model_class(model_class=cls.auto_model_class)
if use_cache:
task += "-with-past"
if kwargs.get("task", None) is not None:
raise ValueError(
f"The `task` argument is not needed when exporting a model with `{cls.__name__}`. "
f"The `task` is automatically inferred from the class as `{task}`."
)
save_dir = TemporaryDirectory()
save_dir_path = Path(save_dir.name)
main_export(
model_name_or_path=model_id,
output=save_dir_path,
task=task,
do_validation=False,
no_post_process=False,
legacy=False,
subfolder=subfolder,
revision=revision,
cache_dir=cache_dir,
token=token,
local_files_only=local_files_only,
force_download=force_download,
trust_remote_code=trust_remote_code,
)
maybe_save_preprocessors(model_id, save_dir_path, src_subfolder=subfolder)
return cls._from_pretrained(
save_dir_path,
config,
use_cache=use_cache,
model_save_dir=save_dir,
**kwargs,
)
def _save_config(self, save_directory):
"""
Save the model and generation configs to the specified directory.
Args:
save_directory (`str` or `os.PathLike`):
Directory where the model and generation configs will be saved.
"""
self.config.save_pretrained(save_directory)
self.generation_config.save_pretrained(save_directory)