optimum/exporters/ipex/model_patcher.py (119 lines of code) (raw):

# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from transformers.models.bert.modeling_bert import BertIntermediate from transformers.models.falcon.modeling_falcon import FalconDecoderLayer, FalconModel from transformers.models.gpt2.modeling_gpt2 import GPT2Block, GPT2Model from transformers.models.llama.modeling_llama import ( LlamaDecoderLayer, LlamaModel, LlamaRMSNorm, ) from transformers.models.mistral.modeling_mistral import MistralDecoderLayer, MistralModel, MistralRMSNorm from transformers.models.qwen2.modeling_qwen2 import ( Qwen2DecoderLayer, Qwen2Model, Qwen2RMSNorm, ) from transformers.models.vit.modeling_vit import ViTIntermediate from optimum.intel.utils.import_utils import is_ipex_version, is_transformers_version from optimum.intel.utils.modeling_utils import replace_customized_linear_with_linear from .modeling_utils import ( _IPEX_MINIMUM_VERSION_FOR_PATCHING, _falcon_for_causal_lm_forward, _falcon_model_forward, _gpt2_lm_head_model_forward, _gpt2_model_forward, _ipex_rms_layer_norm_forward, _IPEXFalconDecoderLayer, _IPEXGPT2Block, _IPEXIntermediate, _IPEXLlamaDecoderLayer, _IPEXMistralDecoderLayer, _IPEXQwen2DecoderLayer, _llama_model_forward, _mistral_model_forward, _qwen2_model_forward, ) # Please also update in the setup.py and .github/workflows/test_ipex.yml if you change the transformers version _TRANSFORMERS_MIN_VERSION = "4.51.0" _TRANSFORMERS_MAX_VERSION = "4.52.99" _IPEX_EXPORTED_GENERATION_TASKS = ("text-generation",) def convert_func(m, func_name, new_function): bound_method = new_function.__get__(m, m.__class__) setattr(m, func_name, bound_method) def convert_functions(m, target_m, new_function_name, new_function): for _, sub_m in m.named_children(): if isinstance(sub_m, target_m): convert_func(sub_m, new_function_name, new_function) convert_functions(sub_m, target_m, new_function_name, new_function) def convert_class(m, target_m, new_class, device, config): for name, sub_m in m.named_children(): if isinstance(sub_m, target_m): new_m = new_class(sub_m, device, config) setattr(m, name, new_m) convert_class(sub_m, target_m, new_class, device, config) def patch_op(m, target_m, new_op_name, new_op): for name, sub_m in m.named_children(): if isinstance(sub_m, target_m): setattr(sub_m, new_op_name, new_op) patch_op(sub_m, target_m, new_op_name, new_op) def _patch_llama_model(model): """ Patch llama model: 1. Use IPEX rope and paged cache 2. Linear fusion with (2 Linears + Silu + Mul) and (Linear + Add) """ convert_functions(model, LlamaModel, "forward", _llama_model_forward) convert_functions(model, LlamaRMSNorm, "forward", _ipex_rms_layer_norm_forward) convert_class(model, LlamaDecoderLayer, _IPEXLlamaDecoderLayer, model.device, model.config) return model def _patch_falcon_model(model): """ Patch falcon model: 1. Use IPEX rope and paged cache 2. Linear fusion with (Linear + Gelu) and (Linear + Add + Add) """ num_key_value_heads = ( model.config.num_kv_heads if (model.config.new_decoder_architecture or not model.config.multi_query) else 1 ) setattr(model.config, "num_key_value_heads", num_key_value_heads) convert_func(model, "forward", _falcon_for_causal_lm_forward) convert_functions(model, FalconModel, "forward", _falcon_model_forward) replace_customized_linear_with_linear(model) convert_class(model, FalconDecoderLayer, _IPEXFalconDecoderLayer, model.device, model.config) return model def _patch_gpt2_model(model): """ Patch gpt2 model: 1. Use IPEX paged attention 2. Linear fusion with (Linear + Add) """ num_key_value_heads = model.config.num_attention_heads setattr(model.config, "num_key_value_heads", num_key_value_heads) convert_func(model, "forward", _gpt2_lm_head_model_forward) convert_functions(model, GPT2Model, "forward", _gpt2_model_forward) convert_class(model, GPT2Block, _IPEXGPT2Block, model.device, model.config) return model def _patch_qwen2_model(model): """ Patch qwen2 model: 1. Use IPEX rope and paged cache 2. Linear fusion with (2 Linears + Silu + Mul) and (Linear + Add) """ # To avoid call _ignore_causal_mask_sdpa which will cause recompile model.config._attn_implementation = "ipex_paged" convert_functions(model, Qwen2Model, "forward", _qwen2_model_forward) convert_functions(model, Qwen2RMSNorm, "forward", _ipex_rms_layer_norm_forward) convert_class(model, Qwen2DecoderLayer, _IPEXQwen2DecoderLayer, model.device, model.config) return model def _patch_mistral_model(model): """ Patch mistral model: 1. Use IPEX rope and paged cache 2. Linear fusion with (Linear + Add) """ convert_functions(model, MistralModel, "forward", _mistral_model_forward) convert_functions(model, MistralRMSNorm, "forward", _ipex_rms_layer_norm_forward) convert_class(model, MistralDecoderLayer, _IPEXMistralDecoderLayer, model.device, model.config) return model def _patch_bert_model(model): """ Patch bert model: 1. Linear fusion with Linear + Gelu """ convert_class(model, BertIntermediate, _IPEXIntermediate, model.device, model.config) return model def _patch_vit_model(model): """ Patch vit model: 1. Linear fusion with Linear + Gelu """ convert_class(model, ViTIntermediate, _IPEXIntermediate, model.device, model.config) return model def _patch_model(model): if is_ipex_version("<", _IPEX_MINIMUM_VERSION_FOR_PATCHING): raise ImportError(f"Only ipex version >= {_IPEX_MINIMUM_VERSION_FOR_PATCHING} supports llama model patching") if is_transformers_version("<", _TRANSFORMERS_MIN_VERSION) or is_transformers_version( ">", _TRANSFORMERS_MAX_VERSION ): raise ImportError( f"Only transformers versions {_TRANSFORMERS_MIN_VERSION} ~ {_TRANSFORMERS_MAX_VERSION} are verified." ) if model.config.model_type == "llama": model = _patch_llama_model(model) elif model.config.model_type == "falcon": model = _patch_falcon_model(model) elif model.config.model_type == "gpt2": model = _patch_gpt2_model(model) elif model.config.model_type == "qwen2": model = _patch_qwen2_model(model) elif model.config.model_type == "mistral": model = _patch_mistral_model(model) elif model.config.model_type == "bert": model = _patch_bert_model(model) elif model.config.model_type == "vit": model = _patch_vit_model(model) return model