optimum/intel/openvino/modeling_decoder.py (813 lines of code) (raw):
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import logging
import os
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as np
import openvino
import torch
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
from openvino import Core, Tensor, Type
from openvino.preprocess import PrePostProcessor
from transformers import AutoModelForCausalLM, PretrainedConfig
from transformers.file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
from transformers.generation import GenerationMixin
from transformers.generation.configuration_utils import GenerationConfig
from transformers.generation.logits_process import LogitsProcessorList
from transformers.generation.stopping_criteria import StoppingCriteriaList
from transformers.generation.utils import GenerateOutput, GenerationMode
from transformers.modeling_outputs import CausalLMOutputWithPast, ModelOutput
from transformers.utils.hub import PushToHubMixin
from optimum.utils.normalized_config import NormalizedConfigManager
from ...exporters.openvino import ensure_stateful_is_available, main_export, patch_stateful
from ...exporters.openvino.stateful import model_has_state
from ..utils.import_utils import compare_versions, is_nncf_available, is_transformers_version
from ..utils.modeling_utils import MULTI_QUERY_ATTN_MODELS
from .configuration import (
_DEFAULT_4BIT_WQ_CONFIG,
OVConfig,
OVWeightQuantizationConfig,
get_default_quantization_config,
)
from .modeling import _TOKENIZER_FOR_DOC, INPUTS_DOCSTRING, MODEL_START_DOCSTRING, OVModel
from .utils import (
ONNX_WEIGHTS_NAME,
OV_XML_FILE_NAME,
STR_TO_OV_TYPE,
TemporaryDirectory,
get_export_transformers_version,
model_has_dynamic_inputs,
)
if TYPE_CHECKING:
try:
from transformers.generation.streamers import BaseStreamer
except Exception:
from typing import Generator as BaseStreamer
from transformers.modeling_utils import PreTrainedModel
logger = logging.getLogger(__name__)
core = Core()
TEXT_GENERATION_EXAMPLE = r"""
Example of text generation:
```python
>>> from transformers import {processor_class}
>>> from optimum.intel import {model_class}
>>> tokenizer = {processor_class}.from_pretrained("{checkpoint}")
>>> model = {model_class}.from_pretrained("{checkpoint}", export=True)
>>> inputs = tokenizer("I love this story because", return_tensors="pt")
>>> gen_tokens = model.generate(**inputs, do_sample=True, temperature=0.9, min_length=20, max_length=20)
>>> tokenizer.batch_decode(gen_tokens)
```
Example using `transformers.pipelines`:
```python
>>> from transformers import {processor_class}, pipeline
>>> from optimum.intel import {model_class}
>>> tokenizer = {processor_class}.from_pretrained("{checkpoint}")
>>> model = {model_class}.from_pretrained("{checkpoint}", export=True)
>>> gen_pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer)
>>> text = "I love this story because"
>>> gen = gen_pipeline(text)
```
"""
# inheritage from PushToHubMixin added as workaround for transformers>=4.52.0 and nncf<=2.16.0 compatibility
# during dataset preparatioon nncf checks isinstance(model, PreTrainedModel.__bases__)
# in transformers 4.52.0 PreTrainedModel does not include GenerationMixin and this check failed for OVModelForCausalLM
# TO DO: remove it after migration on new nncf
@add_start_docstrings(
"""
Base OVBaseDecoderModel class.
""",
)
class OVBaseDecoderModel(OVModel, PushToHubMixin):
def __init__(
self,
model: openvino.Model,
config: PretrainedConfig = None,
device: str = "CPU",
dynamic_shapes: bool = True,
ov_config: Optional[Dict[str, str]] = None,
model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
quantization_config: Optional[Union[OVWeightQuantizationConfig, Dict]] = None,
**kwargs,
):
if not dynamic_shapes:
raise ValueError(
"`dynamic_shapes` was set to `False` but static shapes are not supported for causal language model. Please set `dynamic_shapes=True`."
)
compile_only = kwargs.get("compile_only", False)
enable_compilation = kwargs.get("compile", True)
kwargs["compile"] = False or compile_only # avoid extra compilation in the base class
if compile_only and not enable_compilation:
raise ValueError(
"`compile_only` mode does not support disabling compilation."
"Please provide `compile=True` if you want to use `compile_only=True` or set `compile_only=False`"
)
config.is_encoder_decoder = False
super().__init__(
model,
config,
device=device,
dynamic_shapes=False if not compile_only else model_has_dynamic_inputs(model),
ov_config=ov_config,
model_save_dir=model_save_dir,
quantization_config=quantization_config,
**kwargs,
)
self.is_dynamic = dynamic_shapes
use_cache = kwargs.pop("use_cache", True)
model_has_sinks = model_has_state(self.model)
self.use_cache = any("past_key_values" in key.get_any_name() for key in model.inputs) or model_has_sinks
stateful = kwargs.pop("stateful", None) # stateful model only if it is converted with stateful=True
self.stateful = model_has_sinks
self.main_input_name = "input_ids"
self.num_pkv = 2
self.key_value_input_names = [key for key in self.input_names if "key_values" in key]
self.key_value_output_names = [key for key in self.output_names if "present" in key]
# Keeping the original model for serialization
self._pkv_precision = Type.f32
self.next_beam_idx = None
self._past_length = 0
self._first_iter_beam_search = False
self._second_iter_beam_search = False
self.update_pkv_precision()
if self.is_dynamic and not self._compile_only:
self.model = self._reshape(self.model, -1, -1)
is_stateful_supported = ensure_stateful_is_available(warn=False)
if self.use_cache and not self.stateful:
logger.warning(
"Provided model does not contain state. It may lead to sub-optimal performance."
"Please reexport model with updated OpenVINO version >= 2023.3.0 calling the `from_pretrained` method with original model "
"and `export=True` parameter"
)
if self.stateful:
if stateful is None:
stateful = is_stateful_supported
if model_has_sinks and not is_stateful_supported:
raise ValueError(
"Loaded stateful model, while OpenVINO runtime version does not support stateful model inference. "
"Please update OpenVINO version >= 2023.3.0 "
"or export the original model once again with `stateful=False` when calling the `from_pretrained` method."
"To export your model, simply set `export=True`."
)
def raise_error(model_prop, user_prop, name):
raise ValueError(
f"`{name}` was set to `{user_prop}` but the loaded model only supports `{name}={model_prop}`. "
f"Please load your current model with `{name}={model_prop}` or export the original model "
f"once again with `{name}={user_prop}` when calling the `from_pretrained` method. "
"To export your model, simply set `export=True`."
)
if stateful is not None and stateful ^ self.stateful:
# We cannot transform stateful model to stateless
raise_error(self.stateful, stateful, "stateful")
if use_cache ^ self.use_cache:
raise_error(self.use_cache, use_cache, "use_cache")
if self._compile_only:
self.request = self.model.create_infer_request()
if not self._compile_only and enable_compilation:
self.compile()
@staticmethod
def _get_model_with_updated_pkv_precision(model: openvino.Model, pkv_precision: Type) -> openvino.Model:
ppp = PrePostProcessor(model)
for key in model.inputs:
if "past_key_values" in key.get_any_name() and pkv_precision != key.get_element_type():
ppp.input(key.get_any_name()).tensor().set_element_type(pkv_precision)
for key in model.outputs:
if "present" in key.get_any_name() and pkv_precision != key.get_element_type():
ppp.output(key.get_any_name()).tensor().set_element_type(pkv_precision)
return ppp.build()
def update_pkv_precision(self, force_fp32=False):
if not self.use_cache or self.stateful or self._compile_only:
return
pkv_precision = Type.f32
if not force_fp32:
device = self._device.upper()
try:
if "INFERENCE_PRECISION_HINT" in core.get_property(device, "SUPPORTED_PROPERTIES"):
pkv_precision = core.get_property(device, "INFERENCE_PRECISION_HINT")
except RuntimeError: # use default precision when get_property fails, e.g. when device is "AUTO:GPU"
pass
# ov_config["INFERENCE_PRECISION_HINT"] may override the prefer precision
if self.ov_config:
inference_precision_hint = self.ov_config.get("INFERENCE_PRECISION_HINT", "")
if inference_precision_hint in STR_TO_OV_TYPE:
pkv_precision = STR_TO_OV_TYPE[inference_precision_hint]
self.model = self._get_model_with_updated_pkv_precision(self.model, pkv_precision)
self._pkv_precision = pkv_precision
self.request = None
else:
if hasattr(self, "_pkv_precision") and self._pkv_precision != Type.f32:
self.model = self._get_model_with_updated_pkv_precision(self.model, Type.f32)
self._pkv_precision = Type.f32
if self.is_dynamic:
self.model = self._reshape(self.model, -1, -1)
self.request = None
def _save_pretrained(self, save_directory: Union[str, Path]):
"""
Saves the model to the OpenVINO IR format so that it can be re-loaded using the
[`~optimum.intel.openvino.modeling.OVModel.from_pretrained`] class method.
Arguments:
save_directory (`str` or `Path`):
The directory where to save the model files.
"""
if self._compile_only:
raise ValueError(
"`save_pretrained()` is not supported with `compile_only` mode, please initialize model without this option"
)
model_to_save = (
self.model
if self._pkv_precision == Type.f32
else self._get_model_with_updated_pkv_precision(self.model.clone(), Type.f32)
)
dst_path = os.path.join(save_directory, OV_XML_FILE_NAME)
openvino.save_model(model_to_save, dst_path, compress_to_fp16=False)
if self.generation_config is not None:
try:
self.generation_config.save_pretrained(save_directory)
except Exception as exception:
logger.warning(
f"The generation config will not be saved, saving failed with following error:\n{exception}"
)
self._save_openvino_config(save_directory)
@classmethod
def _export(
cls,
model_id: str,
config: PretrainedConfig,
token: Optional[Union[bool, str]] = None,
revision: Optional[str] = None,
force_download: bool = False,
cache_dir: str = HUGGINGFACE_HUB_CACHE,
subfolder: str = "",
local_files_only: bool = False,
task: Optional[str] = None,
use_cache: bool = True,
trust_remote_code: bool = False,
load_in_8bit: Optional[bool] = None,
quantization_config: Optional[Union[OVWeightQuantizationConfig, Dict]] = None,
**kwargs,
):
save_dir = TemporaryDirectory()
save_dir_path = Path(save_dir.name)
# This attribute is needed to keep one reference on the temporary directory, since garbage collecting
# would end-up removing the directory containing the underlying OpenVINO model
cls._model_save_dir_tempdirectory_instance = save_dir
compile_only = kwargs.pop("compile_only", False)
if compile_only:
logger.warning(
"`compile_only` mode will be disabled because it does not support model export."
"Please provide openvino model obtained using optimum-cli or saved on disk using `save_pretrained`"
)
compile_only = False
if task is None:
task = cls.export_feature
if use_cache:
task = task + "-with-past"
# If load_in_8bit and quantization_config not specified then ov_config is set to None and will be set by default in convert depending on the model size
if load_in_8bit is None and not quantization_config:
ov_export_config = None
else:
ov_export_config = OVConfig(dtype="auto")
stateful = kwargs.pop("stateful", ensure_stateful_is_available(warn=False) and use_cache)
torch_dtype = kwargs.pop("torch_dtype", None)
model_loading_kwargs = {}
if torch_dtype is not None:
model_loading_kwargs["torch_dtype"] = torch_dtype
variant = kwargs.pop("variant", None)
main_export(
model_name_or_path=model_id,
output=save_dir_path,
task=task,
subfolder=subfolder,
revision=revision,
cache_dir=cache_dir,
token=token,
local_files_only=local_files_only,
force_download=force_download,
trust_remote_code=trust_remote_code,
ov_config=ov_export_config,
stateful=stateful,
model_loading_kwargs=model_loading_kwargs,
library_name=cls._library_name,
variant=variant,
)
if config.model_type == "phi3" and config.max_position_embeddings != getattr(
config, "original_max_position_embeddings", config.max_position_embeddings
):
config.max_position_embeddings = config.original_max_position_embeddings
return cls._from_pretrained(
model_id=save_dir_path,
config=config,
use_cache=use_cache,
stateful=None,
load_in_8bit=load_in_8bit,
quantization_config=quantization_config,
trust_remote_code=trust_remote_code,
compile_only=compile_only,
**kwargs,
)
def _reshape(
self,
model: openvino.Model,
batch_size: int,
sequence_length: int,
height: int = None,
width: int = None,
):
if self._compile_only:
raise ValueError(
"`reshape()` is not supported with `compile_only` mode, please initialize model without this option"
)
if height is not None:
logger.warning(f"`height` set to `{height}` will be ignored during reshaping operation.")
if width is not None:
logger.warning(f"`width` set to `{width}` will be ignored during reshaping operation.")
shapes = {}
for inputs in model.inputs:
shapes[inputs] = inputs.get_partial_shape()
shapes[inputs][0] = -1
input_name = inputs.get_any_name()
if input_name.startswith("past_key_values"):
if (len(inputs.partial_shape) == 3 and input_name.endswith("value")) or (
self.config.model_type == "chatglm" and not hasattr(self.config, "rope_ratio")
):
shapes[inputs][1] = -1
else:
shapes[inputs][2] = -1
elif input_name.startswith("beam_idx"):
shapes[inputs][0] = -1
else:
shapes[inputs][1] = -1
model.reshape(shapes)
return model
def reshape(self, batch_size: int, sequence_length: int):
logger.warning("Static shapes are not supported for causal language model.")
return self
@property
def normalized_config(self):
logger.warning(
"access to normalized_config attribute is deprecated and will be removed in future versions, please use config"
)
return NormalizedConfigManager.get_normalized_config_class(self.config.model_type)(self.config)
def compile(self):
if self.request is None:
if self._compile_only:
self.request = self.model.create_infer_request()
super().compile()
self.request = self.request.create_infer_request()
def _make_stateful(self):
patch_stateful(self.config, self.model)
self.stateful = True
@add_start_docstrings(
"""
OpenVINO Model with a causal language modeling head on top (linear layer with weights tied to the input
embeddings).
""",
MODEL_START_DOCSTRING,
)
class OVModelForCausalLM(OVBaseDecoderModel, GenerationMixin):
export_feature = "text-generation"
auto_model_class = AutoModelForCausalLM
@add_start_docstrings_to_model_forward(
INPUTS_DOCSTRING.format("batch_size, sequence_length")
+ TEXT_GENERATION_EXAMPLE.format(
processor_class=_TOKENIZER_FOR_DOC,
model_class="OVModelForCausalLM",
checkpoint="gpt2",
)
)
def prepare_inputs(
self,
input_ids: torch.LongTensor,
attention_mask: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
position_ids: Optional[torch.LongTensor] = None,
**kwargs,
) -> Dict:
batch_size = input_ids.shape[0]
model_transformers_version = get_export_transformers_version(self.model, self.config)
if self.config.model_type == "bloom" and compare_versions(model_transformers_version, "<", "4.44"):
batch_size *= self.config.num_attention_heads
inputs = {}
if not self.stateful:
if past_key_values is not None:
if self.config.model_type not in MULTI_QUERY_ATTN_MODELS or (
self.config.model_type == "falcon" and self.config.new_decoder_architecture
):
if self._pkv_precision == Type.bf16:
# numpy does not support bf16, pretending f16, should change to bf16
past_key_values = tuple(
Tensor(past_key_value, past_key_value.shape, Type.bf16)
for pkv_per_layer in past_key_values
for past_key_value in pkv_per_layer
)
else:
# Flatten the past_key_values
past_key_values = tuple(
past_key_value for pkv_per_layer in past_key_values for past_key_value in pkv_per_layer
)
# Add the past_key_values to the decoder inputs
inputs = dict(zip(self.key_value_input_names, past_key_values))
# Create empty past_key_values for decoder_with_past first generation step
elif self.use_cache:
for input_name in self.key_value_input_names:
model_inputs = self.model.input(input_name)
shape = model_inputs.get_partial_shape()
if self.config.model_type == "chatglm" and not hasattr(self.config, "rope_ratio"):
shape[0] = 0
shape[1] = batch_size
else:
shape[0] = batch_size
if shape[2].is_dynamic:
shape[2] = 0
else:
shape[1] = 0
inputs[input_name] = Tensor(model_inputs.get_element_type(), [dim.get_length() for dim in shape])
else:
# past_key_values are not used explicitly, instead they are handled inside the model
if past_key_values is None:
# This is the first iteration in a sequence, reset all states
if self.request is not None:
self.request.reset_state()
# Set initial value for the next beam_idx input that will be used at the current iteration
# and will be optionally updated by _reorder_cache at the next iterations if beam_search is used
self.next_beam_idx = np.arange(batch_size, dtype=int)
self._past_length = 0
past_len = self._get_past_length(past_key_values)
inputs["input_ids"] = input_ids.cpu().numpy()
# Add the attention_mask inputs when needed
if "attention_mask" in self.input_names or "position_ids" in self.input_names:
if attention_mask is not None:
attention_mask = attention_mask.cpu().numpy()
else:
attention_mask = np.ones(
(input_ids.shape[0], input_ids.shape[1] + past_len), dtype=inputs["input_ids"].dtype
)
if "attention_mask" in self.input_names:
inputs["attention_mask"] = attention_mask
if "position_ids" in self.input_names:
if position_ids is not None:
position_ids = position_ids.cpu().numpy()
else:
position_ids = np.cumsum(attention_mask, axis=1) - 1
position_ids[attention_mask == 0] = 1
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :]
inputs["position_ids"] = position_ids
if "beam_idx" in self.input_names:
inputs["beam_idx"] = (
self.next_beam_idx if self.next_beam_idx is not None else np.arange(batch_size, dtype=int)
)
return inputs
def forward(
self,
input_ids: torch.LongTensor,
attention_mask: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
position_ids: Optional[torch.LongTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
**kwargs,
) -> CausalLMOutputWithPast:
self.compile()
# added as model.generate validates model inputs based on forward signature
kwargs["token_type_ids"] = token_type_ids
inputs = self.prepare_inputs(
input_ids=input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
position_ids=position_ids,
**kwargs,
)
if self._first_iter_beam_search:
inputs, duplication_indices = self._deduplicate_inputs(inputs)
# Run inference
self.request.start_async(inputs, share_inputs=True)
self.request.wait()
logits = torch.from_numpy(self.request.get_tensor("logits").data).clone().to(self.device)
if self.stateful:
# Need a marker to differentiate the first generate iteration from the others in
# the first condition at the function beginning above.
# It should be something that is not None and it should be True when converted to Boolean.
past_key_values = ((),)
self._past_length += input_ids.shape[1]
if not self.stateful:
if self.use_cache:
# Tuple of length equal to : number of layer * number of past_key_value per decoder layer (2 corresponds to the self-attention layer)
past_key_values = tuple(
np.copy(self.request.get_tensor(key).data) for key in self.key_value_output_names
)
if self.config.model_type not in MULTI_QUERY_ATTN_MODELS or (
self.config.model_type == "falcon" and self.config.new_decoder_architecture
):
# Tuple of tuple of length `n_layers`, with each tuple of length equal to 2 (k/v of self-attention)
past_key_values = tuple(
past_key_values[i : i + self.num_pkv] for i in range(0, len(past_key_values), self.num_pkv)
)
else:
past_key_values = None
if self._first_iter_beam_search:
logits, past_key_values = self._expand_outputs_for_generation(duplication_indices, logits, past_key_values)
self._first_iter_beam_search = False
return CausalLMOutputWithPast(logits=logits, past_key_values=past_key_values)
# Adapted from transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
attention_mask = kwargs.get("attention_mask", None)
use_cache = kwargs.get("use_cache", None)
if past_key_values is not None:
past_len = self._get_past_length(past_key_values)
# Keep only the unprocessed tokens:
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
# input)
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
input_ids = input_ids[:, -(attention_mask.shape[1] - past_len) :]
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
# input_ids based on the past_length.
elif past_len < input_ids.shape[1]:
input_ids = input_ids[:, past_len:]
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens
position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None and "position_ids" in self.input_names:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :]
model_inputs = {
"input_ids": input_ids,
"past_key_values": past_key_values,
"use_cache": use_cache,
"position_ids": position_ids,
"attention_mask": attention_mask,
}
return model_inputs
def _update_model_kwargs_for_generation(
self,
outputs: ModelOutput,
model_kwargs: Dict[str, Any],
is_encoder_decoder: bool = False,
**kwargs,
) -> Dict[str, Any]:
model_kwargs = super()._update_model_kwargs_for_generation(
outputs=outputs, model_kwargs=model_kwargs, is_encoder_decoder=is_encoder_decoder, **kwargs
)
if "position_ids" in model_kwargs:
position_ids = model_kwargs["position_ids"]
new_position_id = position_ids[..., -1:].clone()
new_position_id += 1
model_kwargs["position_ids"] = torch.cat([position_ids, new_position_id], dim=-1)
return model_kwargs
def _expand_outputs_for_generation(self, indicies, logits: torch.Tensor, past_key_values: Tuple):
batch_size = logits.shape[0]
if indicies.shape[0] != 1:
logits = logits[indicies]
if past_key_values and not self.stateful:
if self.config.model_type not in MULTI_QUERY_ATTN_MODELS or (
self.config.model_type == "falcon" and self.config.new_decoder_architecture
):
past_key_values = tuple(
tuple(
(
past_state[indicies]
if not (self.config.model_type == "chatglm" and not hasattr(self.config, "rope_ratio"))
else past_state[:, indicies, ...]
)
for past_state in layer_past
)
for layer_past in past_key_values
)
else:
past_key_values = tuple([past_state[indicies] for past_state in past_key_values])
if self.stateful:
self.next_beam_idx = (
self.next_beam_idx[indicies]
if self.next_beam_idx is not None
else np.arange(batch_size, dtype=int)[indicies]
)
self._second_iter_beam_search = True
return logits, past_key_values
def _deduplicate_inputs(self, model_inputs: Dict):
input_ids = model_inputs["input_ids"]
upd_model_inputs = {}
unique_input_ids, indicies, reverse_indicies = np.unique(
input_ids, axis=0, return_index=True, return_inverse=True
)
export_transformers_version = get_export_transformers_version(self.model, self.config)
for input_name, input_tensor in model_inputs.items():
if input_name not in ["input_ids", "beam_idx"]:
if input_name not in self.key_value_input_names:
upd_model_inputs[input_name] = input_tensor[indicies]
else:
shape = input_tensor.shape if isinstance(input_tensor, Tensor) else list(input_tensor.shape)
dtype = input_tensor.element_type if isinstance(input_tensor, Tensor) else Type(input_tensor.dtype)
upd_batch_size = indicies.shape[0]
if self.config.model_type == "bloom" and compare_versions(
export_transformers_version, "<", "4.44"
):
upd_batch_size *= self.config.num_attention_heads
shape[
(
0
if not (self.config.model_type == "chatglm" and not hasattr(self.config, "rope_ratio"))
else 1
)
] = upd_batch_size
upd_model_inputs[input_name] = Tensor(dtype, shape)
upd_model_inputs["input_ids"] = unique_input_ids
if "beam_idx" in model_inputs:
beam_range = (
unique_input_ids.shape[0] * self.config.num_attention_heads
if (self.config.model_type == "bloom" and compare_versions(export_transformers_version, "<", "4.44"))
else unique_input_ids.shape[0]
)
beam_idx = np.arange(beam_range, dtype=int)
upd_model_inputs["beam_idx"] = beam_idx
return upd_model_inputs, reverse_indicies
@torch.no_grad()
def generate(
self,
inputs: Optional[torch.Tensor] = None,
generation_config: Optional[GenerationConfig] = None,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
synced_gpus: Optional[bool] = None,
assistant_model: Optional["PreTrainedModel"] = None,
streamer: Optional["BaseStreamer"] = None,
negative_prompt_ids: Optional[torch.Tensor] = None,
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
**kwargs,
) -> Union[GenerateOutput, torch.LongTensor]:
if is_transformers_version(">=", "4.39.0"):
_generation_config, _ = self._prepare_generation_config(generation_config, **kwargs)
generation_mode = _generation_config.get_generation_mode(assistant_model)
else:
_generation_config = generation_config or self.generation_config
generation_mode = self._get_generation_mode(_generation_config, assistant_model)
is_beam_search = generation_mode in [
GenerationMode.BEAM_SEARCH,
GenerationMode.BEAM_SAMPLE,
GenerationMode.GROUP_BEAM_SEARCH,
GenerationMode.CONSTRAINED_BEAM_SEARCH,
]
if is_beam_search:
self._first_iter_beam_search = True
result = super().generate(
inputs,
generation_config,
logits_processor,
stopping_criteria,
prefix_allowed_tokens_fn,
synced_gpus,
assistant_model,
streamer,
negative_prompt_ids,
negative_prompt_attention_mask,
**kwargs,
)
return result
def _get_past_length(self, past_key_values=None):
if past_key_values is None:
return 0
if self.stateful:
return self._past_length
if self.config.model_type in MULTI_QUERY_ATTN_MODELS and not (
self.config.model_type == "falcon" and self.config.new_decoder_architecture
):
return past_key_values[0].shape[-2]
seq_length_dim = -2
if self.config.model_type == "chatglm" and not hasattr(self.config, "rope_ratio"):
seq_length_dim = 0
elif self.config.model_type == "qwen":
seq_length_dim = 1
# input is tuple of pairs
if isinstance(past_key_values[0], (tuple, list)):
return past_key_values[0][1].shape[seq_length_dim]
# past key values comes after flattening
return past_key_values[1].shape[seq_length_dim]
# Adapted from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel._reorder_cache
def _reorder_cache(
self, past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
) -> Tuple[Tuple[torch.Tensor]]:
"""
This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
[`~PreTrainedModel.beam_sample`] is called.
This is required to match `past_key_values` with the correct beam_idx at every generation step.
"""
if self.stateful:
# TODO: Apply it differently based on model type
# TODO: At least for bloom we need to replicate values for each attention head
self.next_beam_idx = (
np.array(beam_idx) if not self._second_iter_beam_search else self.next_beam_idx
) # save beam_idx to be used as an input in the next iteration
self._second_iter_beam_search = False
return past_key_values
else:
if self.config.model_type not in MULTI_QUERY_ATTN_MODELS or (
self.config.model_type == "falcon" and self.config.new_decoder_architecture
):
return tuple(
tuple(np.take(past_state, beam_idx, 0) for past_state in layer_past)
for layer_past in past_key_values
)
return tuple(np.take(past_state, beam_idx, 0) for past_state in past_key_values)
@classmethod
def _from_pretrained(
cls,
model_id: Union[str, Path],
config: PretrainedConfig,
token: Optional[Union[bool, str]] = None,
revision: Optional[Union[str, None]] = None,
force_download: bool = False,
cache_dir: str = HUGGINGFACE_HUB_CACHE,
file_name: Optional[str] = None,
subfolder: str = "",
from_onnx: bool = False,
local_files_only: bool = False,
load_in_8bit: bool = False,
compile_only: bool = False,
quantization_config: Optional[Union[OVWeightQuantizationConfig, Dict]] = None,
**kwargs,
):
generation_config = kwargs.pop("generation_config", None)
model_path = Path(model_id)
default_file_name = ONNX_WEIGHTS_NAME if from_onnx else OV_XML_FILE_NAME
file_name = file_name or default_file_name
model_cache_path = cls._cached_file(
model_path=model_path,
token=token,
revision=revision,
force_download=force_download,
cache_dir=cache_dir,
file_name=file_name,
subfolder=subfolder,
local_files_only=local_files_only,
)
if not compile_only:
model = cls.load_model(model_cache_path)
else:
model = cls._compile_model(
model_cache_path, kwargs.get("device", "CPU"), kwargs.get("ov_config"), model_cache_path.parent
)
model_type = config.model_type.replace("_", "-")
export_transformers_version = get_export_transformers_version(model, config)
if model_type == "bloom" and compare_versions(export_transformers_version, "<", "4.44"):
init_cls = OVBloomForCausalLM
elif model_type == "gpt-bigcode":
init_cls = OVGPTBigCodeForCausalLM
else:
init_cls = cls
if isinstance(quantization_config, dict) and quantization_config == {"bits": 4}:
default_config = get_default_quantization_config(config.name_or_path, weight_format="int4")
quantization_config = cls._prepare_quantization_config(
default_config or _DEFAULT_4BIT_WQ_CONFIG, load_in_8bit
)
if quantization_config.dataset is not None:
quantization_config.trust_remote_code = kwargs.get("trust_remote_code", False)
else:
quantization_config = cls._prepare_quantization_config(quantization_config, load_in_8bit)
if isinstance(quantization_config, OVWeightQuantizationConfig) and quantization_config.bits == 4:
default_config = get_default_quantization_config(config.name_or_path, weight_format="int4")
if default_config:
logger.info(
f"For the given model, we recommend the following `quantization_config` : {default_config}"
)
enable_compilation = kwargs.pop("compile", True) and not quantization_config
if generation_config is None:
try:
generation_config = GenerationConfig.from_pretrained(
model_id,
cache_dir=cache_dir,
force_download=force_download,
local_files_only=local_files_only,
token=token,
revision=revision,
subfolder=subfolder,
)
if getattr(generation_config, "cache_implementation", None) is not None:
generation_config.cache_implementation = None
except OSError:
logger.info(
"Generation config file not found, using a generation config created from the model config."
)
causal_model = init_cls(
model=model,
config=config,
model_save_dir=model_cache_path.parent,
compile=enable_compilation,
compile_only=compile_only,
quantization_config=quantization_config,
generation_config=generation_config,
**kwargs,
)
if quantization_config:
if not is_nncf_available():
raise ImportError(
"Quantization of the weights requires nncf, please install it with `pip install nncf`"
)
if compile_only:
raise ValueError(
"quantization is not supported with `compile_only` mode, please initialize model without this option"
)
from optimum.intel.openvino.quantization import OVQuantizer
quantizer = OVQuantizer(causal_model)
quantization_config_copy = copy.deepcopy(quantization_config)
quantization_config_copy.tokenizer = quantization_config.tokenizer or model_id
quantizer.quantize(ov_config=OVConfig(quantization_config=quantization_config_copy))
return causal_model
class OVBloomForCausalLM(OVModelForCausalLM):
# Adapted from transformers.models.bloom.modeling_bloom.BloomForCausalLM.prepare_inputs_for_generation
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
# only last token for input_ids if past is not None
if past_key_values and not self.stateful:
# the cache may be in the stardard format (e.g. in contrastive search), convert to bloom's format if needed
if past_key_values[0][0].shape[0] == input_ids.shape[0]:
past_key_values = self._convert_to_bloom_cache(past_key_values)
return super().prepare_inputs_for_generation(input_ids, past_key_values=past_key_values, **kwargs)
# Adapted from transformers.models.bloom.modeling_bloom.BloomForCausalLM._reorder_cache
def _reorder_cache(
self, past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
) -> Tuple[Tuple[torch.Tensor]]:
"""
This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
[`~PreTrainedModel.beam_sample`] is called for bloom architecture.
This is required to match `past_key_values` with the correct beam_idx at every generation step.
"""
if self.stateful:
batch_size = beam_idx.shape[0]
beam_idx = np.array(beam_idx) if not self._second_iter_beam_search else self.next_beam_idx
indices = np.array(range(batch_size * self.config.num_attention_heads))
indices = indices.reshape([batch_size, self.config.num_attention_heads])
self.next_beam_idx = np.take(indices, beam_idx, 0).flatten()
self._second_iter_beam_search = False
return past_key_values
else:
standardized_past = self._convert_to_standard_cache(past_key_values, batch_size=len(beam_idx))
reordered_past = tuple(
(
np.take(layer_past[0], beam_idx, 0),
np.take(layer_past[1], beam_idx, 0),
)
for layer_past in standardized_past
)
return self._convert_to_bloom_cache(reordered_past)
# Copied from transformers.models.bloom.modeling_bloom.BloomPreTrainedModel._convert_to_bloom_cache
@staticmethod
def _convert_to_bloom_cache(past_key_value: Tuple[Tuple[torch.Tensor]]) -> Tuple[Tuple[torch.Tensor]]:
"""
Converts the cache to the format expected by Bloom, i.e. to tuple(tuple([batch_size * num_heads, ...]))
"""
batch_size, num_heads, head_dim, seq_length = past_key_value[0][0].shape
batch_size_times_num_heads = batch_size * num_heads
# key: [batch_size, num_heads, head_dim, seq_length] -> [batch_size * num_heads, head_dim, seq_length]
# value: [batch_size, num_heads, seq_length, head_dim] -> [batch_size * num_heads, seq_length, head_dim]
return tuple(
(
layer_past[0].reshape((batch_size_times_num_heads, head_dim, seq_length)),
layer_past[1].reshape((batch_size_times_num_heads, seq_length, head_dim)),
)
for layer_past in past_key_value
)
# Adapted from transformers.models.bloom.modeling_bloom.BloomPreTrainedModel._convert_to_standard_cache
def _convert_to_standard_cache(
self, past_key_value: Tuple[Tuple[torch.Tensor]], batch_size: int
) -> Tuple[Tuple[torch.Tensor]]:
"""
Standardizes the format of the cache so as to match most implementations, i.e. to tuple(tuple([batch_size, num_heads, ...]))
"""
batch_size_times_num_heads, head_dim, seq_length = past_key_value[0][0].shape
num_heads = batch_size_times_num_heads // batch_size
# key: [batch_size * num_heads, head_dim, seq_length] -> [batch_size, num_heads, head_dim, seq_length]
# value: [batch_size * num_heads, seq_length, head_dim] -> [batch_size, num_heads, seq_length, head_dim]
return tuple(
(
layer_past[0].reshape((batch_size, num_heads, head_dim, seq_length)),
layer_past[1].reshape((batch_size, num_heads, seq_length, head_dim)),
)
for layer_past in past_key_value
)
def _expand_outputs_for_generation(self, indicies, logits: torch.Tensor, past_key_values: Tuple):
batch_size = logits.shape[0]
if indicies.shape[0] != 1:
logits = logits[indicies]
if past_key_values and not self.stateful:
pkv_standard = self._convert_to_standard_cache(past_key_values, batch_size)
pkv = tuple(tuple(past_state[indicies] for past_state in layer_past) for layer_past in pkv_standard)
past_key_values = self._convert_to_bloom_cache(pkv)
if self.stateful:
self.next_beam_idx = (
self.next_beam_idx[indicies]
if self.next_beam_idx is not None
else np.arange(batch_size, dtype=int)[indicies]
)
self._second_iter_beam_search = True
return logits, past_key_values
class OVGPTBigCodeForCausalLM(OVModelForCausalLM):
# Adapted from transformers.models.gpt_bigcode.modeling_gpt_bigcode.GPTBigCodeForCausalLM._reorder_cache
def _reorder_cache(
self, past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
) -> Tuple[Tuple[torch.Tensor]]:
if self.stateful:
# save beam_idx to be used as an input in the next iteration
self.next_beam_idx = np.array(beam_idx) if not self._second_iter_beam_search else self.next_beam_idx
self._second_iter_beam_search = False
return past_key_values
else:
return tuple(np.take(layer_past, beam_idx, 0) for layer_past in past_key_values)