server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py (444 lines of code) (raw):

import torch import torch.distributed from torch import nn from transformers.activations import ACT2FN from typing import Optional, List, Tuple from text_generation_server.layers.attention import ( paged_attention, attention, Seqlen, ) from text_generation_server.layers import ( TensorParallelRowLinear, TensorParallelColumnLinear, SpeculativeHead, TensorParallelEmbedding, get_linear, ) from text_generation_server.layers.attention.kv_cache import get_kv_scales from text_generation_server.layers.gptq import GPTQWeightsLoader from text_generation_server.layers.layernorm import ( FastLayerNorm, ) def load_multi_mqa( config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size ): if config.quantize == "gptq": return _load_multi_mqa_gptq( config, prefix, weights, bias, head_size, num_heads, hidden_size ) elif config.quantize == "marlin": raise RuntimeError( "santacoder models with marlin quantization are not yet supported" ) else: return _load_multi_mqa( config, prefix, weights, bias, head_size, num_heads, hidden_size ) def _load_multi_mqa_gptq( config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size ): from text_generation_server.layers.gptq import GPTQWeight if any("c_attn" in k for k in weights.routing.keys()) and not config.transpose: world_size = weights.process_group.size() rank = weights.process_group.rank() slice_ = weights._get_slice(f"{prefix}.c_attn.qweight") shape = slice_.get_shape() block_size = (shape[1] - 2 * head_size) // world_size start = rank * block_size stop = (rank + 1) * block_size assert (shape[1] - 2 * head_size) % world_size == 0 q_tensor = slice_[:, start:stop] kv_tensor = slice_[:, -2 * head_size :] qweight = torch.cat([q_tensor, kv_tensor], dim=1) qweight = qweight.to(device=weights.device) slice_ = weights._get_slice(f"{prefix}.c_attn.scales") shape = slice_.get_shape() block_size = (shape[1] - 2 * head_size) // world_size start = rank * block_size stop = (rank + 1) * block_size assert (shape[1] - 2 * head_size) % world_size == 0 q_tensor = slice_[:, start:stop] kv_tensor = slice_[:, -2 * head_size :] scales = torch.cat([q_tensor, kv_tensor], dim=1) scales = scales.to(device=weights.device) slice_ = weights._get_slice(f"{prefix}.c_attn.qzeros") shape = slice_.get_shape() block_size = (shape[1] - (2 * head_size) * 4 // 32) // world_size start = rank * block_size stop = (rank + 1) * block_size assert 2 * head_size % (32 // 4) == 0 q_tensor = slice_[:, start:stop] kv_tensor = slice_[:, -2 * head_size * 4 // 32 :] qzeros = torch.cat([q_tensor, kv_tensor], dim=1) qzeros = qzeros.to(device=weights.device) loader = weights.weights_loader assert isinstance(loader, GPTQWeightsLoader) loader._get_gptq_params(weights) if loader.quant_method == "gptq": g_idx = weights.get_tensor(f"{prefix}.c_attn.g_idx") g_idx = g_idx.to(device=weights.device) elif loader.quant_method == "awq": g_idx = None from text_generation_server.layers.awq.conversion_utils import ( fast_awq_to_gptq, ) qweight, qzeros = fast_awq_to_gptq(qweight, qzeros) from text_generation_server.layers.gptq import HAS_EXLLAMA weight = GPTQWeight( qweight=qweight, qzeros=qzeros, scales=scales, g_idx=g_idx, bits=loader.bits, groupsize=loader.groupsize, use_awq_kernel=loader.quantize == "awq", use_exllama=HAS_EXLLAMA, ) if bias: slice_ = weights._get_slice(f"{prefix}.c_attn.bias") shape = slice_.get_shape() block_size = (shape[0] - 2 * head_size) // world_size assert (shape[0] - 2 * head_size) % world_size == 0 q_tensor = slice_[start:stop] start = rank * block_size stop = (rank + 1) * block_size q_tensor = slice_[start:stop] kv_tensor = slice_[-2 * head_size :] bias = torch.cat([q_tensor, kv_tensor], dim=0) bias = bias.to(device=weights.device) return TensorParallelColumnLinear(get_linear(weight, bias)) else: raise NotImplementedError("Gptq loading with santacoder is not implemented") def _load_multi_mqa( config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size ): if any("c_attn" in k for k in weights.routing.keys()): slice_ = weights._get_slice(f"{prefix}.c_attn.weight") shape = slice_.get_shape() world_size = weights.process_group.size() rank = weights.process_group.rank() if config.transpose: block_size = (shape[1] - 2 * head_size) // world_size start = rank * block_size stop = (rank + 1) * block_size assert (shape[1] - 2 * head_size) % world_size == 0 q_tensor = slice_[:, start:stop] kv_tensor = slice_[:, -2 * head_size :] weight = torch.cat([q_tensor, kv_tensor], dim=1).T else: block_size = (shape[0] - 2 * head_size) // world_size start = rank * block_size stop = (rank + 1) * block_size assert (shape[0] - 2 * head_size) % world_size == 0 q_tensor = slice_[start:stop] kv_tensor = slice_[-2 * head_size :] weight = torch.cat([q_tensor, kv_tensor], dim=0) if bias: slice_ = weights._get_slice(f"{prefix}.c_attn.bias") shape = slice_.get_shape() block_size = (shape[0] - 2 * head_size) // world_size assert (shape[0] - 2 * head_size) % world_size == 0 start = rank * block_size stop = (rank + 1) * block_size q_tensor = slice_[start:stop] kv_tensor = slice_[-2 * head_size :] bias = torch.cat([q_tensor, kv_tensor], dim=0) else: if config.transpose: w = [ weights.get_sharded(f"{prefix}.q_attn.weight", dim=1).T, weights.get_tensor(f"{prefix}.kv_attn.weight").T, ] weight = torch.cat(w, dim=0) else: w = [ weights.get_sharded(f"{prefix}.q_attn.weight", dim=0), weights.get_tensor(f"{prefix}.kv_attn.weight"), ] weight = torch.cat(w, dim=1) if bias: b = [ weights.get_sharded(f"{prefix}.q_attn.bias", dim=0), weights.get_tensor(f"{prefix}.kv_attn.bias"), ] bias = torch.cat(b, dim=0) else: bias = None weight = weight.to(dtype=weights.dtype).to(device=weights.device) assert list(weight.shape) == [ (num_heads + 2) * head_size, hidden_size, ], f"{weight.shape} != {[(num_heads + 2) * head_size, hidden_size]}" if bias is not None: bias = bias.to(dtype=weights.dtype).to(device=weights.device) assert list(bias.shape) == [ (num_heads + 2) * head_size ], f"{weight.shape} != {[(num_heads + 2) * head_size]}" return TensorParallelColumnLinear(get_linear(weight, bias)) def load_col(config, prefix: str, weights, bias: bool): if config.transpose: weight = weights.get_sharded(f"{prefix}.weight", dim=1).T else: weight = weights.get_multi_weights_col([prefix], dim=0) if bias: bias = weights.get_sharded(f"{prefix}.bias", dim=0) else: bias = None return TensorParallelColumnLinear(get_linear(weight, bias)) def load_row(config, prefix: str, weights, bias: bool): if config.transpose: weight = weights.get_sharded(f"{prefix}.weight", dim=0).T else: weight = weights.get_weights_row(prefix) if bias and weights.process_group.rank() == 0: # Rank is only on the first rank process bias = weights.get_tensor(f"{prefix}.bias") else: bias = None return TensorParallelRowLinear( get_linear(weight, bias), process_group=weights.process_group ) class FlashMQAttention(torch.nn.Module): def __init__(self, prefix, config, weights): super().__init__() num_heads = config.num_attention_heads hidden_size = config.hidden_size self.num_heads = num_heads self.hidden_size = hidden_size self.head_size = hidden_size // num_heads if self.num_heads % weights.process_group.size() != 0: raise ValueError( f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} " f"and `num_shards`: {weights.process_group.size()}" ) self.num_heads = self.num_heads // weights.process_group.size() self.softmax_scale = self.head_size ** (-0.5) self.c_attn = load_multi_mqa( config, prefix=prefix, weights=weights, bias=True, head_size=self.head_size, hidden_size=hidden_size, num_heads=self.num_heads, ) self.c_proj = load_row( config, prefix=f"{prefix}.c_proj", weights=weights, bias=True ) self.kv_scales = get_kv_scales(weights, f"{prefix}") self.kv_head_mapping = torch.zeros( self.num_heads, dtype=torch.int32, device=weights.device ) def forward( self, hidden_states, cu_seqlen_prefill, kv_cache, block_tables, slots, seqlen, max_s, ): qkv = self.c_attn(hidden_states) # Split query from key_value query, key_value = qkv.split( [self.head_size * self.num_heads, 2 * self.head_size], dim=1 ) # Prepare query and key_value for indexing query = query.view(-1, self.num_heads, self.head_size) key_value = key_value.view(-1, 2, 1, self.head_size) kv_cache.store( key=key_value[:, 0], value=key_value[:, 1], slots=slots, kv_scales=self.kv_scales, ) # Prefill if cu_seqlen_prefill is not None: # flash attention attn_output = attention( query=query, key=key_value[:, 0], value=key_value[:, 1], kv_cache=kv_cache, kv_scales=self.kv_scales, seqlen=seqlen, block_tables=block_tables, softmax_scale=self.softmax_scale, ) # Decode else: attn_output = paged_attention( query, kv_cache, self.kv_head_mapping, self.softmax_scale, block_tables, seqlen, max_s, kv_scales=self.kv_scales, ) return self.c_proj(attn_output.view(-1, self.num_heads * self.head_size)) class MLP(nn.Module): def __init__(self, prefix, config, weights): super().__init__() act = config.activation_function self.act = ( ACT2FN[act] if "gelu" not in act else lambda x: torch.nn.functional.gelu( x, approximate=( "tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none" ), ) ) self.c_fc = load_col( config, prefix=f"{prefix}.c_fc", weights=weights, bias=True ) self.c_proj = load_row( config, prefix=f"{prefix}.c_proj", weights=weights, bias=True ) def forward(self, hidden_states): hidden_states = self.c_fc(hidden_states) hidden_states = self.act(hidden_states) hidden_states = self.c_proj(hidden_states) return hidden_states class Block(nn.Module): def __init__(self, prefix: str, layer_id, config, weights): super().__init__() prefix = f"{prefix}.h.{layer_id}" self.ln_1 = FastLayerNorm.load( prefix=f"{prefix}.ln_1", weights=weights, eps=config.layer_norm_epsilon ) self.ln_2 = FastLayerNorm.load( prefix=f"{prefix}.ln_2", weights=weights, eps=config.layer_norm_epsilon ) self.self_attn = FlashMQAttention( prefix=f"{prefix}.attn", config=config, weights=weights, ) self.mlp = MLP( prefix=f"{prefix}.mlp", config=config, weights=weights, ) def forward( self, hidden_states, residual, cu_seqlen_prefill, kv_cache, block_tables, slots, seqlen, max_s, ): hidden_states, residual = self.ln_1(hidden_states, residual) hidden_states = self.self_attn( hidden_states, cu_seqlen_prefill, kv_cache, block_tables, slots, seqlen, max_s, ) hidden_states, residual = self.ln_2(hidden_states, residual) mlp_output = self.mlp(hidden_states) return mlp_output, residual class FlashSantacoderModel(nn.Module): def __init__(self, prefix: str, config, weights): super().__init__() self.config = config self.process_group = weights.process_group self.wte = TensorParallelEmbedding( prefix=f"{prefix}.wte", weights=weights, reduce=False, ) self.wpe = TensorParallelEmbedding( prefix=f"{prefix}.wpe", weights=weights, reduce=False, ) self.layers = nn.ModuleList( [ Block( prefix, layer_id, config, weights, ) for layer_id in range(config.num_hidden_layers) ] ) self.ln_f = FastLayerNorm.load( prefix="transformer.ln_f", weights=weights, eps=config.layer_norm_epsilon ) self.head_size = self.layers[0].self_attn.head_size self.num_heads = self.layers[0].self_attn.num_heads def forward( self, input_ids: torch.Tensor, position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, max_s: int, ) -> torch.Tensor: hidden_states = self.wte(input_ids) + self.wpe(position_ids) if self.process_group.size() > 1: torch.distributed.all_reduce(hidden_states, group=self.process_group) residual = None for i, layer in enumerate(self.layers): hidden_states, residual = layer( hidden_states, residual, cu_seqlen_prefill, kv_cache[i], block_tables, slots, seqlen, max_s, ) hidden_states, _ = self.ln_f(hidden_states, residual) return hidden_states class FlashSantacoderForCausalLM(nn.Module): def __init__(self, prefix, config, weights): super().__init__() if not prefix: prefix = "transformer" else: prefix = f"{prefix}.transformer" config.transpose = config.architectures[0].startswith("GPT2") self.model = FlashSantacoderModel(prefix, config, weights) self.lm_head = SpeculativeHead.load( config, prefix=f"{prefix}.wte", weights=weights ) def forward( self, input_ids: torch.Tensor, position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, max_s: int, prefill_cache_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, ) -> torch.Tensor: hidden_states = self.model( input_ids, position_ids, cu_seqlen_prefill, kv_cache, block_tables, slots, seqlen, max_s, ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] logits = self.lm_head(hidden_states) return logits