optimum/exporters/ipex/model_patcher.py (119 lines of code) (raw):
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from transformers.models.bert.modeling_bert import BertIntermediate
from transformers.models.falcon.modeling_falcon import FalconDecoderLayer, FalconModel
from transformers.models.gpt2.modeling_gpt2 import GPT2Block, GPT2Model
from transformers.models.llama.modeling_llama import (
LlamaDecoderLayer,
LlamaModel,
LlamaRMSNorm,
)
from transformers.models.mistral.modeling_mistral import MistralDecoderLayer, MistralModel, MistralRMSNorm
from transformers.models.qwen2.modeling_qwen2 import (
Qwen2DecoderLayer,
Qwen2Model,
Qwen2RMSNorm,
)
from transformers.models.vit.modeling_vit import ViTIntermediate
from optimum.intel.utils.import_utils import is_ipex_version, is_transformers_version
from optimum.intel.utils.modeling_utils import replace_customized_linear_with_linear
from .modeling_utils import (
_IPEX_MINIMUM_VERSION_FOR_PATCHING,
_falcon_for_causal_lm_forward,
_falcon_model_forward,
_gpt2_lm_head_model_forward,
_gpt2_model_forward,
_ipex_rms_layer_norm_forward,
_IPEXFalconDecoderLayer,
_IPEXGPT2Block,
_IPEXIntermediate,
_IPEXLlamaDecoderLayer,
_IPEXMistralDecoderLayer,
_IPEXQwen2DecoderLayer,
_llama_model_forward,
_mistral_model_forward,
_qwen2_model_forward,
)
# Please also update in the setup.py and .github/workflows/test_ipex.yml if you change the transformers version
_TRANSFORMERS_MIN_VERSION = "4.51.0"
_TRANSFORMERS_MAX_VERSION = "4.52.99"
_IPEX_EXPORTED_GENERATION_TASKS = ("text-generation",)
def convert_func(m, func_name, new_function):
bound_method = new_function.__get__(m, m.__class__)
setattr(m, func_name, bound_method)
def convert_functions(m, target_m, new_function_name, new_function):
for _, sub_m in m.named_children():
if isinstance(sub_m, target_m):
convert_func(sub_m, new_function_name, new_function)
convert_functions(sub_m, target_m, new_function_name, new_function)
def convert_class(m, target_m, new_class, device, config):
for name, sub_m in m.named_children():
if isinstance(sub_m, target_m):
new_m = new_class(sub_m, device, config)
setattr(m, name, new_m)
convert_class(sub_m, target_m, new_class, device, config)
def patch_op(m, target_m, new_op_name, new_op):
for name, sub_m in m.named_children():
if isinstance(sub_m, target_m):
setattr(sub_m, new_op_name, new_op)
patch_op(sub_m, target_m, new_op_name, new_op)
def _patch_llama_model(model):
"""
Patch llama model:
1. Use IPEX rope and paged cache
2. Linear fusion with (2 Linears + Silu + Mul) and (Linear + Add)
"""
convert_functions(model, LlamaModel, "forward", _llama_model_forward)
convert_functions(model, LlamaRMSNorm, "forward", _ipex_rms_layer_norm_forward)
convert_class(model, LlamaDecoderLayer, _IPEXLlamaDecoderLayer, model.device, model.config)
return model
def _patch_falcon_model(model):
"""
Patch falcon model:
1. Use IPEX rope and paged cache
2. Linear fusion with (Linear + Gelu) and (Linear + Add + Add)
"""
num_key_value_heads = (
model.config.num_kv_heads if (model.config.new_decoder_architecture or not model.config.multi_query) else 1
)
setattr(model.config, "num_key_value_heads", num_key_value_heads)
convert_func(model, "forward", _falcon_for_causal_lm_forward)
convert_functions(model, FalconModel, "forward", _falcon_model_forward)
replace_customized_linear_with_linear(model)
convert_class(model, FalconDecoderLayer, _IPEXFalconDecoderLayer, model.device, model.config)
return model
def _patch_gpt2_model(model):
"""
Patch gpt2 model:
1. Use IPEX paged attention
2. Linear fusion with (Linear + Add)
"""
num_key_value_heads = model.config.num_attention_heads
setattr(model.config, "num_key_value_heads", num_key_value_heads)
convert_func(model, "forward", _gpt2_lm_head_model_forward)
convert_functions(model, GPT2Model, "forward", _gpt2_model_forward)
convert_class(model, GPT2Block, _IPEXGPT2Block, model.device, model.config)
return model
def _patch_qwen2_model(model):
"""
Patch qwen2 model:
1. Use IPEX rope and paged cache
2. Linear fusion with (2 Linears + Silu + Mul) and (Linear + Add)
"""
# To avoid call _ignore_causal_mask_sdpa which will cause recompile
model.config._attn_implementation = "ipex_paged"
convert_functions(model, Qwen2Model, "forward", _qwen2_model_forward)
convert_functions(model, Qwen2RMSNorm, "forward", _ipex_rms_layer_norm_forward)
convert_class(model, Qwen2DecoderLayer, _IPEXQwen2DecoderLayer, model.device, model.config)
return model
def _patch_mistral_model(model):
"""
Patch mistral model:
1. Use IPEX rope and paged cache
2. Linear fusion with (Linear + Add)
"""
convert_functions(model, MistralModel, "forward", _mistral_model_forward)
convert_functions(model, MistralRMSNorm, "forward", _ipex_rms_layer_norm_forward)
convert_class(model, MistralDecoderLayer, _IPEXMistralDecoderLayer, model.device, model.config)
return model
def _patch_bert_model(model):
"""
Patch bert model:
1. Linear fusion with Linear + Gelu
"""
convert_class(model, BertIntermediate, _IPEXIntermediate, model.device, model.config)
return model
def _patch_vit_model(model):
"""
Patch vit model:
1. Linear fusion with Linear + Gelu
"""
convert_class(model, ViTIntermediate, _IPEXIntermediate, model.device, model.config)
return model
def _patch_model(model):
if is_ipex_version("<", _IPEX_MINIMUM_VERSION_FOR_PATCHING):
raise ImportError(f"Only ipex version >= {_IPEX_MINIMUM_VERSION_FOR_PATCHING} supports llama model patching")
if is_transformers_version("<", _TRANSFORMERS_MIN_VERSION) or is_transformers_version(
">", _TRANSFORMERS_MAX_VERSION
):
raise ImportError(
f"Only transformers versions {_TRANSFORMERS_MIN_VERSION} ~ {_TRANSFORMERS_MAX_VERSION} are verified."
)
if model.config.model_type == "llama":
model = _patch_llama_model(model)
elif model.config.model_type == "falcon":
model = _patch_falcon_model(model)
elif model.config.model_type == "gpt2":
model = _patch_gpt2_model(model)
elif model.config.model_type == "qwen2":
model = _patch_qwen2_model(model)
elif model.config.model_type == "mistral":
model = _patch_mistral_model(model)
elif model.config.model_type == "bert":
model = _patch_bert_model(model)
elif model.config.model_type == "vit":
model = _patch_vit_model(model)
return model