in backends/python/server/text_embeddings_server/models/__init__.py [0:0]
def get_model(model_path: Path, dtype: Optional[str], pool: str):
if dtype == "float32":
datatype = torch.float32
elif dtype == "float16":
datatype = torch.float16
elif dtype == "bfloat16":
datatype = torch.bfloat16
else:
raise RuntimeError(f"Unknown dtype {dtype}")
device = get_device()
logger.info(f"backend device: {device}")
config = AutoConfig.from_pretrained(model_path, trust_remote_code=TRUST_REMOTE_CODE)
if (
hasattr(config, "auto_map")
and isinstance(config.auto_map, dict)
and "AutoModel" in config.auto_map
and config.auto_map["AutoModel"]
== "jinaai/jina-bert-v2-qk-post-norm--modeling_bert.JinaBertModel"
):
# Add specific offline modeling for model "jinaai/jina-embeddings-v2-base-code" which uses "autoMap" to reference code in other repository
return create_model(FlashJinaBert, model_path, device, datatype)
if config.model_type == "bert":
config: BertConfig
if (
use_ipex()
or device.type in ["cuda", "hpu"]
and config.position_embedding_type == "absolute"
and datatype in [torch.float16, torch.bfloat16]
and FLASH_ATTENTION
):
if pool != "cls":
if config.architectures[0].endswith("ForMaskedLM") and pool == "splade":
return create_model(
MaskedLanguageModel, model_path, device, datatype, pool
)
return create_model(DefaultModel, model_path, device, datatype, pool)
try:
return create_model(FlashBert, model_path, device, datatype)
except FileNotFoundError:
logger.info(
"Do not have safetensors file for this model, use default transformers model path instead"
)
return create_model(DefaultModel, model_path, device, datatype, pool)
if config.architectures[0].endswith("Classification"):
return create_model(ClassificationModel, model_path, device, datatype)
elif config.architectures[0].endswith("ForMaskedLM") and pool == "splade":
return create_model(MaskedLanguageModel, model_path, device, datatype)
else:
return create_model(DefaultModel, model_path, device, datatype, pool)
if config.model_type == "mistral" and device.type == "hpu":
try:
return create_model(FlashMistral, model_path, device, datatype, pool)
except FileNotFoundError:
return create_model(DefaultModel, model_path, device, datatype, pool)
if config.model_type == "qwen3" and device.type == "hpu":
try:
return create_model(FlashQwen3, model_path, device, datatype, pool)
except FileNotFoundError:
return create_model(DefaultModel, model_path, device, datatype, pool)
# Default case
if config.architectures[0].endswith("Classification"):
return create_model(ClassificationModel, model_path, device, datatype)
elif config.architectures[0].endswith("ForMaskedLM") and pool == "splade":
return create_model(MaskedLanguageModel, model_path, device, datatype)
else:
return create_model(DefaultModel, model_path, device, datatype, pool)