optimum/intel/ipex/modeling_sentence_transformers.py (59 lines of code) (raw):
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Optional
import torch
from sentence_transformers import SentenceTransformer
from sentence_transformers.models import Transformer
from sentence_transformers.models.Transformer import _save_pretrained_wrapper
from sentence_transformers.util import import_from_string
from transformers import MT5Config, T5Config
from transformers.dynamic_module_utils import get_class_from_dynamic_module
from ..utils.import_utils import _sentence_transformers_version, is_sentence_transformers_version
from .modeling_base import IPEXModel
class IPEXTransformer(Transformer):
def __init__(self, *args, **kwargs):
if is_sentence_transformers_version("<", "3.4"):
raise ImportError(
f"Backend: ipex requires sentence-transformers>=3.4 but found {_sentence_transformers_version}. "
"You can install it with pip: `pip install --upgrade sentence-transformers`"
)
super().__init__(*args, **kwargs)
self.backend = "ipex"
def _load_model(self, model_name_or_path, config, cache_dir, backend, is_peft_model, **model_args) -> None:
self._load_ipex_model(model_name_or_path, config, cache_dir, **model_args)
def _load_ipex_model(self, model_name_or_path, config, cache_dir, **model_args) -> None:
if isinstance(config, T5Config) or isinstance(config, MT5Config):
raise ValueError("T5 models are not yet supported by the IPEX backend.")
self.auto_model = IPEXModel.from_pretrained(
model_name_or_path,
config=config,
cache_dir=cache_dir,
**model_args,
)
# Wrap the save_pretrained method to save the model in the correct subfolder
self.auto_model._save_pretrained = _save_pretrained_wrapper(self.auto_model._save_pretrained, "ipex")
class IPEXSentenceTransformer(SentenceTransformer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.backend = "ipex"
def _load_module_class_from_ref(
self,
class_ref: str,
model_name_or_path: str,
trust_remote_code: bool,
revision: Optional[str] = None,
model_kwargs: Optional[Dict[str, Any]] = None,
) -> torch.nn.Module:
if class_ref.startswith("sentence_transformers."):
if class_ref == "sentence_transformers.models.Transformer":
class_ref = "optimum.intel.ipex.modeling_sentence_transformers.IPEXTransformer"
return import_from_string(class_ref)
if trust_remote_code:
code_revision = model_kwargs.pop("code_revision", None) if model_kwargs else None
try:
return get_class_from_dynamic_module(
class_ref,
model_name_or_path,
revision=revision,
code_revision=code_revision,
)
except OSError:
# Ignore the error if the file does not exist, and fall back to the default import
pass
return import_from_string(class_ref)