server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py (632 lines of code) (raw):
# coding=utf-8
# Copyright 2022 HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.distributed
from torch import nn
from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple, Any
from text_generation_server.layers.attention.kv_cache import get_kv_scales
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.kernels import load_kernel
if SYSTEM == "ipex":
from intel_extension_for_pytorch.llm.modules import GatedMLPMOE
elif SYSTEM == "cuda":
moe_kernels = load_kernel(module="moe", repo_id="kernels-community/moe")
else:
import moe_kernels
from text_generation_server.layers.attention import (
paged_attention,
attention,
Seqlen,
)
from text_generation_server.layers import (
FastLinear,
TensorParallelRowLinear,
TensorParallelColumnLinear,
TensorParallelEmbedding,
SpeculativeHead,
get_linear,
)
from text_generation_server.layers.rotary import (
PositionRotaryEmbedding,
)
from text_generation_server.layers.layernorm import (
FastLayerNorm,
)
class DbrxAttentionConfig(PretrainedConfig):
def __init__(
self,
attn_pdrop: float = 0,
clip_qkv: Optional[float] = None,
kv_n_heads: int = 1,
rope_theta: float = 10000.0,
**kwargs: Any,
):
super().__init__(**kwargs)
self.attn_pdrop = attn_pdrop
self.clip_qkv = clip_qkv
self.kv_n_heads = kv_n_heads
self.rope_theta = rope_theta
for k in ["model_type"]:
if k in kwargs:
kwargs.pop(k)
if len(kwargs) != 0:
raise ValueError(f"Found unknown {kwargs=}")
class DbrxFFNConfig(PretrainedConfig):
def __init__(
self,
ffn_act_fn: Optional[dict] = None,
ffn_hidden_size: int = 3584,
moe_num_experts: int = 4,
moe_top_k: int = 1,
moe_jitter_eps: Optional[float] = None,
moe_loss_weight: float = 0.01,
moe_normalize_expert_weights: Optional[float] = 1,
uniform_expert_assignment: bool = False,
**kwargs: Any,
):
super().__init__()
if ffn_act_fn is None:
ffn_act_fn = {"name": "silu"}
self.ffn_act_fn = ffn_act_fn
self.ffn_hidden_size = ffn_hidden_size
self.moe_num_experts = moe_num_experts
self.moe_top_k = moe_top_k
self.moe_jitter_eps = moe_jitter_eps
self.moe_loss_weight = moe_loss_weight
self.moe_normalize_expert_weights = moe_normalize_expert_weights
self.uniform_expert_assignment = uniform_expert_assignment
if uniform_expert_assignment:
raise ValueError("`uniform_expert_assignment = True` is not supported")
for k in ["model_type"]:
if k in kwargs:
kwargs.pop(k)
if len(kwargs) != 0:
raise ValueError(f"Found unknown {kwargs=}")
class DbrxConfig(PretrainedConfig):
attribute_map = {
"hidden_size": "d_model",
"num_attention_heads": "n_heads",
"num_hidden_layers": "n_layers",
}
def __init__(
self,
d_model: int = 2048,
n_heads: int = 16,
n_layers: int = 24,
max_seq_len: int = 2048,
vocab_size: int = 32000,
resid_pdrop: float = 0.0,
emb_pdrop: float = 0.0,
attn_config: Optional[DbrxAttentionConfig] = None,
ffn_config: Optional[DbrxFFNConfig] = None,
use_cache: bool = True,
initializer_range: float = 0.02,
output_router_logits: bool = False,
router_aux_loss_coef: float = 0.05,
**kwargs: Any,
):
if attn_config is None:
self.attn_config = DbrxAttentionConfig()
elif isinstance(attn_config, dict):
self.attn_config = DbrxAttentionConfig(**attn_config)
else:
self.attn_config = attn_config
if ffn_config is None:
self.ffn_config = DbrxFFNConfig()
elif isinstance(ffn_config, dict):
self.ffn_config = DbrxFFNConfig(**ffn_config)
else:
self.ffn_config = ffn_config
self.d_model = d_model
self.n_heads = n_heads
self.n_layers = n_layers
self.max_seq_len = max_seq_len
self.vocab_size = vocab_size
self.resid_pdrop = resid_pdrop
self.emb_pdrop = emb_pdrop
self.use_cache = use_cache
self.initializer_range = initializer_range
self.output_router_logits = output_router_logits
self.router_aux_loss_coef = router_aux_loss_coef
tie_word_embeddings = kwargs.pop("tie_word_embeddings", False)
if tie_word_embeddings:
raise ValueError("tie_word_embeddings is not supported for Dbrx models.")
super().__init__(
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
@property
def num_key_value_heads(self):
# We can't use the attribute map, since this the number of KV
# heads is not top-level.
return self.attn_config.kv_n_heads
def promote_scalar(x: torch.Tensor) -> torch.Tensor:
return x.view(1) if len(x.size()) == 0 else x
def load_attention(config, prefix, weights):
return TensorParallelColumnLinear.load_qkv(
config,
prefix=f"{prefix}.Wqkv",
weights=weights,
bias=False,
num_heads=config.n_heads,
num_key_value_heads=config.attn_config.kv_n_heads,
)
def _load_experts(config, prefix, weights):
world_size = weights.process_group.size()
rank = weights.process_group.rank()
assert (
config.ffn_config.ffn_hidden_size % world_size == 0
), f"The chosen size {config.ffn_config.ffn_hidden_size} is not compatible with sharding on {world_size} shards"
expert_size = config.ffn_config.ffn_hidden_size
block_size = expert_size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = torch.empty(
(config.ffn_config.moe_num_experts * block_size, config.d_model),
dtype=weights.dtype,
device=weights.device,
)
slice_ = weights._get_slice(f"{prefix}")
for i in range(config.ffn_config.moe_num_experts):
offset = i * expert_size
expert_slice = slice_[start + offset : stop + offset]
tensor[i * block_size : (i + 1) * block_size] = expert_slice.to(
dtype=weights.dtype
).to(device=weights.device)
return tensor
def _load_experts_quantized(config, prefix, weights, cls):
world_size = weights.process_group.size()
rank = weights.process_group.rank()
assert (
config.ffn_config.ffn_hidden_size % world_size == 0
), f"The chosen size {config.ffn_config.ffn_hidden_size} is not compatible with sharding on {world_size} shards"
expert_size = config.ffn_config.ffn_hidden_size
block_size = expert_size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
slice_ = weights._get_slice(f"{prefix}")
experts = []
for i in range(config.ffn_config.moe_num_experts):
if config.quantize in ["gptq", "awq"]:
raise NotImplementedError(
"Dbrx does not support gptq/awq quantization yet."
)
else:
offset = i * expert_size
expert_slice = (
slice_[start + offset : stop + offset]
.to(dtype=weights.dtype)
.to(device=weights.device)
)
if cls == TensorParallelRowLinear:
expert_slice = expert_slice.t().contiguous()
linear = get_linear(expert_slice, None)
experts.append(cls(linear, weights.process_group))
else:
linear = get_linear(expert_slice, None)
experts.append(cls(linear))
return experts
class DbrxAttention(torch.nn.Module):
def __init__(
self,
prefix: str,
config,
weights,
):
super().__init__()
self.clip_qkv = config.attn_config.clip_qkv
self.num_heads = config.n_heads
self.hidden_size = config.d_model
self.head_size = self.hidden_size // self.num_heads
self.rotary_emb = PositionRotaryEmbedding.static(
config=config,
dim=self.head_size,
base=config.attn_config.rope_theta,
device=weights.device,
)
self.softmax_scale = self.head_size**-0.5
if self.num_heads % weights.process_group.size() != 0:
raise ValueError(
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
f"and `num_shards`: {weights.process_group.size()}"
)
self.num_heads = self.num_heads // weights.process_group.size()
self.num_key_value_heads = (
config.attn_config.kv_n_heads // weights.process_group.size()
)
self.query_key_value = load_attention(config, prefix, weights)
self.kv_scales = get_kv_scales(weights, f"{prefix}")
self.o_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.out_proj",
weights=weights,
bias=False,
)
self.num_groups = self.num_heads // self.num_key_value_heads
self.kv_head_mapping = torch.arange(
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
).repeat_interleave(self.num_groups)
def forward(
self,
hidden_states,
cos,
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
max_s,
):
qkv = self.query_key_value(hidden_states)
if self.clip_qkv is not None:
qkv = qkv.clamp(min=-self.clip_qkv, max=self.clip_qkv)
query, kv = qkv.split(
[
self.head_size * self.num_heads,
2 * self.head_size * self.num_key_value_heads,
],
dim=1,
)
query = query.view(-1, self.num_heads, self.head_size)
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
kv_cache.store(
key=kv[:, 0],
value=kv[:, 1],
slots=slots,
kv_scales=self.kv_scales,
)
# Prefill
if cu_seqlen_prefill is not None:
# flash attention
attn_output = attention(
query=query,
key=kv[:, 0],
value=kv[:, 1],
kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen,
block_tables=block_tables,
softmax_scale=self.softmax_scale,
)
# Decode
else:
attn_output = paged_attention(
query,
kv_cache,
self.kv_head_mapping,
self.softmax_scale,
block_tables,
seqlen,
max_s,
kv_scales=self.kv_scales,
)
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
class DbrxNormAttentionNorm(nn.Module):
def __init__(
self,
prefix: str,
config,
weights,
):
super().__init__()
self.norm_1 = FastLayerNorm.load_no_bias(
prefix=f"{prefix}.norm_1", weights=weights, eps=1e-5
)
self.self_attn = DbrxAttention(
prefix=f"{prefix}.attn", config=config, weights=weights
)
self.norm_2 = FastLayerNorm.load_no_bias(
prefix=f"{prefix}.norm_2",
weights=weights,
eps=1e-5,
)
def forward(
self,
hidden_states,
residual,
cos,
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
max_s,
):
normed_hidden_states, res = self.norm_1(hidden_states, residual)
# Self Attention
attn_output = self.self_attn(
normed_hidden_states,
cos,
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
max_s,
)
# faster post attention rms norm
normed_attn_res_output, attn_res = self.norm_2(attn_output, res)
return normed_attn_res_output, attn_res
@torch.jit.script
def select_experts(
gate_logits: torch.Tensor, top_k: int, moe_normalize_expert_weights: int
):
# all_probs: (sequence_length, n_experts) and upcast for softmax
all_probs = torch.nn.functional.softmax(gate_logits, dim=1, dtype=torch.float)
# weights, selected_experts: (sequence_length, top-k)
weights, selected_experts = torch.topk(all_probs, top_k, dim=-1)
if moe_normalize_expert_weights:
weights = weights / torch.norm(
weights, p=moe_normalize_expert_weights, dim=-1, keepdim=True
)
weights = weights.view(-1)
selected_experts = selected_experts.view(-1)
return selected_experts, weights
@torch.jit.script
def round_up(x: torch.Tensor, value: int):
return torch.div(x + (value - 1), value, rounding_mode="trunc") * value
class BlockSparseMoE(nn.Module):
def __init__(self, prefix, config: DbrxConfig, weights):
super().__init__()
self.moe_normalize_expert_weights = (
config.ffn_config.moe_normalize_expert_weights
)
self.hidden_dim = config.d_model
self.ffn_dim = config.ffn_config.ffn_hidden_size // weights.process_group.size()
self.num_experts = config.ffn_config.moe_num_experts
self.top_k = config.ffn_config.moe_top_k
act = config.ffn_config.ffn_act_fn["name"]
if "gelu" in act:
self.act = lambda x: torch.nn.functional.gelu(
x,
approximate=(
"tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
),
)
elif "silu" in act:
self.act = torch.nn.functional.silu
else:
self.act = ACT2FN[act]
# gating
self.gate = FastLinear.load(
config, f"{prefix}.router.layer", weights, bias=False
)
# merged expert weights, all of size (n_experts * ffn_dim, hidden_dim)
w1 = _load_experts(config, f"{prefix}.experts.mlp.w1", weights).view(
self.num_experts, self.ffn_dim, self.hidden_dim
)
v1 = _load_experts(config, f"{prefix}.experts.mlp.v1", weights).view(
self.num_experts, self.ffn_dim, self.hidden_dim
)
self.wv1 = torch.cat([w1, v1], dim=1)
self.w2 = (
_load_experts(config, f"{prefix}.experts.mlp.w2", weights)
.view(self.num_experts, self.ffn_dim, self.hidden_dim)
.transpose(1, 2)
.contiguous()
)
self.process_group = weights.process_group
if SYSTEM == "ipex":
self.ipex_fused_moe = GatedMLPMOE(
W13=self.wv1, W2=self.w2, use_prepack=True
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# router_logits: (num_tokens, n_experts)
router_logits = self.gate(x)
if SYSTEM == "ipex":
out = self.ipex_fused_moe(
hidden_states=x,
router_logits=router_logits,
top_k=self.top_k,
renormalize=self.moe_normalize_expert_weights,
use_grouped_topk=False,
num_expert_group=None,
topk_group=None,
)
else:
out = moe_kernels.fused_moe(
x,
self.wv1,
self.w2,
router_logits,
self.top_k,
renormalize=self.moe_normalize_expert_weights,
inplace=True,
)
# Reduce sum
if self.process_group.size() > 1:
torch.distributed.all_reduce(out, group=self.process_group)
return out.view(*x.shape)
class DenseMoE(nn.Module):
def __init__(self, prefix, config: DbrxConfig, weights):
super().__init__()
self.moe_normalize_expert_weights = (
config.ffn_config.moe_normalize_expert_weights
)
self.hidden_dim = config.d_model
self.ffn_dim = config.ffn_config.ffn_hidden_size // weights.process_group.size()
self.num_experts = config.ffn_config.moe_num_experts
self.top_k = config.ffn_config.moe_top_k
act = config.ffn_config.ffn_act_fn["name"]
if "gelu" in act:
self.act = lambda x: torch.nn.functional.gelu(
x,
approximate=(
"tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
),
)
elif "silu" in act:
self.act = torch.nn.functional.silu
else:
self.act = ACT2FN[act]
# gating
self.gate = FastLinear.load(
config, f"{prefix}.router.layer", weights, bias=False
)
self.w1 = _load_experts_quantized(
config,
prefix=f"{prefix}.experts.mlp.w1",
weights=weights,
cls=TensorParallelColumnLinear,
)
self.w2 = _load_experts_quantized(
config,
prefix=f"{prefix}.experts.mlp.w2",
weights=weights,
cls=TensorParallelRowLinear,
)
self.v1 = _load_experts_quantized(
config,
prefix=f"{prefix}.experts.mlp.v1",
weights=weights,
cls=TensorParallelColumnLinear,
)
self.process_group = weights.process_group
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
x: (sequence_length, model_dim)
gate_logits: (sequence_length, n_experts)
"""
# optional reshape
input_shape = x.shape
x = x.view(-1, input_shape[-1])
# gate_logits: (sequence_length, n_experts)
gate_logits = self.gate(x)
# all_probs: (sequence_length, n_experts) and upcast for softmax
weights = torch.nn.functional.softmax(gate_logits, dim=1, dtype=torch.float)
if self.top_k < self.num_experts:
_, not_selected_experts = torch.topk(
weights,
self.num_experts - self.top_k,
largest=False,
sorted=False,
dim=1,
)
# Mask not selected experts
weights.scatter_(1, not_selected_experts, 0)
# Re-normalize
if self.moe_normalize_expert_weights:
weights = weights / torch.norm(
weights, p=self.moe_normalize_expert_weights, dim=-1, keepdim=True
)
weights = weights.to(x.dtype)
# Final output tensor
out = x.new_zeros(x.shape[0], self.hidden_dim)
for i in range(self.num_experts):
h = self.act(self.w1[i](x)) * self.v1[i](x)
h = self.w2[i](h, reduce=False)
# Add expert output to out with masking
out += h * weights[:, i].view(-1, 1)
# Reduce sum
if self.process_group.size() > 1:
torch.distributed.all_reduce(out, group=self.process_group)
return out
class DbrxLayer(nn.Module):
def __init__(self, prefix: str, layer_id, config, weights):
super().__init__()
prefix = f"{prefix}.blocks.{layer_id}"
self.attn = DbrxNormAttentionNorm(
prefix=f"{prefix}.norm_attn_norm", config=config, weights=weights
)
moe_cls = BlockSparseMoE if config.quantize is None else DenseMoE
self.moe = moe_cls(f"{prefix}.ffn", config, weights)
def forward(
self,
hidden_states,
residual,
cos,
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
max_s,
):
# Self Attention
attn_output, attn_res = self.attn(
hidden_states,
residual,
cos,
sin,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
max_s,
)
moe_output = self.moe(attn_output)
return moe_output, attn_res
class DbrxModel(torch.nn.Module):
def __init__(self, prefix: str, config, weights):
super().__init__()
self.embed_tokens = TensorParallelEmbedding(
prefix=f"{prefix}.wte", weights=weights
)
self.layers = nn.ModuleList(
[
DbrxLayer(
prefix,
layer_id,
config,
weights,
)
for layer_id in range(config.n_layers)
]
)
self.norm = FastLayerNorm.load_no_bias(
prefix=f"{prefix}.norm_f", weights=weights, eps=1e-5
)
self.head_size = self.layers[0].attn.self_attn.head_size
self.num_heads = self.layers[0].attn.self_attn.num_heads
self.num_key_value_heads = self.layers[0].attn.self_attn.num_key_value_heads
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
max_s: int,
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
# Get rotary cos and sin for this forward
# Avoid to index in each layer
cos, sin = self.layers[0].attn.self_attn.rotary_emb.get_cos_sin(
position_ids, max_s, hidden_states.dtype
)
residual = None
for i, layer in enumerate(self.layers):
hidden_states, residual = layer(
hidden_states,
residual,
cos,
sin,
cu_seqlen_prefill,
kv_cache[i],
block_tables,
slots,
seqlen,
max_s,
)
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
class FlashDbrxForCausalLM(torch.nn.Module):
def __init__(self, prefix: str, config, weights):
super().__init__()
if not prefix:
prefix = "transformer"
else:
prefix = f"{prefix}.transformer"
self.model = DbrxModel(prefix, config, weights)
self.lm_head = SpeculativeHead.load(
config,
prefix="lm_head",
weights=weights,
)
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
hidden_states = self.model(
input_ids,
position_ids,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
max_s,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits, speculative_logits = self.lm_head(hidden_states)
return logits, speculative_logits