backends/gaudi/server/text_generation_server/models/__init__.py (984 lines of code) (raw):
# ruff: noqa: F821
# the above line disables the `undefined-name` rule for the model type variables
import torch
import os
from loguru import logger
from transformers.configuration_utils import PretrainedConfig
from huggingface_hub import hf_hub_download, HfApi
from typing import Optional
from pathlib import Path
from typing import List, Dict
import enum
# Needed to properly setup habana_frameworks
from text_generation_server.utils.speculate import get_speculate, set_speculate
from text_generation_server.models.model import Model
from text_generation_server.models.custom_modeling.flash_phi_moe_modeling import (
PhiMoEConfig,
)
from text_generation_server.utils.adapter import (
AdapterParameters,
build_layer_weight_lookup,
load_and_merge_adapters,
AdapterInfo,
)
from text_generation_server.adapters.lora import LoraWeights
from text_generation_server.utils.log import log_master
__all__ = [
"Model",
"CausalLM",
"Seq2SeqLM",
"get_model_with_lora_adapters",
]
VLM_BATCH_TYPES = set()
FLASH_ATTENTION = True
try:
from text_generation_server.models.flash_causal_lm import FlashCausalLM
from text_generation_server.models.flash_vlm_causal_lm import FlashVlmCausalLM
from text_generation_server.models.mllama_causal_lm import FlashMllamaCausalLM
from text_generation_server.models.custom_modeling.flash_deepseek_v2_modeling import (
FlashDeepseekV2ForCausalLM,
DeepseekV2Config,
)
from text_generation_server.models.custom_modeling.flash_deepseek_v3_modeling import (
FlashDeepseekV3ForCausalLM,
DeepseekV3Config,
)
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
FlashLlamaForCausalLM,
)
from text_generation_server.models.custom_modeling.flash_llama4_modeling import (
Llama4ForConditionalGeneration,
)
from text_generation_server.models.custom_modeling.flash_cohere_modeling import (
FlashCohereForCausalLM,
)
from text_generation_server.models.custom_modeling.flash_gemma_modeling import (
FlashGemmaForCausalLM,
)
from text_generation_server.models.custom_modeling.flash_gemma2_modeling import (
FlashGemma2ForCausalLM,
)
from text_generation_server.models.custom_modeling.flash_gemma3_modeling import (
Gemma3ForConditionalGeneration,
FlashGemma3ForCausalLM,
)
from text_generation_server.models.custom_modeling.flash_dbrx_modeling import (
FlashDbrxForCausalLM,
DbrxConfig,
)
from text_generation_server.models.custom_modeling.flash_rw_modeling import (
RWConfig,
FlashRWForCausalLM,
)
from text_generation_server.models.custom_modeling.flash_neox_modeling import (
FlashGPTNeoXForCausalLM,
)
from text_generation_server.models.custom_modeling.flash_pali_gemma_modeling import (
PaliGemmaForConditionalGeneration,
)
from text_generation_server.models.custom_modeling.flash_phi_modeling import (
FlashPhiForCausalLM,
)
from text_generation_server.models.mllama_causal_lm import FlashMllamaCausalLMBatch
from text_generation_server.models.custom_modeling.flash_mllama import (
FlashMllamaForConditionalGeneration,
)
from text_generation_server.models.custom_modeling.flash_llava_next import (
FlashLlavaNextForConditionalGeneration,
)
from text_generation_server.models.custom_modeling.flash_santacoder_modeling import (
FlashSantacoderForCausalLM,
)
from text_generation_server.models.custom_modeling.flash_starcoder2_modeling import (
FlashStarcoder2ForCausalLM,
)
from text_generation_server.models.custom_modeling.flash_qwen2_modeling import (
Qwen2ForCausalLM,
)
from text_generation_server.models.custom_modeling.flash_qwen3_modeling import (
Qwen3ForCausalLM,
)
from text_generation_server.models.custom_modeling.flash_qwen3_moe_modeling import (
Qwen3MoeForCausalLM,
)
from text_generation_server.models.custom_modeling.flash_mistral_modeling import (
FlashMistralForCausalLM,
)
from text_generation_server.models.custom_modeling.flash_mixtral_modeling import (
FlashMixtralForCausalLM,
)
from text_generation_server.models.custom_modeling.flash_gpt2_modeling import (
FlashGPT2ForCausalLM,
)
from text_generation_server.models.custom_modeling.flash_gptj_modeling import (
FlashGPTJForCausalLM,
)
from text_generation_server.models.custom_modeling.idefics2 import (
Idefics2ForConditionalGeneration,
)
from text_generation_server.models.custom_modeling.idefics3 import (
Idefics3ForConditionalGeneration,
)
from text_generation_server.models.custom_modeling.qwen2_vl import (
Qwen2VLForConditionalGeneration,
)
from text_generation_server.models.custom_modeling.qwen2_5_vl import (
Qwen2_5VLForConditionalGeneration,
Qwen2_5_VLConfig,
Qwen2_5_VLProcessor,
)
from text_generation_server.layers.attention import SUPPORTS_WINDOWING
except ImportError as e:
log_master(logger.warning, f"Could not import Flash Attention enabled models: {e}")
SUPPORTS_WINDOWING = False
FLASH_ATTENTION = False
VLM_BATCH_TYPES = set()
if FLASH_ATTENTION:
__all__.append(FlashCausalLM)
from text_generation_server.models.flash_vlm_causal_lm import (
FlashVlmCausalLMBatch,
)
VLM_BATCH_TYPES = {
FlashVlmCausalLMBatch,
FlashMllamaCausalLMBatch,
}
__all__.append(VLM_BATCH_TYPES)
class ModelType(enum.Enum):
DEEPSEEK_V2 = {
"type": "deepseek_v2",
"name": "Deepseek V2",
"url": "https://huggingface.co/deepseek-ai/DeepSeek-V2",
}
DEEPSEEK_V3 = {
"type": "deepseek_v3",
"name": "Deepseek V3",
"url": "https://huggingface.co/deepseek-ai/DeepSeek-V3",
}
IDEFICS2 = {
"type": "idefics2",
"name": "Idefics 2",
"url": "https://huggingface.co/HuggingFaceM4/idefics2-8b",
"multimodal": True,
}
IDEFICS3 = {
"type": "idefics3",
"name": "Idefics 3",
"url": "https://huggingface.co/HuggingFaceM4/Idefics3-8B-Llama3",
"multimodal": True,
}
LLAVA_NEXT = {
"type": "llava_next",
"name": "Llava Next (1.6)",
"url": "https://huggingface.co/llava-hf/llava-v1.6-vicuna-13b-hf",
"multimodal": True,
}
LLAMA = {
"type": "llama",
"name": "Llama",
"url": "https://huggingface.co/collections/meta-llama/llama-31-669fc079a0c406a149a5738f",
}
LLAMA4 = {
"type": "llama4",
"name": "Llama4",
"url": "https://huggingface.co/collections/meta-llama/llama-31-669fc079a0c406a149a5738f",
}
PHI3 = {
"type": "phi3",
"name": "Phi 3",
"url": "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct",
}
GRANITE = {
"type": "granite",
"name": "Granite",
"url": "https://huggingface.co/ibm-granite/granite-3.0-8b-instruct",
}
GEMMA = {
"type": "gemma",
"name": "Gemma",
"url": "https://huggingface.co/google/gemma-7b",
}
PALIGEMMA = {
"type": "paligemma",
"name": "PaliGemma",
"url": "https://huggingface.co/google/paligemma-3b-pt-224",
}
GEMMA2 = {
"type": "gemma2",
"name": "Gemma2",
"url": "https://huggingface.co/collections/google/gemma-2-release-667d6600fd5220e7b967f315",
}
GEMMA3 = {
"type": "gemma3",
"name": "Gemma3",
"url": "https://huggingface.co/collections/google/gemma-3-release-67c6c6f89c4f76621268bb6d",
}
GEMMA3_TEXT = {
"type": "gemma3_text",
"name": "Gemma3 Text",
"url": "https://huggingface.co/collections/google/gemma-3-release-67c6c6f89c4f76621268bb6d",
}
COHERE = {
"type": "cohere",
"name": "Cohere",
"url": "https://huggingface.co/CohereForAI/c4ai-command-r-plus",
}
DBRX = {
"type": "dbrx",
"name": "Dbrx",
"url": "https://huggingface.co/databricks/dbrx-instruct",
}
MAMBA = {
"type": "mamba",
"name": "Mamba",
"url": "https://huggingface.co/state-spaces/mamba-2.8b-slimpj",
}
MISTRAL = {
"type": "mistral",
"name": "Mistral",
"url": "https://huggingface.co/mistralai/Mistral-Nemo-Instruct-2407",
}
MIXTRAL = {
"type": "mixtral",
"name": "Mixtral",
"url": "https://huggingface.co/mistralai/Mixtral-8x22B-Instruct-v0.1",
}
GPT_BIGCODE = {
"type": "gpt_bigcode",
"name": "Gpt Bigcode",
"url": "https://huggingface.co/bigcode/gpt_bigcode-santacoder",
}
PHI = {
"type": "phi",
"name": "Phi",
"url": "https://huggingface.co/microsoft/phi-1_5",
}
PHI_MOE = {
"type": "phimoe",
"name": "PhiMoe",
"url": "https://huggingface.co/microsoft/Phi-3.5-MoE-instruct",
}
BAICHUAN = {
"type": "baichuan",
"name": "Baichuan",
"url": "https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat",
}
FALCON = {
"type": "falcon",
"name": "Falcon",
"url": "https://huggingface.co/tiiuae/falcon-7b-instruct",
}
STARCODER2 = {
"type": "starcoder2",
"name": "StarCoder 2",
"url": "https://huggingface.co/bigcode/starcoder2-15b-instruct-v0.1",
}
QWEN2 = {
"type": "qwen2",
"name": "Qwen 2",
"url": "https://huggingface.co/collections/Qwen/qwen2-6659360b33528ced941e557f",
}
QWEN2_VL = {
"type": "qwen2_vl",
"name": "Qwen 2 VL",
"url": "https://huggingface.co/collections/Qwen/qwen2-vl-66cee7455501d7126940800d",
}
QWEN2_5_VL = {
"type": "qwen2_5_vl",
"name": "Qwen 2.5 VL",
"url": "https://huggingface.co/collections/Qwen/qwen25-66e81a666513e518adb90d9e",
}
QWEN3 = {
"type": "qwen3",
"name": "Qwen 3",
"url": "https://huggingface.co/collections/Qwen/qwen3-67dd247413f0e2e4f653967f",
}
QWEN3_MOE = {
"type": "qwen3_moe",
"name": "Qwen 3 Moe",
"url": "https://huggingface.co/collections/Qwen/qwen3-67dd247413f0e2e4f653967f",
}
GALACTICA = {
"type": "galactica",
"name": "Galactica",
"url": "https://huggingface.co/facebook/galactica-120b",
}
SANTACODER = {
"type": "santacoder",
"name": "SantaCoder",
"url": "https://huggingface.co/bigcode/santacoder",
}
GPT2 = {
"type": "gpt2",
"name": "Gpt2",
"url": "https://huggingface.co/openai-community/gpt2",
}
GPT_NEOX = {
"type": "gpt_neox",
"name": "Gpt Neox",
"url": "https://huggingface.co/EleutherAI/gpt-neox-20b",
}
GPTJ = {
"type": "gptj",
"name": "Gptj",
"url": "https://huggingface.co/EleutherAI/gpt-j-6b",
}
MLLAMA = {
"type": "mllama",
"name": "Mllama",
"url": "https://huggingface.co/meta-llama/Llama-3.2-11B-Vision-Instruct",
"multimodal": True,
}
__GLOBALS = locals()
for data in ModelType:
__GLOBALS[data.name] = data.value["type"]
SDP_ON_BF16 = int(os.environ.get("SDP_ON_BF16", 0))
# Disable gradients
torch.set_grad_enabled(False)
def get_model(
model_id: str,
lora_adapter_ids: Optional[List[str]],
revision: Optional[str],
sharded: bool,
quantize: Optional[str],
speculate: Optional[int],
dtype: Optional[torch.dtype],
kv_cache_dtype: Optional[str],
trust_remote_code: bool,
max_input_tokens: int,
) -> Model:
global FLASH_ATTENTION
if speculate is not None:
set_speculate(speculate)
else:
set_speculate(0)
config_dict, _ = PretrainedConfig.get_config_dict(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
model_type = config_dict.get("model_type", None)
speculator = None
if "medusa_num_heads" in config_dict:
medusa_model_id = model_id
medusa_revision = revision
model_id = config_dict["base_model_name_or_path"]
revision = "main"
speculate_medusa = config_dict["medusa_num_heads"]
if speculate is not None:
if speculate > speculate_medusa:
raise RuntimeError(
f"Speculate is set to `{speculate}` but this medusa models only has `{speculate_medusa}` heads, please make them match"
)
else:
set_speculate(speculate)
else:
set_speculate(speculate_medusa)
config_dict, _ = PretrainedConfig.get_config_dict(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
# Reload model type from parent.
model_type = config_dict.get("model_type", None)
is_local = Path(medusa_model_id).exists()
if not is_local:
medusa_config = hf_hub_download(
medusa_model_id, revision=medusa_revision, filename="config.json"
)
hf_hub_download(
medusa_model_id,
revision=medusa_revision,
filename="medusa_lm_head.safetensors",
)
speculator = {
"path": Path(medusa_config).parent,
"model_paths": ["medusa_lm_head.safetensors"],
}
else:
speculator = {
"path": Path(medusa_model_id),
"model_paths": ["medusa_lm_head.safetensors"],
}
method = "medusa"
elif model_type == "mlp_speculator":
mlp_model_id = model_id
mlp_revision = revision
model_id = config_dict["base_model_name_or_path"]
revision = "main"
speculate_mlp = config_dict["n_predict"]
if speculate is not None:
if speculate > speculate_mlp:
raise RuntimeError(
f"Speculate is set to `{speculate}` but this mlp_speculator models only has `{speculate_mlp}` heads, please make them match"
)
else:
set_speculate(speculate)
else:
set_speculate(speculate_mlp)
config_dict, _ = PretrainedConfig.get_config_dict(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
# Reload model type from parent.
model_type = config_dict.get("model_type", None)
is_local = Path(mlp_model_id).exists()
extension = ".safetensors"
if not is_local:
mlp_speculator_config = hf_hub_download(
mlp_model_id, revision=mlp_revision, filename="config.json"
)
api = HfApi()
info = api.model_info(mlp_model_id, revision=mlp_revision)
filenames = [
s.rfilename
for s in info.siblings
if s.rfilename.endswith(extension)
and len(s.rfilename.split("/")) == 1
and "arguments" not in s.rfilename
and "args" not in s.rfilename
and "training" not in s.rfilename
]
for filename in filenames:
hf_hub_download(
mlp_model_id,
revision=mlp_revision,
filename=filename,
)
speculator_dir_path = Path(mlp_speculator_config).parent
# if these are downloaded, they get converted to safetensors
filenames.extend(
[p for p in os.listdir(speculator_dir_path) if p.endswith(extension)]
)
speculator = {
"path": Path(mlp_speculator_config).parent,
"model_paths": filenames,
}
else:
speculator = Path(mlp_model_id)
filenames = [p for p in os.listdir(speculator) if p.endswith(extension)]
speculator = {"path": speculator, "model_paths": filenames}
method = "mlp_speculator"
else:
method = "n-gram"
speculate = get_speculate()
if speculate > 0:
logger.info(f"Using speculation {method} with {speculate} input ids.")
model_type = config_dict["model_type"]
if kv_cache_dtype == "fp8_e4m3fn":
kv_cache_dtype = torch.float8_e4m3fn
elif kv_cache_dtype == "fp8_e5m2":
kv_cache_dtype = torch.float8_e5m2
else:
kv_cache_dtype = dtype
if FLASH_ATTENTION:
if model_type == DEEPSEEK_V2:
head_size = max(
config_dict.get("qk_nope_dim", 128)
+ config_dict.get("qk_rope_dim", 64),
config_dict.get("v_head_dim", 128),
)
return FlashCausalLM(
model_id=model_id,
model_class=FlashDeepseekV2ForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
default_dtype=torch.bfloat16,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
config_class=DeepseekV2Config,
head_size=head_size,
)
elif model_type == DEEPSEEK_V3:
head_size = max(
config_dict.get("qk_nope_dim", 128)
+ config_dict.get("qk_rope_dim", 64),
config_dict.get("v_head_dim", 128),
)
return FlashCausalLM(
model_id=model_id,
model_class=FlashDeepseekV3ForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
default_dtype=torch.bfloat16,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
config_class=DeepseekV3Config,
head_size=head_size,
)
elif (
model_type == GPT_BIGCODE
or model_type == GPT2
and model_id.startswith("bigcode/")
):
return FlashCausalLM(
model_id=model_id,
model_class=FlashSantacoderForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
aliases={"transformer.wte.weight": ["lm_head.weight"]},
num_kv_heads=1,
)
elif model_type == GPT2:
return FlashCausalLM(
model_id=model_id,
model_class=FlashGPT2ForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
elif model_type == GPTJ:
return FlashCausalLM(
model_id=model_id,
model_class=FlashGPTJForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
elif model_type == GPT_NEOX:
from text_generation_server.models.custom_modeling.flash_neox_modeling import (
GPTNeoXConfig,
)
return FlashCausalLM(
model_id=model_id,
model_class=FlashGPTNeoXForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
config_class=GPTNeoXConfig,
)
elif model_type == PHI:
return FlashCausalLM(
model_id=model_id,
model_class=FlashPhiForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
elif model_type == PHI_MOE:
return FlashCausalLM(
model_id=model_id,
model_class=FlashLlamaForCausalLM,
config_class=PhiMoEConfig,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
elif model_type == LLAMA or model_type == PHI3 or model_type == GRANITE:
return FlashCausalLM(
model_id=model_id,
model_class=FlashLlamaForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
elif model_type == LLAMA4:
print(f"Llama4 model detected: {model_id}")
return FlashVlmCausalLM(
model_id=model_id,
model_class=Llama4ForConditionalGeneration,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
default_dtype=torch.bfloat16,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
support_chunking=False,
)
elif model_type == BAICHUAN:
return FlashCausalLM(
model_id=model_id,
model_class=FlashLlamaForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
elif model_type == GEMMA:
return FlashCausalLM(
model_id=model_id,
model_class=FlashGemmaForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
# Works better for these models
default_dtype=torch.bfloat16,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
elif model_type == GEMMA2:
return FlashCausalLM(
model_id=model_id,
model_class=FlashGemma2ForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
# Works better for these models
default_dtype=torch.bfloat16,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
elif model_type == GEMMA3:
return FlashVlmCausalLM(
model_id=model_id,
model_class=Gemma3ForConditionalGeneration,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
default_dtype=torch.bfloat16,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
support_chunking=False,
)
elif model_type == GEMMA3_TEXT:
return FlashCausalLM(
model_id=model_id,
model_class=FlashGemma3ForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
# Works better for these models
default_dtype=torch.bfloat16,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
elif model_type == COHERE:
return FlashCausalLM(
model_id=model_id,
model_class=FlashCohereForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
elif model_type == DBRX:
return FlashCausalLM(
model_id=model_id,
model_class=FlashDbrxForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
# Dbrx works better in bfloat16.
default_dtype=torch.bfloat16,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
config_class=DbrxConfig,
)
elif (
model_type in ["RefinedWeb", "RefinedWebModel", FALCON]
and not sharded
and not config_dict.get("alibi", False)
):
return FlashCausalLM(
model_id=model_id,
model_class=FlashRWForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
aliases={
"lm_head.weight": ["transformer.word_embeddings.weight"],
"transformer.word_embeddings.weight": ["lm_head.weight"],
},
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
config_class=RWConfig,
)
elif model_type == MISTRAL:
return FlashCausalLM(
model_id=model_id,
model_class=FlashMistralForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
elif model_type == MIXTRAL:
return FlashCausalLM(
model_id=model_id,
model_class=FlashMixtralForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
elif model_type == STARCODER2:
return FlashCausalLM(
model_id=model_id,
model_class=FlashStarcoder2ForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
elif model_type == QWEN2:
return FlashCausalLM(
model_id=model_id,
model_class=Qwen2ForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
elif model_type == QWEN2_VL:
return FlashVlmCausalLM(
model_id=model_id,
model_class=Qwen2VLForConditionalGeneration,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
default_dtype=torch.bfloat16,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
# TODO: Fix bug in rust image_text_replacement implementation
support_chunking=False,
)
elif model_type == QWEN2_5_VL:
return FlashVlmCausalLM(
model_id=model_id,
model_class=Qwen2_5VLForConditionalGeneration,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
default_dtype=torch.bfloat16,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
config_class=Qwen2_5_VLConfig,
processor_class=Qwen2_5_VLProcessor,
# TODO: Fix bug in rust image_text_replacement implementation
support_chunking=False,
)
elif model_type == QWEN3:
return FlashCausalLM(
model_id=model_id,
model_class=Qwen3ForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
elif model_type == QWEN3_MOE:
return FlashCausalLM(
model_id=model_id,
model_class=Qwen3MoeForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
elif model_type == MLLAMA:
return FlashMllamaCausalLM(
model_id=model_id,
model_class=FlashMllamaForConditionalGeneration,
batch_class=FlashMllamaCausalLMBatch,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
default_dtype=torch.bfloat16,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
support_chunking=False,
)
elif model_type == IDEFICS2:
return FlashVlmCausalLM(
model_id=model_id,
model_class=Idefics2ForConditionalGeneration,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
# XXX: Extremely important to cap resolution in order to limit
# VRAM usage.
processor_kwargs={"size": {"longest_edge": 448, "shortest_edge": 378}},
)
elif model_type == IDEFICS3:
return FlashVlmCausalLM(
model_id=model_id,
model_class=Idefics3ForConditionalGeneration,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
default_dtype=torch.bfloat16,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
# XXX: Extremely important to cap resolution in order to limit
# VRAM usage.
processor_kwargs={"size": {"longest_edge": 1456}},
)
elif model_type == PALIGEMMA:
return FlashVlmCausalLM(
model_id=model_id,
model_class=PaliGemmaForConditionalGeneration,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
# Works better for these models
default_dtype=torch.bfloat16,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
elif model_type == LLAVA_NEXT:
return FlashVlmCausalLM(
model_class=FlashLlavaNextForConditionalGeneration,
model_id=model_id,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
)
raise ValueError(f"Unsupported model type {model_type}")
# get_model_with_lora_adapters wraps the internal get_model function and adds support for loading adapters
# this provides a post model loading hook to load adapters into the model after the model has been loaded
def get_model_with_lora_adapters(
model_id: str,
lora_adapters: Optional[List[AdapterInfo]],
revision: Optional[str],
sharded: bool,
quantize: Optional[str],
speculate: Optional[int],
dtype: Optional[torch.dtype],
kv_cache_dtype: Optional[str],
trust_remote_code: bool,
max_input_tokens: int,
adapter_to_index: Dict[str, int],
):
lora_adapter_ids = [adapter.id for adapter in lora_adapters]
model = get_model(
model_id,
lora_adapter_ids,
revision,
sharded,
quantize,
speculate,
dtype,
kv_cache_dtype,
trust_remote_code,
max_input_tokens,
)
if len(lora_adapters) > 0:
target_to_layer = build_layer_weight_lookup(model.model)
for index, adapter in enumerate(lora_adapters):
# The AdapterParameters object allows for merging multiple adapters into a single adapter.
# At the moment, we only support loading a single adapter into the model, but we keep the
# AdapterParameters object for easier extension in the future.
adapter_parameters = AdapterParameters(
adapter_info=[adapter],
# when merging multiple adapters we can weight them differently
# if this is not set, all adapters will be weighted equally
# see: text_generation_server.utils.merges.strategies for impl
weights=None,
merge_strategy=0,
density=1.0,
majority_sign_method=0,
)
adapter_index = index + 1
adapter_to_index[adapter.id] = adapter_index
logger.info(
f"Loading adapter weights into model: {','.join([adapter.id for adapter in adapter_parameters.adapter_info])}"
)
weight_names = tuple([v[0] for v in target_to_layer.values()])
(
module_map,
adapter_config,
adapter_weight_names,
adapter_tokenizer,
) = load_and_merge_adapters(
model.model_id,
adapter_parameters,
adapter_index,
weight_names,
False,
)
unused_weight_names = adapter_weight_names.copy()
adapter_layers = [
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
"qkv_proj",
]
for layer_name in adapter_layers:
nlayers = (
1 if layer_name == "lm_head" else len(model.model.model.layers)
)
adapter_weights = LoraWeights.prepare_weights(
config=adapter_config,
module_map=module_map,
layer_type=layer_name,
unused_weight_names=unused_weight_names,
nlayers=nlayers,
dtype=model.dtype,
world_size=model.world_size,
process_group=model.process_group,
target_to_layer=target_to_layer,
)
if adapter_weights is None:
continue
model.layer_to_adapter_weights[layer_name].add_adapter(
adapter_index, adapter_weights
)
if len(unused_weight_names) > 0:
logger.warning(
f"{','.join([a.id for a in lora_adapters])} unused adapter weights: {unused_weight_names}"
)
if adapter_tokenizer is not None:
model.tokenizers.add_tokenizer(adapter_index, adapter_tokenizer)
model.loaded_adapters.add(adapter_index)
return model