server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py [158:363]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
    def __init__(
        self,
        prefix: str,
        config,
        weights: Weights,
    ):
        super().__init__()
        self.num_heads = config.num_attention_heads
        self.hidden_size = config.hidden_size
        self.kv_lora_rank = config.kv_lora_rank
        self.q_lora_rank = config.q_lora_rank
        self.qk_nope_head_dim = config.qk_nope_head_dim
        self.qk_rope_head_dim = config.qk_rope_head_dim
        self.head_size = config.qk_nope_head_dim + config.qk_rope_head_dim
        self.value_head_size = config.v_head_dim
        self.head_pad_size = max(self.head_size, self.value_head_size)

        self.rotary_emb = PositionRotaryEmbedding.static(
            config=config,
            dim=self.qk_rope_head_dim,
            base=config.rope_theta,
            device=weights.device,
        )

        mscale = get_mscale(
            self.rotary_emb.scaling_factor, self.rotary_emb.mscale_all_dim
        )
        self.softmax_scale = self.head_size**-0.5 * mscale * mscale

        if self.num_heads % weights.process_group.size() != 0:
            raise ValueError(
                f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
                f"and `num_shards`: {weights.process_group.size()}"
            )
        self.num_heads = self.num_heads // weights.process_group.size()
        self.num_key_value_heads = (
            config.num_key_value_heads // weights.process_group.size()
        )

        if self.q_lora_rank is None:
            self.q_proj = TensorParallelColumnLinear.load(
                config,
                prefix=f"{prefix}.q_proj",
                weights=weights,
                bias=config.attention_bias,
            )
        else:
            self.q_a_proj = get_linear(
                weight=weights.get_weights(f"{prefix}.q_a_proj"),
                bias=(
                    weights.get_tensor(f"{prefix}.q_a_proj.bias")
                    if config.attention_bias
                    else None
                ),
            )
            self.q_a_layernorm = FastRMSNorm.load(
                prefix=f"{prefix}.q_a_layernorm",
                weights=weights,
                eps=config.rms_norm_eps,
            )
            self.q_b_proj = TensorParallelColumnLinear.load(
                config,
                prefix=f"{prefix}.q_b_proj",
                weights=weights,
                bias=config.attention_bias,
            )

        self.kv_a_proj_with_mqa = get_linear(
            weight=weights.get_weights(f"{prefix}.kv_a_proj_with_mqa"),
            bias=(
                weights.get_tensor(f"{prefix}.kv_a_proj_with_mqa.bias")
                if config.attention_bias
                else None
            ),
        )

        self.kv_scales = get_kv_scales(weights, f"{prefix}")

        self.kv_a_layernorm = FastRMSNorm.load(
            prefix=f"{prefix}.kv_a_layernorm", weights=weights, eps=config.rms_norm_eps
        )

        self.kv_b_proj = TensorParallelColumnLinear.load(
            config,
            prefix=f"{prefix}.kv_b_proj",
            weights=weights,
            bias=config.attention_bias,
        )

        self.o_proj = TensorParallelRowLinear.load(
            config,
            prefix=f"{prefix}.o_proj",
            weights=weights,
            bias=False,
        )
        self.num_groups = self.num_heads // self.num_key_value_heads
        self.kv_head_mapping = torch.arange(
            0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
        ).repeat_interleave(self.num_groups)

    def forward(
        self,
        hidden_states: torch.Tensor,
        cos: torch.Tensor,
        sin: torch.Tensor,
        cu_seqlen_prefill: torch.Tensor,
        kv_cache: KVCache,
        block_tables: torch.Tensor,
        slots: torch.Tensor,
        seqlen: Seqlen,
        max_s: int,
    ):
        if self.q_lora_rank is None:
            query = self.q_proj(hidden_states)
        else:
            query = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))[0])
        query = query.view(-1, self.num_heads, self.head_size)

        _, query_pe = torch.split(
            query, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
        )

        compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
        compressed_kv, key_pe = torch.split(
            compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
        )

        key_pe = key_pe.view(-1, 1, self.qk_rope_head_dim)
        kv = self.kv_b_proj(self.kv_a_layernorm(compressed_kv.contiguous())[0]).view(
            -1, self.num_key_value_heads, self.qk_nope_head_dim + self.value_head_size
        )

        key_nope, value = torch.split(
            kv, [self.qk_nope_head_dim, self.value_head_size], dim=-1
        )

        batch_size, heads, head_dim = query_pe.shape
        query_pe = (
            query_pe.view(batch_size, heads, head_dim // 2, 2)
            .transpose(2, 3)
            .reshape(batch_size, heads, head_dim)
        )
        batch_size, heads, head_dim = key_pe.shape
        key_pe = (
            key_pe.view(batch_size, heads, head_dim // 2, 2)
            .transpose(2, 3)
            .reshape(batch_size, heads, head_dim)
        )
        self.rotary_emb(query_pe, key_pe, cos, sin)

        query[..., self.qk_nope_head_dim :] = query_pe
        key = torch.empty_like(query)
        key[..., : self.qk_nope_head_dim] = key_nope
        key[..., self.qk_nope_head_dim :] = key_pe

        # We need to pad the heads because Flash Attention does not support
        # qk and v with different head sizes.
        query = torch.nn.functional.pad(
            query, (0, self.head_pad_size - self.head_size), value=0
        )
        key = torch.nn.functional.pad(
            key, (0, self.head_pad_size - self.head_size), value=0
        )
        value = torch.nn.functional.pad(
            value, (0, self.head_pad_size - self.value_head_size), value=0
        )

        kv_cache.store(
            key=key,
            value=value,
            slots=slots,
            kv_scales=self.kv_scales,
        )

        # Prefill
        if cu_seqlen_prefill is not None:
            # flash attention
            attn_output = attention(
                query=query,
                key=key,
                value=value,
                kv_cache=kv_cache,
                kv_scales=self.kv_scales,
                seqlen=seqlen,
                block_tables=block_tables,
                softmax_scale=self.softmax_scale,
            )
        # Decode
        else:
            attn_output = paged_attention(
                query,
                kv_cache,
                self.kv_head_mapping,
                self.softmax_scale,
                block_tables,
                seqlen,
                max_s,
                kv_scales=self.kv_scales,
            )

        # Remove padding.
        attn_output = attn_output[..., : self.value_head_size]

        return self.o_proj(
            attn_output.reshape(-1, self.num_heads * self.value_head_size)
        )
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



server/text_generation_server/models/custom_modeling/flash_deepseek_v3_modeling.py [158:363]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
    def __init__(
        self,
        prefix: str,
        config,
        weights: Weights,
    ):
        super().__init__()
        self.num_heads = config.num_attention_heads
        self.hidden_size = config.hidden_size
        self.kv_lora_rank = config.kv_lora_rank
        self.q_lora_rank = config.q_lora_rank
        self.qk_nope_head_dim = config.qk_nope_head_dim
        self.qk_rope_head_dim = config.qk_rope_head_dim
        self.head_size = config.qk_nope_head_dim + config.qk_rope_head_dim
        self.value_head_size = config.v_head_dim
        self.head_pad_size = max(self.head_size, self.value_head_size)

        self.rotary_emb = PositionRotaryEmbedding.static(
            config=config,
            dim=self.qk_rope_head_dim,
            base=config.rope_theta,
            device=weights.device,
        )

        mscale = get_mscale(
            self.rotary_emb.scaling_factor, self.rotary_emb.mscale_all_dim
        )
        self.softmax_scale = self.head_size**-0.5 * mscale * mscale

        if self.num_heads % weights.process_group.size() != 0:
            raise ValueError(
                f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
                f"and `num_shards`: {weights.process_group.size()}"
            )
        self.num_heads = self.num_heads // weights.process_group.size()
        self.num_key_value_heads = (
            config.num_key_value_heads // weights.process_group.size()
        )

        if self.q_lora_rank is None:
            self.q_proj = TensorParallelColumnLinear.load(
                config,
                prefix=f"{prefix}.q_proj",
                weights=weights,
                bias=config.attention_bias,
            )
        else:
            self.q_a_proj = get_linear(
                weight=weights.get_weights(f"{prefix}.q_a_proj"),
                bias=(
                    weights.get_tensor(f"{prefix}.q_a_proj.bias")
                    if config.attention_bias
                    else None
                ),
            )
            self.q_a_layernorm = FastRMSNorm.load(
                prefix=f"{prefix}.q_a_layernorm",
                weights=weights,
                eps=config.rms_norm_eps,
            )
            self.q_b_proj = TensorParallelColumnLinear.load(
                config,
                prefix=f"{prefix}.q_b_proj",
                weights=weights,
                bias=config.attention_bias,
            )

        self.kv_a_proj_with_mqa = get_linear(
            weight=weights.get_weights(f"{prefix}.kv_a_proj_with_mqa"),
            bias=(
                weights.get_tensor(f"{prefix}.kv_a_proj_with_mqa.bias")
                if config.attention_bias
                else None
            ),
        )

        self.kv_scales = get_kv_scales(weights, f"{prefix}")

        self.kv_a_layernorm = FastRMSNorm.load(
            prefix=f"{prefix}.kv_a_layernorm", weights=weights, eps=config.rms_norm_eps
        )

        self.kv_b_proj = TensorParallelColumnLinear.load(
            config,
            prefix=f"{prefix}.kv_b_proj",
            weights=weights,
            bias=config.attention_bias,
        )

        self.o_proj = TensorParallelRowLinear.load(
            config,
            prefix=f"{prefix}.o_proj",
            weights=weights,
            bias=False,
        )
        self.num_groups = self.num_heads // self.num_key_value_heads
        self.kv_head_mapping = torch.arange(
            0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
        ).repeat_interleave(self.num_groups)

    def forward(
        self,
        hidden_states: torch.Tensor,
        cos: torch.Tensor,
        sin: torch.Tensor,
        cu_seqlen_prefill: torch.Tensor,
        kv_cache: KVCache,
        block_tables: torch.Tensor,
        slots: torch.Tensor,
        seqlen: Seqlen,
        max_s: int,
    ):
        if self.q_lora_rank is None:
            query = self.q_proj(hidden_states)
        else:
            query = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))[0])
        query = query.view(-1, self.num_heads, self.head_size)

        _, query_pe = torch.split(
            query, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
        )

        compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
        compressed_kv, key_pe = torch.split(
            compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
        )

        key_pe = key_pe.view(-1, 1, self.qk_rope_head_dim)
        kv = self.kv_b_proj(self.kv_a_layernorm(compressed_kv.contiguous())[0]).view(
            -1, self.num_key_value_heads, self.qk_nope_head_dim + self.value_head_size
        )

        key_nope, value = torch.split(
            kv, [self.qk_nope_head_dim, self.value_head_size], dim=-1
        )

        batch_size, heads, head_dim = query_pe.shape
        query_pe = (
            query_pe.view(batch_size, heads, head_dim // 2, 2)
            .transpose(2, 3)
            .reshape(batch_size, heads, head_dim)
        )
        batch_size, heads, head_dim = key_pe.shape
        key_pe = (
            key_pe.view(batch_size, heads, head_dim // 2, 2)
            .transpose(2, 3)
            .reshape(batch_size, heads, head_dim)
        )
        self.rotary_emb(query_pe, key_pe, cos, sin)

        query[..., self.qk_nope_head_dim :] = query_pe
        key = torch.empty_like(query)
        key[..., : self.qk_nope_head_dim] = key_nope
        key[..., self.qk_nope_head_dim :] = key_pe

        # We need to pad the heads because Flash Attention does not support
        # qk and v with different head sizes.
        query = torch.nn.functional.pad(
            query, (0, self.head_pad_size - self.head_size), value=0
        )
        key = torch.nn.functional.pad(
            key, (0, self.head_pad_size - self.head_size), value=0
        )
        value = torch.nn.functional.pad(
            value, (0, self.head_pad_size - self.value_head_size), value=0
        )

        kv_cache.store(
            key=key,
            value=value,
            slots=slots,
            kv_scales=self.kv_scales,
        )

        # Prefill
        if cu_seqlen_prefill is not None:
            # flash attention
            attn_output = attention(
                query=query,
                key=key,
                value=value,
                kv_cache=kv_cache,
                kv_scales=self.kv_scales,
                seqlen=seqlen,
                block_tables=block_tables,
                softmax_scale=self.softmax_scale,
            )
        # Decode
        else:
            attn_output = paged_attention(
                query,
                kv_cache,
                self.kv_head_mapping,
                self.softmax_scale,
                block_tables,
                seqlen,
                max_s,
                kv_scales=self.kv_scales,
            )

        # Remove padding.
        attn_output = attn_output[..., : self.value_head_size]

        return self.o_proj(
            attn_output.reshape(-1, self.num_heads * self.value_head_size)
        )
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



