optimum/intel/openvino/modeling_decoder.py (813 lines of code) (raw):

# Copyright 2022 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import copy import logging import os from pathlib import Path from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np import openvino import torch from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE from openvino import Core, Tensor, Type from openvino.preprocess import PrePostProcessor from transformers import AutoModelForCausalLM, PretrainedConfig from transformers.file_utils import add_start_docstrings, add_start_docstrings_to_model_forward from transformers.generation import GenerationMixin from transformers.generation.configuration_utils import GenerationConfig from transformers.generation.logits_process import LogitsProcessorList from transformers.generation.stopping_criteria import StoppingCriteriaList from transformers.generation.utils import GenerateOutput, GenerationMode from transformers.modeling_outputs import CausalLMOutputWithPast, ModelOutput from transformers.utils.hub import PushToHubMixin from optimum.utils.normalized_config import NormalizedConfigManager from ...exporters.openvino import ensure_stateful_is_available, main_export, patch_stateful from ...exporters.openvino.stateful import model_has_state from ..utils.import_utils import compare_versions, is_nncf_available, is_transformers_version from ..utils.modeling_utils import MULTI_QUERY_ATTN_MODELS from .configuration import ( _DEFAULT_4BIT_WQ_CONFIG, OVConfig, OVWeightQuantizationConfig, get_default_quantization_config, ) from .modeling import _TOKENIZER_FOR_DOC, INPUTS_DOCSTRING, MODEL_START_DOCSTRING, OVModel from .utils import ( ONNX_WEIGHTS_NAME, OV_XML_FILE_NAME, STR_TO_OV_TYPE, TemporaryDirectory, get_export_transformers_version, model_has_dynamic_inputs, ) if TYPE_CHECKING: try: from transformers.generation.streamers import BaseStreamer except Exception: from typing import Generator as BaseStreamer from transformers.modeling_utils import PreTrainedModel logger = logging.getLogger(__name__) core = Core() TEXT_GENERATION_EXAMPLE = r""" Example of text generation: ```python >>> from transformers import {processor_class} >>> from optimum.intel import {model_class} >>> tokenizer = {processor_class}.from_pretrained("{checkpoint}") >>> model = {model_class}.from_pretrained("{checkpoint}", export=True) >>> inputs = tokenizer("I love this story because", return_tensors="pt") >>> gen_tokens = model.generate(**inputs, do_sample=True, temperature=0.9, min_length=20, max_length=20) >>> tokenizer.batch_decode(gen_tokens) ``` Example using `transformers.pipelines`: ```python >>> from transformers import {processor_class}, pipeline >>> from optimum.intel import {model_class} >>> tokenizer = {processor_class}.from_pretrained("{checkpoint}") >>> model = {model_class}.from_pretrained("{checkpoint}", export=True) >>> gen_pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer) >>> text = "I love this story because" >>> gen = gen_pipeline(text) ``` """ # inheritage from PushToHubMixin added as workaround for transformers>=4.52.0 and nncf<=2.16.0 compatibility # during dataset preparatioon nncf checks isinstance(model, PreTrainedModel.__bases__) # in transformers 4.52.0 PreTrainedModel does not include GenerationMixin and this check failed for OVModelForCausalLM # TO DO: remove it after migration on new nncf @add_start_docstrings( """ Base OVBaseDecoderModel class. """, ) class OVBaseDecoderModel(OVModel, PushToHubMixin): def __init__( self, model: openvino.Model, config: PretrainedConfig = None, device: str = "CPU", dynamic_shapes: bool = True, ov_config: Optional[Dict[str, str]] = None, model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None, quantization_config: Optional[Union[OVWeightQuantizationConfig, Dict]] = None, **kwargs, ): if not dynamic_shapes: raise ValueError( "`dynamic_shapes` was set to `False` but static shapes are not supported for causal language model. Please set `dynamic_shapes=True`." ) compile_only = kwargs.get("compile_only", False) enable_compilation = kwargs.get("compile", True) kwargs["compile"] = False or compile_only # avoid extra compilation in the base class if compile_only and not enable_compilation: raise ValueError( "`compile_only` mode does not support disabling compilation." "Please provide `compile=True` if you want to use `compile_only=True` or set `compile_only=False`" ) config.is_encoder_decoder = False super().__init__( model, config, device=device, dynamic_shapes=False if not compile_only else model_has_dynamic_inputs(model), ov_config=ov_config, model_save_dir=model_save_dir, quantization_config=quantization_config, **kwargs, ) self.is_dynamic = dynamic_shapes use_cache = kwargs.pop("use_cache", True) model_has_sinks = model_has_state(self.model) self.use_cache = any("past_key_values" in key.get_any_name() for key in model.inputs) or model_has_sinks stateful = kwargs.pop("stateful", None) # stateful model only if it is converted with stateful=True self.stateful = model_has_sinks self.main_input_name = "input_ids" self.num_pkv = 2 self.key_value_input_names = [key for key in self.input_names if "key_values" in key] self.key_value_output_names = [key for key in self.output_names if "present" in key] # Keeping the original model for serialization self._pkv_precision = Type.f32 self.next_beam_idx = None self._past_length = 0 self._first_iter_beam_search = False self._second_iter_beam_search = False self.update_pkv_precision() if self.is_dynamic and not self._compile_only: self.model = self._reshape(self.model, -1, -1) is_stateful_supported = ensure_stateful_is_available(warn=False) if self.use_cache and not self.stateful: logger.warning( "Provided model does not contain state. It may lead to sub-optimal performance." "Please reexport model with updated OpenVINO version >= 2023.3.0 calling the `from_pretrained` method with original model " "and `export=True` parameter" ) if self.stateful: if stateful is None: stateful = is_stateful_supported if model_has_sinks and not is_stateful_supported: raise ValueError( "Loaded stateful model, while OpenVINO runtime version does not support stateful model inference. " "Please update OpenVINO version >= 2023.3.0 " "or export the original model once again with `stateful=False` when calling the `from_pretrained` method." "To export your model, simply set `export=True`." ) def raise_error(model_prop, user_prop, name): raise ValueError( f"`{name}` was set to `{user_prop}` but the loaded model only supports `{name}={model_prop}`. " f"Please load your current model with `{name}={model_prop}` or export the original model " f"once again with `{name}={user_prop}` when calling the `from_pretrained` method. " "To export your model, simply set `export=True`." ) if stateful is not None and stateful ^ self.stateful: # We cannot transform stateful model to stateless raise_error(self.stateful, stateful, "stateful") if use_cache ^ self.use_cache: raise_error(self.use_cache, use_cache, "use_cache") if self._compile_only: self.request = self.model.create_infer_request() if not self._compile_only and enable_compilation: self.compile() @staticmethod def _get_model_with_updated_pkv_precision(model: openvino.Model, pkv_precision: Type) -> openvino.Model: ppp = PrePostProcessor(model) for key in model.inputs: if "past_key_values" in key.get_any_name() and pkv_precision != key.get_element_type(): ppp.input(key.get_any_name()).tensor().set_element_type(pkv_precision) for key in model.outputs: if "present" in key.get_any_name() and pkv_precision != key.get_element_type(): ppp.output(key.get_any_name()).tensor().set_element_type(pkv_precision) return ppp.build() def update_pkv_precision(self, force_fp32=False): if not self.use_cache or self.stateful or self._compile_only: return pkv_precision = Type.f32 if not force_fp32: device = self._device.upper() try: if "INFERENCE_PRECISION_HINT" in core.get_property(device, "SUPPORTED_PROPERTIES"): pkv_precision = core.get_property(device, "INFERENCE_PRECISION_HINT") except RuntimeError: # use default precision when get_property fails, e.g. when device is "AUTO:GPU" pass # ov_config["INFERENCE_PRECISION_HINT"] may override the prefer precision if self.ov_config: inference_precision_hint = self.ov_config.get("INFERENCE_PRECISION_HINT", "") if inference_precision_hint in STR_TO_OV_TYPE: pkv_precision = STR_TO_OV_TYPE[inference_precision_hint] self.model = self._get_model_with_updated_pkv_precision(self.model, pkv_precision) self._pkv_precision = pkv_precision self.request = None else: if hasattr(self, "_pkv_precision") and self._pkv_precision != Type.f32: self.model = self._get_model_with_updated_pkv_precision(self.model, Type.f32) self._pkv_precision = Type.f32 if self.is_dynamic: self.model = self._reshape(self.model, -1, -1) self.request = None def _save_pretrained(self, save_directory: Union[str, Path]): """ Saves the model to the OpenVINO IR format so that it can be re-loaded using the [`~optimum.intel.openvino.modeling.OVModel.from_pretrained`] class method. Arguments: save_directory (`str` or `Path`): The directory where to save the model files. """ if self._compile_only: raise ValueError( "`save_pretrained()` is not supported with `compile_only` mode, please initialize model without this option" ) model_to_save = ( self.model if self._pkv_precision == Type.f32 else self._get_model_with_updated_pkv_precision(self.model.clone(), Type.f32) ) dst_path = os.path.join(save_directory, OV_XML_FILE_NAME) openvino.save_model(model_to_save, dst_path, compress_to_fp16=False) if self.generation_config is not None: try: self.generation_config.save_pretrained(save_directory) except Exception as exception: logger.warning( f"The generation config will not be saved, saving failed with following error:\n{exception}" ) self._save_openvino_config(save_directory) @classmethod def _export( cls, model_id: str, config: PretrainedConfig, token: Optional[Union[bool, str]] = None, revision: Optional[str] = None, force_download: bool = False, cache_dir: str = HUGGINGFACE_HUB_CACHE, subfolder: str = "", local_files_only: bool = False, task: Optional[str] = None, use_cache: bool = True, trust_remote_code: bool = False, load_in_8bit: Optional[bool] = None, quantization_config: Optional[Union[OVWeightQuantizationConfig, Dict]] = None, **kwargs, ): save_dir = TemporaryDirectory() save_dir_path = Path(save_dir.name) # This attribute is needed to keep one reference on the temporary directory, since garbage collecting # would end-up removing the directory containing the underlying OpenVINO model cls._model_save_dir_tempdirectory_instance = save_dir compile_only = kwargs.pop("compile_only", False) if compile_only: logger.warning( "`compile_only` mode will be disabled because it does not support model export." "Please provide openvino model obtained using optimum-cli or saved on disk using `save_pretrained`" ) compile_only = False if task is None: task = cls.export_feature if use_cache: task = task + "-with-past" # If load_in_8bit and quantization_config not specified then ov_config is set to None and will be set by default in convert depending on the model size if load_in_8bit is None and not quantization_config: ov_export_config = None else: ov_export_config = OVConfig(dtype="auto") stateful = kwargs.pop("stateful", ensure_stateful_is_available(warn=False) and use_cache) torch_dtype = kwargs.pop("torch_dtype", None) model_loading_kwargs = {} if torch_dtype is not None: model_loading_kwargs["torch_dtype"] = torch_dtype variant = kwargs.pop("variant", None) main_export( model_name_or_path=model_id, output=save_dir_path, task=task, subfolder=subfolder, revision=revision, cache_dir=cache_dir, token=token, local_files_only=local_files_only, force_download=force_download, trust_remote_code=trust_remote_code, ov_config=ov_export_config, stateful=stateful, model_loading_kwargs=model_loading_kwargs, library_name=cls._library_name, variant=variant, ) if config.model_type == "phi3" and config.max_position_embeddings != getattr( config, "original_max_position_embeddings", config.max_position_embeddings ): config.max_position_embeddings = config.original_max_position_embeddings return cls._from_pretrained( model_id=save_dir_path, config=config, use_cache=use_cache, stateful=None, load_in_8bit=load_in_8bit, quantization_config=quantization_config, trust_remote_code=trust_remote_code, compile_only=compile_only, **kwargs, ) def _reshape( self, model: openvino.Model, batch_size: int, sequence_length: int, height: int = None, width: int = None, ): if self._compile_only: raise ValueError( "`reshape()` is not supported with `compile_only` mode, please initialize model without this option" ) if height is not None: logger.warning(f"`height` set to `{height}` will be ignored during reshaping operation.") if width is not None: logger.warning(f"`width` set to `{width}` will be ignored during reshaping operation.") shapes = {} for inputs in model.inputs: shapes[inputs] = inputs.get_partial_shape() shapes[inputs][0] = -1 input_name = inputs.get_any_name() if input_name.startswith("past_key_values"): if (len(inputs.partial_shape) == 3 and input_name.endswith("value")) or ( self.config.model_type == "chatglm" and not hasattr(self.config, "rope_ratio") ): shapes[inputs][1] = -1 else: shapes[inputs][2] = -1 elif input_name.startswith("beam_idx"): shapes[inputs][0] = -1 else: shapes[inputs][1] = -1 model.reshape(shapes) return model def reshape(self, batch_size: int, sequence_length: int): logger.warning("Static shapes are not supported for causal language model.") return self @property def normalized_config(self): logger.warning( "access to normalized_config attribute is deprecated and will be removed in future versions, please use config" ) return NormalizedConfigManager.get_normalized_config_class(self.config.model_type)(self.config) def compile(self): if self.request is None: if self._compile_only: self.request = self.model.create_infer_request() super().compile() self.request = self.request.create_infer_request() def _make_stateful(self): patch_stateful(self.config, self.model) self.stateful = True @add_start_docstrings( """ OpenVINO Model with a causal language modeling head on top (linear layer with weights tied to the input embeddings). """, MODEL_START_DOCSTRING, ) class OVModelForCausalLM(OVBaseDecoderModel, GenerationMixin): export_feature = "text-generation" auto_model_class = AutoModelForCausalLM @add_start_docstrings_to_model_forward( INPUTS_DOCSTRING.format("batch_size, sequence_length") + TEXT_GENERATION_EXAMPLE.format( processor_class=_TOKENIZER_FOR_DOC, model_class="OVModelForCausalLM", checkpoint="gpt2", ) ) def prepare_inputs( self, input_ids: torch.LongTensor, attention_mask: Optional[torch.LongTensor] = None, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, position_ids: Optional[torch.LongTensor] = None, **kwargs, ) -> Dict: batch_size = input_ids.shape[0] model_transformers_version = get_export_transformers_version(self.model, self.config) if self.config.model_type == "bloom" and compare_versions(model_transformers_version, "<", "4.44"): batch_size *= self.config.num_attention_heads inputs = {} if not self.stateful: if past_key_values is not None: if self.config.model_type not in MULTI_QUERY_ATTN_MODELS or ( self.config.model_type == "falcon" and self.config.new_decoder_architecture ): if self._pkv_precision == Type.bf16: # numpy does not support bf16, pretending f16, should change to bf16 past_key_values = tuple( Tensor(past_key_value, past_key_value.shape, Type.bf16) for pkv_per_layer in past_key_values for past_key_value in pkv_per_layer ) else: # Flatten the past_key_values past_key_values = tuple( past_key_value for pkv_per_layer in past_key_values for past_key_value in pkv_per_layer ) # Add the past_key_values to the decoder inputs inputs = dict(zip(self.key_value_input_names, past_key_values)) # Create empty past_key_values for decoder_with_past first generation step elif self.use_cache: for input_name in self.key_value_input_names: model_inputs = self.model.input(input_name) shape = model_inputs.get_partial_shape() if self.config.model_type == "chatglm" and not hasattr(self.config, "rope_ratio"): shape[0] = 0 shape[1] = batch_size else: shape[0] = batch_size if shape[2].is_dynamic: shape[2] = 0 else: shape[1] = 0 inputs[input_name] = Tensor(model_inputs.get_element_type(), [dim.get_length() for dim in shape]) else: # past_key_values are not used explicitly, instead they are handled inside the model if past_key_values is None: # This is the first iteration in a sequence, reset all states if self.request is not None: self.request.reset_state() # Set initial value for the next beam_idx input that will be used at the current iteration # and will be optionally updated by _reorder_cache at the next iterations if beam_search is used self.next_beam_idx = np.arange(batch_size, dtype=int) self._past_length = 0 past_len = self._get_past_length(past_key_values) inputs["input_ids"] = input_ids.cpu().numpy() # Add the attention_mask inputs when needed if "attention_mask" in self.input_names or "position_ids" in self.input_names: if attention_mask is not None: attention_mask = attention_mask.cpu().numpy() else: attention_mask = np.ones( (input_ids.shape[0], input_ids.shape[1] + past_len), dtype=inputs["input_ids"].dtype ) if "attention_mask" in self.input_names: inputs["attention_mask"] = attention_mask if "position_ids" in self.input_names: if position_ids is not None: position_ids = position_ids.cpu().numpy() else: position_ids = np.cumsum(attention_mask, axis=1) - 1 position_ids[attention_mask == 0] = 1 if past_key_values: position_ids = position_ids[:, -input_ids.shape[1] :] inputs["position_ids"] = position_ids if "beam_idx" in self.input_names: inputs["beam_idx"] = ( self.next_beam_idx if self.next_beam_idx is not None else np.arange(batch_size, dtype=int) ) return inputs def forward( self, input_ids: torch.LongTensor, attention_mask: Optional[torch.LongTensor] = None, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, position_ids: Optional[torch.LongTensor] = None, token_type_ids: Optional[torch.LongTensor] = None, **kwargs, ) -> CausalLMOutputWithPast: self.compile() # added as model.generate validates model inputs based on forward signature kwargs["token_type_ids"] = token_type_ids inputs = self.prepare_inputs( input_ids=input_ids, attention_mask=attention_mask, past_key_values=past_key_values, position_ids=position_ids, **kwargs, ) if self._first_iter_beam_search: inputs, duplication_indices = self._deduplicate_inputs(inputs) # Run inference self.request.start_async(inputs, share_inputs=True) self.request.wait() logits = torch.from_numpy(self.request.get_tensor("logits").data).clone().to(self.device) if self.stateful: # Need a marker to differentiate the first generate iteration from the others in # the first condition at the function beginning above. # It should be something that is not None and it should be True when converted to Boolean. past_key_values = ((),) self._past_length += input_ids.shape[1] if not self.stateful: if self.use_cache: # Tuple of length equal to : number of layer * number of past_key_value per decoder layer (2 corresponds to the self-attention layer) past_key_values = tuple( np.copy(self.request.get_tensor(key).data) for key in self.key_value_output_names ) if self.config.model_type not in MULTI_QUERY_ATTN_MODELS or ( self.config.model_type == "falcon" and self.config.new_decoder_architecture ): # Tuple of tuple of length `n_layers`, with each tuple of length equal to 2 (k/v of self-attention) past_key_values = tuple( past_key_values[i : i + self.num_pkv] for i in range(0, len(past_key_values), self.num_pkv) ) else: past_key_values = None if self._first_iter_beam_search: logits, past_key_values = self._expand_outputs_for_generation(duplication_indices, logits, past_key_values) self._first_iter_beam_search = False return CausalLMOutputWithPast(logits=logits, past_key_values=past_key_values) # Adapted from transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs): # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly attention_mask = kwargs.get("attention_mask", None) use_cache = kwargs.get("use_cache", None) if past_key_values is not None: past_len = self._get_past_length(past_key_values) # Keep only the unprocessed tokens: # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as # input) if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: input_ids = input_ids[:, -(attention_mask.shape[1] - past_len) :] # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard # input_ids based on the past_length. elif past_len < input_ids.shape[1]: input_ids = input_ids[:, past_len:] # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens position_ids = kwargs.get("position_ids", None) if attention_mask is not None and position_ids is None and "position_ids" in self.input_names: # create position_ids on the fly for batch generation position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values: position_ids = position_ids[:, -input_ids.shape[1] :] model_inputs = { "input_ids": input_ids, "past_key_values": past_key_values, "use_cache": use_cache, "position_ids": position_ids, "attention_mask": attention_mask, } return model_inputs def _update_model_kwargs_for_generation( self, outputs: ModelOutput, model_kwargs: Dict[str, Any], is_encoder_decoder: bool = False, **kwargs, ) -> Dict[str, Any]: model_kwargs = super()._update_model_kwargs_for_generation( outputs=outputs, model_kwargs=model_kwargs, is_encoder_decoder=is_encoder_decoder, **kwargs ) if "position_ids" in model_kwargs: position_ids = model_kwargs["position_ids"] new_position_id = position_ids[..., -1:].clone() new_position_id += 1 model_kwargs["position_ids"] = torch.cat([position_ids, new_position_id], dim=-1) return model_kwargs def _expand_outputs_for_generation(self, indicies, logits: torch.Tensor, past_key_values: Tuple): batch_size = logits.shape[0] if indicies.shape[0] != 1: logits = logits[indicies] if past_key_values and not self.stateful: if self.config.model_type not in MULTI_QUERY_ATTN_MODELS or ( self.config.model_type == "falcon" and self.config.new_decoder_architecture ): past_key_values = tuple( tuple( ( past_state[indicies] if not (self.config.model_type == "chatglm" and not hasattr(self.config, "rope_ratio")) else past_state[:, indicies, ...] ) for past_state in layer_past ) for layer_past in past_key_values ) else: past_key_values = tuple([past_state[indicies] for past_state in past_key_values]) if self.stateful: self.next_beam_idx = ( self.next_beam_idx[indicies] if self.next_beam_idx is not None else np.arange(batch_size, dtype=int)[indicies] ) self._second_iter_beam_search = True return logits, past_key_values def _deduplicate_inputs(self, model_inputs: Dict): input_ids = model_inputs["input_ids"] upd_model_inputs = {} unique_input_ids, indicies, reverse_indicies = np.unique( input_ids, axis=0, return_index=True, return_inverse=True ) export_transformers_version = get_export_transformers_version(self.model, self.config) for input_name, input_tensor in model_inputs.items(): if input_name not in ["input_ids", "beam_idx"]: if input_name not in self.key_value_input_names: upd_model_inputs[input_name] = input_tensor[indicies] else: shape = input_tensor.shape if isinstance(input_tensor, Tensor) else list(input_tensor.shape) dtype = input_tensor.element_type if isinstance(input_tensor, Tensor) else Type(input_tensor.dtype) upd_batch_size = indicies.shape[0] if self.config.model_type == "bloom" and compare_versions( export_transformers_version, "<", "4.44" ): upd_batch_size *= self.config.num_attention_heads shape[ ( 0 if not (self.config.model_type == "chatglm" and not hasattr(self.config, "rope_ratio")) else 1 ) ] = upd_batch_size upd_model_inputs[input_name] = Tensor(dtype, shape) upd_model_inputs["input_ids"] = unique_input_ids if "beam_idx" in model_inputs: beam_range = ( unique_input_ids.shape[0] * self.config.num_attention_heads if (self.config.model_type == "bloom" and compare_versions(export_transformers_version, "<", "4.44")) else unique_input_ids.shape[0] ) beam_idx = np.arange(beam_range, dtype=int) upd_model_inputs["beam_idx"] = beam_idx return upd_model_inputs, reverse_indicies @torch.no_grad() def generate( self, inputs: Optional[torch.Tensor] = None, generation_config: Optional[GenerationConfig] = None, logits_processor: Optional[LogitsProcessorList] = None, stopping_criteria: Optional[StoppingCriteriaList] = None, prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, synced_gpus: Optional[bool] = None, assistant_model: Optional["PreTrainedModel"] = None, streamer: Optional["BaseStreamer"] = None, negative_prompt_ids: Optional[torch.Tensor] = None, negative_prompt_attention_mask: Optional[torch.Tensor] = None, **kwargs, ) -> Union[GenerateOutput, torch.LongTensor]: if is_transformers_version(">=", "4.39.0"): _generation_config, _ = self._prepare_generation_config(generation_config, **kwargs) generation_mode = _generation_config.get_generation_mode(assistant_model) else: _generation_config = generation_config or self.generation_config generation_mode = self._get_generation_mode(_generation_config, assistant_model) is_beam_search = generation_mode in [ GenerationMode.BEAM_SEARCH, GenerationMode.BEAM_SAMPLE, GenerationMode.GROUP_BEAM_SEARCH, GenerationMode.CONSTRAINED_BEAM_SEARCH, ] if is_beam_search: self._first_iter_beam_search = True result = super().generate( inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, assistant_model, streamer, negative_prompt_ids, negative_prompt_attention_mask, **kwargs, ) return result def _get_past_length(self, past_key_values=None): if past_key_values is None: return 0 if self.stateful: return self._past_length if self.config.model_type in MULTI_QUERY_ATTN_MODELS and not ( self.config.model_type == "falcon" and self.config.new_decoder_architecture ): return past_key_values[0].shape[-2] seq_length_dim = -2 if self.config.model_type == "chatglm" and not hasattr(self.config, "rope_ratio"): seq_length_dim = 0 elif self.config.model_type == "qwen": seq_length_dim = 1 # input is tuple of pairs if isinstance(past_key_values[0], (tuple, list)): return past_key_values[0][1].shape[seq_length_dim] # past key values comes after flattening return past_key_values[1].shape[seq_length_dim] # Adapted from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel._reorder_cache def _reorder_cache( self, past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor ) -> Tuple[Tuple[torch.Tensor]]: """ This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct beam_idx at every generation step. """ if self.stateful: # TODO: Apply it differently based on model type # TODO: At least for bloom we need to replicate values for each attention head self.next_beam_idx = ( np.array(beam_idx) if not self._second_iter_beam_search else self.next_beam_idx ) # save beam_idx to be used as an input in the next iteration self._second_iter_beam_search = False return past_key_values else: if self.config.model_type not in MULTI_QUERY_ATTN_MODELS or ( self.config.model_type == "falcon" and self.config.new_decoder_architecture ): return tuple( tuple(np.take(past_state, beam_idx, 0) for past_state in layer_past) for layer_past in past_key_values ) return tuple(np.take(past_state, beam_idx, 0) for past_state in past_key_values) @classmethod def _from_pretrained( cls, model_id: Union[str, Path], config: PretrainedConfig, token: Optional[Union[bool, str]] = None, revision: Optional[Union[str, None]] = None, force_download: bool = False, cache_dir: str = HUGGINGFACE_HUB_CACHE, file_name: Optional[str] = None, subfolder: str = "", from_onnx: bool = False, local_files_only: bool = False, load_in_8bit: bool = False, compile_only: bool = False, quantization_config: Optional[Union[OVWeightQuantizationConfig, Dict]] = None, **kwargs, ): generation_config = kwargs.pop("generation_config", None) model_path = Path(model_id) default_file_name = ONNX_WEIGHTS_NAME if from_onnx else OV_XML_FILE_NAME file_name = file_name or default_file_name model_cache_path = cls._cached_file( model_path=model_path, token=token, revision=revision, force_download=force_download, cache_dir=cache_dir, file_name=file_name, subfolder=subfolder, local_files_only=local_files_only, ) if not compile_only: model = cls.load_model(model_cache_path) else: model = cls._compile_model( model_cache_path, kwargs.get("device", "CPU"), kwargs.get("ov_config"), model_cache_path.parent ) model_type = config.model_type.replace("_", "-") export_transformers_version = get_export_transformers_version(model, config) if model_type == "bloom" and compare_versions(export_transformers_version, "<", "4.44"): init_cls = OVBloomForCausalLM elif model_type == "gpt-bigcode": init_cls = OVGPTBigCodeForCausalLM else: init_cls = cls if isinstance(quantization_config, dict) and quantization_config == {"bits": 4}: default_config = get_default_quantization_config(config.name_or_path, weight_format="int4") quantization_config = cls._prepare_quantization_config( default_config or _DEFAULT_4BIT_WQ_CONFIG, load_in_8bit ) if quantization_config.dataset is not None: quantization_config.trust_remote_code = kwargs.get("trust_remote_code", False) else: quantization_config = cls._prepare_quantization_config(quantization_config, load_in_8bit) if isinstance(quantization_config, OVWeightQuantizationConfig) and quantization_config.bits == 4: default_config = get_default_quantization_config(config.name_or_path, weight_format="int4") if default_config: logger.info( f"For the given model, we recommend the following `quantization_config` : {default_config}" ) enable_compilation = kwargs.pop("compile", True) and not quantization_config if generation_config is None: try: generation_config = GenerationConfig.from_pretrained( model_id, cache_dir=cache_dir, force_download=force_download, local_files_only=local_files_only, token=token, revision=revision, subfolder=subfolder, ) if getattr(generation_config, "cache_implementation", None) is not None: generation_config.cache_implementation = None except OSError: logger.info( "Generation config file not found, using a generation config created from the model config." ) causal_model = init_cls( model=model, config=config, model_save_dir=model_cache_path.parent, compile=enable_compilation, compile_only=compile_only, quantization_config=quantization_config, generation_config=generation_config, **kwargs, ) if quantization_config: if not is_nncf_available(): raise ImportError( "Quantization of the weights requires nncf, please install it with `pip install nncf`" ) if compile_only: raise ValueError( "quantization is not supported with `compile_only` mode, please initialize model without this option" ) from optimum.intel.openvino.quantization import OVQuantizer quantizer = OVQuantizer(causal_model) quantization_config_copy = copy.deepcopy(quantization_config) quantization_config_copy.tokenizer = quantization_config.tokenizer or model_id quantizer.quantize(ov_config=OVConfig(quantization_config=quantization_config_copy)) return causal_model class OVBloomForCausalLM(OVModelForCausalLM): # Adapted from transformers.models.bloom.modeling_bloom.BloomForCausalLM.prepare_inputs_for_generation def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs): # only last token for input_ids if past is not None if past_key_values and not self.stateful: # the cache may be in the stardard format (e.g. in contrastive search), convert to bloom's format if needed if past_key_values[0][0].shape[0] == input_ids.shape[0]: past_key_values = self._convert_to_bloom_cache(past_key_values) return super().prepare_inputs_for_generation(input_ids, past_key_values=past_key_values, **kwargs) # Adapted from transformers.models.bloom.modeling_bloom.BloomForCausalLM._reorder_cache def _reorder_cache( self, past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor ) -> Tuple[Tuple[torch.Tensor]]: """ This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or [`~PreTrainedModel.beam_sample`] is called for bloom architecture. This is required to match `past_key_values` with the correct beam_idx at every generation step. """ if self.stateful: batch_size = beam_idx.shape[0] beam_idx = np.array(beam_idx) if not self._second_iter_beam_search else self.next_beam_idx indices = np.array(range(batch_size * self.config.num_attention_heads)) indices = indices.reshape([batch_size, self.config.num_attention_heads]) self.next_beam_idx = np.take(indices, beam_idx, 0).flatten() self._second_iter_beam_search = False return past_key_values else: standardized_past = self._convert_to_standard_cache(past_key_values, batch_size=len(beam_idx)) reordered_past = tuple( ( np.take(layer_past[0], beam_idx, 0), np.take(layer_past[1], beam_idx, 0), ) for layer_past in standardized_past ) return self._convert_to_bloom_cache(reordered_past) # Copied from transformers.models.bloom.modeling_bloom.BloomPreTrainedModel._convert_to_bloom_cache @staticmethod def _convert_to_bloom_cache(past_key_value: Tuple[Tuple[torch.Tensor]]) -> Tuple[Tuple[torch.Tensor]]: """ Converts the cache to the format expected by Bloom, i.e. to tuple(tuple([batch_size * num_heads, ...])) """ batch_size, num_heads, head_dim, seq_length = past_key_value[0][0].shape batch_size_times_num_heads = batch_size * num_heads # key: [batch_size, num_heads, head_dim, seq_length] -> [batch_size * num_heads, head_dim, seq_length] # value: [batch_size, num_heads, seq_length, head_dim] -> [batch_size * num_heads, seq_length, head_dim] return tuple( ( layer_past[0].reshape((batch_size_times_num_heads, head_dim, seq_length)), layer_past[1].reshape((batch_size_times_num_heads, seq_length, head_dim)), ) for layer_past in past_key_value ) # Adapted from transformers.models.bloom.modeling_bloom.BloomPreTrainedModel._convert_to_standard_cache def _convert_to_standard_cache( self, past_key_value: Tuple[Tuple[torch.Tensor]], batch_size: int ) -> Tuple[Tuple[torch.Tensor]]: """ Standardizes the format of the cache so as to match most implementations, i.e. to tuple(tuple([batch_size, num_heads, ...])) """ batch_size_times_num_heads, head_dim, seq_length = past_key_value[0][0].shape num_heads = batch_size_times_num_heads // batch_size # key: [batch_size * num_heads, head_dim, seq_length] -> [batch_size, num_heads, head_dim, seq_length] # value: [batch_size * num_heads, seq_length, head_dim] -> [batch_size, num_heads, seq_length, head_dim] return tuple( ( layer_past[0].reshape((batch_size, num_heads, head_dim, seq_length)), layer_past[1].reshape((batch_size, num_heads, seq_length, head_dim)), ) for layer_past in past_key_value ) def _expand_outputs_for_generation(self, indicies, logits: torch.Tensor, past_key_values: Tuple): batch_size = logits.shape[0] if indicies.shape[0] != 1: logits = logits[indicies] if past_key_values and not self.stateful: pkv_standard = self._convert_to_standard_cache(past_key_values, batch_size) pkv = tuple(tuple(past_state[indicies] for past_state in layer_past) for layer_past in pkv_standard) past_key_values = self._convert_to_bloom_cache(pkv) if self.stateful: self.next_beam_idx = ( self.next_beam_idx[indicies] if self.next_beam_idx is not None else np.arange(batch_size, dtype=int)[indicies] ) self._second_iter_beam_search = True return logits, past_key_values class OVGPTBigCodeForCausalLM(OVModelForCausalLM): # Adapted from transformers.models.gpt_bigcode.modeling_gpt_bigcode.GPTBigCodeForCausalLM._reorder_cache def _reorder_cache( self, past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor ) -> Tuple[Tuple[torch.Tensor]]: if self.stateful: # save beam_idx to be used as an input in the next iteration self.next_beam_idx = np.array(beam_idx) if not self._second_iter_beam_search else self.next_beam_idx self._second_iter_beam_search = False return past_key_values else: return tuple(np.take(layer_past, beam_idx, 0) for layer_past in past_key_values)