backends/python/server/text_embeddings_server/models/flash_qwen3.py (396 lines of code) (raw):
import torch
import json
from pathlib import Path
from torch import nn
import torch.nn.functional as F
from typing import List, Union, Optional
from safetensors import safe_open
from transformers.activations import ACT2FN
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.models.qwen3 import Qwen3Config
from opentelemetry import trace
from text_embeddings_server.models import Model
from text_embeddings_server.models.pooling import DefaultPooling
from text_embeddings_server.models.types import FlashBatch, Embedding, PaddedBatch
from text_embeddings_server.utils.flash_attn import attention
tracer = trace.get_tracer(__name__)
def load_weight(model_path, weight_map, name, dtype, device):
"""
Helper function to load a weight tensor from safetensors.
"""
target_file = weight_map[name]
with safe_open(f"{model_path}/{target_file}", framework="pt") as f:
return f.get_tensor(name).to(dtype).to(device)
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
def compute_default_rope_parameters(
config: Qwen3Config,
device: torch.device,
) -> tuple["torch.Tensor", float]:
base = config.rope_theta
partial_rotary_factor = (
config.partial_rotary_factor
if hasattr(config, "partial_rotary_factor")
else 1.0
)
head_dim = (
getattr(config, "head_dim", None)
or config.hidden_size // config.num_attention_heads
)
dim = int(head_dim * partial_rotary_factor)
attention_factor = 1.0
inv_freq = 1.0 / (
base
** (
torch.arange(0, dim, 2, dtype=torch.int64).to(
device=device, dtype=torch.float
)
/ dim
)
)
return inv_freq, attention_factor
class Qwen3RMSNorm:
def __init__(
self,
model_path,
weight_map,
name,
device,
dtype,
eps=1e-6,
):
self.weight = load_weight(model_path, weight_map, name, dtype, device)
self.variance_epsilon = eps
def forward(self, hidden_states):
if hidden_states.device.type == "hpu":
from habana_frameworks.torch.hpex.normalization import (
FusedRMSNorm as FusedRMSNorm,
)
hidden_states = FusedRMSNorm.apply(
hidden_states, self.weight, self.variance_epsilon
)
return hidden_states
else:
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(
variance + self.variance_epsilon
)
return self.weight * hidden_states.to(input_dtype)
class Qwen3Attention:
def __init__(
self,
model_path,
weight_map,
device,
dtype,
config: Qwen3Config,
layer_idx: Optional[int] = None,
):
self.num_heads = config.num_attention_heads
self.head_dim = config.head_dim
self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.softmax_scale = self.head_dim**-0.5
self.q_proj_weight = load_weight(
model_path,
weight_map,
f"layers.{layer_idx}.self_attn.q_proj.weight",
dtype,
device,
)
self.k_proj_weight = load_weight(
model_path,
weight_map,
f"layers.{layer_idx}.self_attn.k_proj.weight",
dtype,
device,
)
self.v_proj_weight = load_weight(
model_path,
weight_map,
f"layers.{layer_idx}.self_attn.v_proj.weight",
dtype,
device,
)
self.o_proj_weight = load_weight(
model_path,
weight_map,
f"layers.{layer_idx}.self_attn.o_proj.weight",
dtype,
device,
)
self.q_norm = Qwen3RMSNorm(
model_path,
weight_map,
f"layers.{layer_idx}.self_attn.q_norm.weight",
device,
dtype,
eps=config.rms_norm_eps,
)
self.k_norm = Qwen3RMSNorm(
model_path,
weight_map,
f"layers.{layer_idx}.self_attn.k_norm.weight",
device,
dtype,
eps=config.rms_norm_eps,
)
def forward(
self, hidden_states, position_embeddings, cu_seqlens, max_s, attn_mask=None
):
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
q = self.q_norm.forward(
F.linear(hidden_states, self.q_proj_weight).view(hidden_shape)
)
k = self.k_norm.forward(
F.linear(hidden_states, self.k_proj_weight).view(hidden_shape)
)
v = F.linear(hidden_states, self.v_proj_weight).view(hidden_shape)
cos, sin = position_embeddings
q, k = apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=2)
attn_output = torch.empty_like(q)
attention(
q,
k,
v,
attn_output,
cu_seqlens,
max_s,
self.softmax_scale,
is_causal=True,
attn_mask=attn_mask,
)
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = F.linear(attn_output, self.o_proj_weight, bias=None)
return attn_output
class Qwen3MLP:
def __init__(
self,
model_path,
weight_map,
device,
dtype,
config: Qwen3Config,
layer_idx: Optional[int] = None,
):
self.gate_proj_weight = load_weight(
model_path,
weight_map,
f"layers.{layer_idx}.mlp.gate_proj.weight",
dtype,
device,
)
self.up_proj_weight = load_weight(
model_path,
weight_map,
f"layers.{layer_idx}.mlp.up_proj.weight",
dtype,
device,
)
self.down_proj_weight = load_weight(
model_path,
weight_map,
f"layers.{layer_idx}.mlp.down_proj.weight",
dtype,
device,
)
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, hidden_state):
gated_hidden_states = F.linear(hidden_state, self.gate_proj_weight)
uped_hidden_states = F.linear(hidden_state, self.up_proj_weight)
return F.linear(
self.act_fn(gated_hidden_states) * uped_hidden_states,
self.down_proj_weight,
)
class Qwen3DecoderLayer:
def __init__(
self,
model_path,
weight_map,
device,
dtype,
config: Qwen3Config,
layer_idx: Optional[int] = None,
):
self.attention = Qwen3Attention(
model_path, weight_map, device, dtype, config, layer_idx
)
self.mlp = Qwen3MLP(model_path, weight_map, device, dtype, config, layer_idx)
self.input_layernorm = Qwen3RMSNorm(
model_path,
weight_map,
f"layers.{layer_idx}.input_layernorm.weight",
device,
dtype,
eps=config.rms_norm_eps,
)
self.post_attention_layernorm = Qwen3RMSNorm(
model_path,
weight_map,
f"layers.{layer_idx}.post_attention_layernorm.weight",
device,
dtype,
eps=config.rms_norm_eps,
)
def forward(
self, hidden_states, position_embeddings, cu_seqlens, max_s, attn_mask=None
):
residual = hidden_states
hidden_states = self.input_layernorm.forward(hidden_states)
# Self Attention
hidden_states = self.attention.forward(
hidden_states, position_embeddings, cu_seqlens, max_s, attn_mask
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm.forward(hidden_states)
hidden_states = self.mlp.forward(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
class Qwen3RotaryEmbedding(nn.Module):
def __init__(self, config: Qwen3Config, device=None):
super().__init__()
inv_freq, self.attention_scaling = compute_default_rope_parameters(
config, device
)
self.register_buffer("inv_freq", inv_freq, persistent=False)
def forward(self, x, position_ids):
inv_freq_expanded = (
self.inv_freq[None, :, None]
.float()
.expand(position_ids.shape[0], -1, 1)
.to(x.device)
)
position_ids_expanded = position_ids[:, None, :].float()
device_type = (
x.device.type
if isinstance(x.device.type, str) and x.device.type != "mps"
else "cpu"
)
with torch.autocast(device_type=device_type, enabled=False): # Force float32
freqs = (
inv_freq_expanded.float() @ position_ids_expanded.float()
).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos() * self.attention_scaling
sin = emb.sin() * self.attention_scaling
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
class FlashQwen3Model:
"""
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MistralDecoderLayer`]
Args:
config: MistralConfig
"""
def __init__(self, model_path, weight_map_json, device, dtype, config: Qwen3Config):
self.word_embeddings_weight = load_weight(
model_path,
weight_map_json["weight_map"],
"embed_tokens.weight",
dtype,
device,
)
self.layers = [
Qwen3DecoderLayer(
model_path,
weight_map_json["weight_map"],
device,
dtype,
config,
layer_idx,
)
for layer_idx in range(config.num_hidden_layers)
]
self.rotary_emb = Qwen3RotaryEmbedding(config=config, device=device)
self.norm = Qwen3RMSNorm(
model_path,
weight_map_json["weight_map"],
f"norm.weight",
device,
dtype,
eps=config.rms_norm_eps,
)
def forward(
self,
input_ids,
position_ids,
cu_seqlens,
max_s,
mask=None,
attn_mask=None,
):
inputs_embeds = nn.functional.embedding(input_ids, self.word_embeddings_weight)
hidden_states = inputs_embeds
position_embeddings = self.rotary_emb(hidden_states, position_ids)
for layer in self.layers:
hidden_states = layer.forward(
hidden_states, position_embeddings, cu_seqlens, max_s, attn_mask
)
hidden_states = self.norm.forward(hidden_states)
return BaseModelOutputWithPast(last_hidden_state=hidden_states)
class FlashQwen3(Model):
def __init__(
self,
model_path: Path,
device: torch.device,
dtype: torch.dtype,
pool: str = "cls",
trust_remote: bool = False,
):
config = Qwen3Config.from_pretrained(model_path)
if hasattr(config, "max_seq_length"):
self.max_input_length = config.max_seq_length
else:
self.max_input_length = config.max_position_embeddings
with open(model_path / "model.safetensors.index.json", "r") as f:
index_data = json.load(f)
model = FlashQwen3Model(model_path, index_data, device, dtype, config)
self.hidden_size = config.hidden_size
self.pooling = DefaultPooling(self.hidden_size, pooling_mode=pool)
self.device = device
self.dtype = dtype
super(FlashQwen3, self).__init__(model=model, dtype=dtype, device=device)
@property
def batch_type(self) -> Union[FlashBatch, PaddedBatch]:
# for hpu devices, we use PaddedBatch as we do not have real varlen fwd yet
return FlashBatch if self.device.type != "hpu" else PaddedBatch
@tracer.start_as_current_span("embed")
def embed(self, batch: Union[FlashBatch, PaddedBatch]) -> List[Embedding]:
if isinstance(batch, PaddedBatch):
input_lens = batch.attention_mask.cumsum(-1)[:, -1].to(torch.int32)
max_input_lens = 0
cu_seqlens = torch.cat(
(input_lens.new_tensor([0]), input_lens.cumsum(-1).int())
)
mask = batch.attention_mask.bool()
bsz, tgt_len = mask.size()
min_val = torch.finfo(self.dtype).min
attn_mask = torch.full(
[bsz, 1, tgt_len, tgt_len],
fill_value=min_val,
device=self.device,
dtype=self.dtype,
)
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, tgt_len)
attn_mask = attn_mask.masked_fill(expanded_mask, 0.0)
elif isinstance(batch, FlashBatch):
cu_seqlens = batch.cu_seqlens
mask = None
attn_mask = None
max_input_lens = batch.max_s
output = self.model.forward(
input_ids=batch.input_ids,
position_ids=batch.position_ids,
cu_seqlens=cu_seqlens,
max_s=max_input_lens,
mask=mask,
attn_mask=attn_mask,
)
embedding = self.pooling.forward(output, batch.attention_mask)
cpu_results = embedding.view(-1).tolist()
return [
Embedding(
values=cpu_results[i * self.hidden_size : (i + 1) * self.hidden_size]
)
for i in range(len(batch))
]