optimum/exporters/ipex/model_config.py (68 lines of code) (raw):

# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Optional, Tuple from optimum.exporters.onnx.model_configs import ( FalconOnnxConfig, GPT2OnnxConfig, LlamaOnnxConfig, ) from optimum.utils import DEFAULT_DUMMY_SHAPES from optimum.utils.input_generators import DummyPastKeyValuesGenerator, DummyTextInputGenerator from optimum.utils.normalized_config import NormalizedTextConfig DEFAULT_DUMMY_SHAPES["batch_size"] = 1 class IPEXDummyPastKeyValuesGenerator(DummyPastKeyValuesGenerator): def __init__( self, task: str, normalized_config: NormalizedTextConfig, batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"], sequence_length: int = DEFAULT_DUMMY_SHAPES["sequence_length"], random_batch_size_range: Optional[Tuple[int, int]] = None, random_sequence_length_range: Optional[Tuple[int, int]] = None, **kwargs, ): super().__init__( task=task, normalized_config=normalized_config, batch_size=batch_size, sequence_length=sequence_length, random_batch_size_range=random_batch_size_range, random_sequence_length_range=random_sequence_length_range, ) self.num_key_value_heads = getattr(normalized_config, "num_key_value_heads", 1) self.max_position_embeddings = normalized_config.max_position_embeddings def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): shape_init = (1, self.sequence_length, self.sequence_length, 1) shape_beam_idx_tmp = (self.max_position_embeddings, self.batch_size) shape_kv = ( self.max_position_embeddings, self.batch_size, self.num_key_value_heads, self.hidden_size // self.num_attention_heads, ) return [ ( self.random_int_tensor(shape_init, max_value=1, framework=framework).contiguous(), self.random_float_tensor(shape_kv, framework=framework, dtype=float_dtype).contiguous(), self.random_float_tensor(shape_kv, framework=framework, dtype=float_dtype).contiguous(), self.random_int_tensor(shape_beam_idx_tmp, max_value=1, framework=framework).contiguous(), ) for _ in range(self.num_layers) ] class IPEXDummyTextInputGenerator(DummyTextInputGenerator): def __init__( self, task: str, normalized_config: NormalizedTextConfig, batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"], **kwargs, ): super().__init__(task, normalized_config, batch_size, **kwargs) class LlamaIPEXConfig(LlamaOnnxConfig): DUMMY_INPUT_GENERATOR_CLASSES = (IPEXDummyTextInputGenerator, IPEXDummyPastKeyValuesGenerator) DUMMY_PKV_GENERATOR_CLASS = IPEXDummyPastKeyValuesGenerator class FalconIPEXConfig(FalconOnnxConfig): DUMMY_INPUT_GENERATOR_CLASSES = (IPEXDummyTextInputGenerator, IPEXDummyPastKeyValuesGenerator) DUMMY_PKV_GENERATOR_CLASS = IPEXDummyPastKeyValuesGenerator class GPT2IPEXConfig(GPT2OnnxConfig): DUMMY_INPUT_GENERATOR_CLASSES = (IPEXDummyTextInputGenerator, IPEXDummyPastKeyValuesGenerator) DUMMY_PKV_GENERATOR_CLASS = IPEXDummyPastKeyValuesGenerator ipex_onnx_config = {"llama": LlamaIPEXConfig, "falcon": FalconIPEXConfig, "gpt2": GPT2IPEXConfig}