def get_model()

in backends/python/server/text_embeddings_server/models/__init__.py [0:0]


def get_model(model_path: Path, dtype: Optional[str], pool: str):
    if dtype == "float32":
        datatype = torch.float32
    elif dtype == "float16":
        datatype = torch.float16
    elif dtype == "bfloat16":
        datatype = torch.bfloat16
    else:
        raise RuntimeError(f"Unknown dtype {dtype}")

    device = get_device()
    logger.info(f"backend device: {device}")

    config = AutoConfig.from_pretrained(model_path, trust_remote_code=TRUST_REMOTE_CODE)

    if (
        hasattr(config, "auto_map")
        and isinstance(config.auto_map, dict)
        and "AutoModel" in config.auto_map
        and config.auto_map["AutoModel"]
        == "jinaai/jina-bert-v2-qk-post-norm--modeling_bert.JinaBertModel"
    ):
        # Add specific offline modeling for model "jinaai/jina-embeddings-v2-base-code" which uses "autoMap" to reference code in other repository
        return create_model(FlashJinaBert, model_path, device, datatype)

    if config.model_type == "bert":
        config: BertConfig
        if (
            use_ipex()
            or device.type in ["cuda", "hpu"]
            and config.position_embedding_type == "absolute"
            and datatype in [torch.float16, torch.bfloat16]
            and FLASH_ATTENTION
        ):
            if pool != "cls":
                if config.architectures[0].endswith("ForMaskedLM") and pool == "splade":
                    return create_model(
                        MaskedLanguageModel, model_path, device, datatype, pool
                    )
                return create_model(DefaultModel, model_path, device, datatype, pool)

            try:
                return create_model(FlashBert, model_path, device, datatype)
            except FileNotFoundError:
                logger.info(
                    "Do not have safetensors file for this model, use default transformers model path instead"
                )
                return create_model(DefaultModel, model_path, device, datatype, pool)

        if config.architectures[0].endswith("Classification"):
            return create_model(ClassificationModel, model_path, device, datatype)
        elif config.architectures[0].endswith("ForMaskedLM") and pool == "splade":
            return create_model(MaskedLanguageModel, model_path, device, datatype)
        else:
            return create_model(DefaultModel, model_path, device, datatype, pool)

    if config.model_type == "mistral" and device.type == "hpu":
        try:
            return create_model(FlashMistral, model_path, device, datatype, pool)
        except FileNotFoundError:
            return create_model(DefaultModel, model_path, device, datatype, pool)

    if config.model_type == "qwen3" and device.type == "hpu":
        try:
            return create_model(FlashQwen3, model_path, device, datatype, pool)
        except FileNotFoundError:
            return create_model(DefaultModel, model_path, device, datatype, pool)

    # Default case
    if config.architectures[0].endswith("Classification"):
        return create_model(ClassificationModel, model_path, device, datatype)
    elif config.architectures[0].endswith("ForMaskedLM") and pool == "splade":
        return create_model(MaskedLanguageModel, model_path, device, datatype)
    else:
        return create_model(DefaultModel, model_path, device, datatype, pool)