backends/python/server/text_embeddings_server/models/__init__.py (110 lines of code) (raw):
import os
import torch
from loguru import logger
from pathlib import Path
from typing import Optional
from transformers import AutoConfig
from transformers.models.bert import BertConfig
from text_embeddings_server.models.model import Model
from text_embeddings_server.models.masked_model import MaskedLanguageModel
from text_embeddings_server.models.default_model import DefaultModel
from text_embeddings_server.models.classification_model import ClassificationModel
from text_embeddings_server.models.jinaBert_model import FlashJinaBert
from text_embeddings_server.models.flash_mistral import FlashMistral
from text_embeddings_server.models.flash_qwen3 import FlashQwen3
from text_embeddings_server.utils.device import get_device, use_ipex
__all__ = ["Model"]
TRUST_REMOTE_CODE = os.getenv("TRUST_REMOTE_CODE", "false").lower() in ["true", "1"]
DISABLE_TENSOR_CACHE = os.getenv("DISABLE_TENSOR_CACHE", "false").lower() in [
"true",
"1",
]
# Disable gradients
torch.set_grad_enabled(False)
FLASH_ATTENTION = True
try:
from text_embeddings_server.models.flash_bert import FlashBert
except ImportError as e:
logger.warning(f"Could not import Flash Attention enabled models: {e}")
FLASH_ATTENTION = False
if FLASH_ATTENTION:
__all__.append(FlashBert)
def wrap_model_if_hpu(model_handle, device):
"""Wrap the model in HPU graph if the device is HPU."""
if device.type == "hpu":
from habana_frameworks.torch.hpu import wrap_in_hpu_graph
model_handle.model = wrap_in_hpu_graph(
model_handle.model, disable_tensor_cache=DISABLE_TENSOR_CACHE
)
return model_handle
def create_model(model_class, model_path, device, datatype, pool="cls"):
"""Create a model instance and wrap it if needed."""
model_handle = model_class(
model_path,
device,
datatype,
pool,
trust_remote=TRUST_REMOTE_CODE,
)
return wrap_model_if_hpu(model_handle, device)
def get_model(model_path: Path, dtype: Optional[str], pool: str):
if dtype == "float32":
datatype = torch.float32
elif dtype == "float16":
datatype = torch.float16
elif dtype == "bfloat16":
datatype = torch.bfloat16
else:
raise RuntimeError(f"Unknown dtype {dtype}")
device = get_device()
logger.info(f"backend device: {device}")
config = AutoConfig.from_pretrained(model_path, trust_remote_code=TRUST_REMOTE_CODE)
if (
hasattr(config, "auto_map")
and isinstance(config.auto_map, dict)
and "AutoModel" in config.auto_map
and config.auto_map["AutoModel"]
== "jinaai/jina-bert-v2-qk-post-norm--modeling_bert.JinaBertModel"
):
# Add specific offline modeling for model "jinaai/jina-embeddings-v2-base-code" which uses "autoMap" to reference code in other repository
return create_model(FlashJinaBert, model_path, device, datatype)
if config.model_type == "bert":
config: BertConfig
if (
use_ipex()
or device.type in ["cuda", "hpu"]
and config.position_embedding_type == "absolute"
and datatype in [torch.float16, torch.bfloat16]
and FLASH_ATTENTION
):
if pool != "cls":
if config.architectures[0].endswith("ForMaskedLM") and pool == "splade":
return create_model(
MaskedLanguageModel, model_path, device, datatype, pool
)
return create_model(DefaultModel, model_path, device, datatype, pool)
try:
return create_model(FlashBert, model_path, device, datatype)
except FileNotFoundError:
logger.info(
"Do not have safetensors file for this model, use default transformers model path instead"
)
return create_model(DefaultModel, model_path, device, datatype, pool)
if config.architectures[0].endswith("Classification"):
return create_model(ClassificationModel, model_path, device, datatype)
elif config.architectures[0].endswith("ForMaskedLM") and pool == "splade":
return create_model(MaskedLanguageModel, model_path, device, datatype)
else:
return create_model(DefaultModel, model_path, device, datatype, pool)
if config.model_type == "mistral" and device.type == "hpu":
try:
return create_model(FlashMistral, model_path, device, datatype, pool)
except FileNotFoundError:
return create_model(DefaultModel, model_path, device, datatype, pool)
if config.model_type == "qwen3" and device.type == "hpu":
try:
return create_model(FlashQwen3, model_path, device, datatype, pool)
except FileNotFoundError:
return create_model(DefaultModel, model_path, device, datatype, pool)
# Default case
if config.architectures[0].endswith("Classification"):
return create_model(ClassificationModel, model_path, device, datatype)
elif config.architectures[0].endswith("ForMaskedLM") and pool == "splade":
return create_model(MaskedLanguageModel, model_path, device, datatype)
else:
return create_model(DefaultModel, model_path, device, datatype, pool)