server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py (393 lines of code) (raw):

# coding=utf-8 # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. # # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX # and OPT implementations in this library. It has been modified from its # original forms to accommodate minor architectural differences compared # to GPT-NeoX and OPT used by the Meta AI team that trained the model. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch import torch.distributed from torch import nn from transformers.activations import ACT2FN from transformers.configuration_utils import PretrainedConfig from typing import Optional, List, Tuple from text_generation_server.layers.attention import ( paged_attention, attention, Seqlen, ) from text_generation_server.layers import ( TensorParallelRowLinear, TensorParallelColumnLinear, TensorParallelEmbedding, SpeculativeHead, get_linear, ) from text_generation_server.layers.attention.kv_cache import get_kv_scales from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.layernorm import ( FastRMSNorm, ) from text_generation_server.utils.weights import UnquantizedWeight class GemmaConfig(PretrainedConfig): def __init__( self, vocab_size=256128, hidden_size=3072, intermediate_size=24576, num_hidden_layers=28, num_attention_heads=16, num_key_value_heads=16, head_dim=256, hidden_act="gelu_pytorch_tanh", max_position_embeddings=8192, initializer_range=0.02, rms_norm_eps=1e-6, use_cache=True, pad_token_id=None, bos_token_id=1, eos_token_id=2, tie_word_embeddings=True, rope_theta=10000.0, rope_scaling=None, attention_bias=False, attention_dropout=0.0, **kwargs, ): self.vocab_size = vocab_size self.max_position_embeddings = max_position_embeddings self.hidden_size = hidden_size self.head_dim = head_dim self.intermediate_size = intermediate_size self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads # for backward compatibility if num_key_value_heads is None: num_key_value_heads = num_attention_heads self.num_key_value_heads = num_key_value_heads self.hidden_act = hidden_act self.initializer_range = initializer_range self.rms_norm_eps = rms_norm_eps self.use_cache = use_cache self.rope_theta = rope_theta self.rope_scaling = rope_scaling self.attention_bias = attention_bias self.attention_dropout = attention_dropout super().__init__( pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, tie_word_embeddings=tie_word_embeddings, **kwargs, ) class GemmaFastRMSNorm(FastRMSNorm): @classmethod def load(cls, prefix: str, weights, eps=1e-6): dtype = weights.dtype weights.dtype = torch.float32 weight = weights.get_tensor(f"{prefix}.weight") + 1 weights.dtype = dtype new = cls(weight, eps) new.dtype = dtype return new # perform the multiplication in full precision and downcast after def forward(self, hidden_states, residual=None): if residual is not None: hidden_states += residual residual = hidden_states hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) hidden_states = hidden_states * self.weight return hidden_states.to(self.dtype), residual def load_attention(config, prefix: str, weights): if config.num_attention_heads != config.num_key_value_heads: return _load_gqa(config, prefix, weights) else: return TensorParallelColumnLinear.load_multi( config, prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], dim=0, weights=weights, bias=False, ) def _load_gqa(config, prefix: str, weights): assert config.num_attention_heads % weights.process_group.size() == 0 weight = weights.get_multi_weights_col( prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], dim=0, ) if isinstance(weight, UnquantizedWeight): weight.weight = weight.weight.to(dtype=weights.dtype).to(device=weights.device) head_size = config.head_dim num_heads = config.num_attention_heads // weights.process_group.size() num_key_value_heads = config.num_key_value_heads // weights.process_group.size() assert list(weight.weight.shape) == [ (num_heads + 2 * num_key_value_heads) * head_size, config.hidden_size, ], f"{list(weight.weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}" return TensorParallelColumnLinear(get_linear(weight, bias=None)) class FlashGemmaAttention(torch.nn.Module): def __init__(self, prefix: str, config, weights, causal: bool): super().__init__() self.num_heads = config.num_attention_heads self.head_size = config.head_dim self.causal = causal self.rotary_emb = PositionRotaryEmbedding.static( config=config, dim=self.head_size, base=config.rope_theta, device=weights.device, ) self.softmax_scale = self.head_size**-0.5 if self.num_heads % weights.process_group.size() != 0: raise ValueError( f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} " f"and `num_shards`: {weights.process_group.size()}" ) self.num_heads = self.num_heads // weights.process_group.size() self.num_key_value_heads = ( config.num_key_value_heads // weights.process_group.size() ) self.query_key_value = load_attention(config, prefix, weights) self.kv_scales = get_kv_scales(weights, f"{prefix}") self.o_proj = TensorParallelRowLinear.load( config, prefix=f"{prefix}.o_proj", weights=weights, bias=False, ) self.num_groups = self.num_heads // self.num_key_value_heads self.kv_head_mapping = torch.arange( 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device ).repeat_interleave(self.num_groups) def forward( self, hidden_states, cos, sin, cu_seqlen_prefill, kv_cache, block_tables, slots, seqlen, max_s, ): qkv = self.query_key_value(hidden_states) query, kv = qkv.split( [ self.head_size * self.num_heads, 2 * self.head_size * self.num_key_value_heads, ], dim=1, ) query = query.view(-1, self.num_heads, self.head_size) kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size) self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) kv_cache.store( key=kv[:, 0], value=kv[:, 1], slots=slots, kv_scales=self.kv_scales, ) # Prefill if cu_seqlen_prefill is not None: # flash attention attn_output = attention( query=query, key=kv[:, 0], value=kv[:, 1], kv_cache=kv_cache, kv_scales=self.kv_scales, seqlen=seqlen, block_tables=block_tables, softmax_scale=self.softmax_scale, causal=self.causal, ) # Decode else: attn_output = paged_attention( query, kv_cache, self.kv_head_mapping, self.softmax_scale, block_tables, seqlen, max_s, kv_scales=self.kv_scales, ) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) class GemmaMLP(nn.Module): def __init__(self, prefix: str, config, weights): super().__init__() act = config.hidden_act self.act = ( ACT2FN[act] if "gelu" not in act else lambda x: torch.nn.functional.gelu( x, approximate=( "tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none" ), ) ) # Fuse gate and up proj self.gate_up_proj = TensorParallelColumnLinear.load_multi( config, prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"], weights=weights, dim=0, bias=False, ) self.down_proj = TensorParallelRowLinear.load( config, prefix=f"{prefix}.down_proj", weights=weights, bias=False, ) self.intermediate_size = ( config.intermediate_size // weights.process_group.size() ) def forward(self, hidden_states): gate_up_states = self.gate_up_proj(hidden_states) gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size) return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1]) class FlashGemmaLayer(nn.Module): def __init__(self, prefix: str, config, weights, causal: bool): super().__init__() self.self_attn = FlashGemmaAttention( prefix=f"{prefix}.self_attn", config=config, weights=weights, causal=causal ) self.mlp = GemmaMLP(prefix=f"{prefix}.mlp", config=config, weights=weights) self.input_layernorm = GemmaFastRMSNorm.load( prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps ) self.post_attention_layernorm = GemmaFastRMSNorm.load( prefix=f"{prefix}.post_attention_layernorm", weights=weights, eps=config.rms_norm_eps, ) def forward( self, hidden_states, residual, cos, sin, cu_seqlen_prefill, kv_cache, block_tables, slots, seqlen, max_s, ): normed_hidden_states, res = self.input_layernorm(hidden_states, residual) # Self Attention attn_output = self.self_attn( normed_hidden_states, cos, sin, cu_seqlen_prefill, kv_cache, block_tables, slots, seqlen, max_s, ) # faster post attention rms norm normed_attn_res_output, attn_res = self.post_attention_layernorm( attn_output, res ) mlp_output = self.mlp(normed_attn_res_output) return mlp_output, attn_res class FlashGemmaModel(torch.nn.Module): def __init__(self, prefix: str, config, weights, causal: bool): super().__init__() process_group = weights.process_group self.tp_rank = process_group.rank() self.tp_world_size = process_group.size() self.layers = nn.ModuleList( [ FlashGemmaLayer( prefix=f"{prefix}.layers.{layer_id}", config=config, weights=weights, causal=causal, ) for layer_id in range(config.num_hidden_layers) ] ) self.norm = GemmaFastRMSNorm.load( prefix=f"{prefix}.norm", weights=weights, eps=config.rms_norm_eps ) self.head_size = self.layers[0].self_attn.head_size self.num_heads = self.layers[0].self_attn.num_heads self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads def forward( self, inputs_embeds: torch.Tensor, position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, max_s: int, ) -> torch.Tensor: hidden_states = inputs_embeds # Get rotary cos and sin for this forward # Avoid to index in each layer cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( position_ids, max_s, hidden_states.dtype ) residual = None for i, layer in enumerate(self.layers): hidden_states, residual = layer( hidden_states, residual, cos, sin, cu_seqlen_prefill, kv_cache[i], block_tables, slots, seqlen, max_s, ) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states class FlashGemmaForCausalLM(torch.nn.Module): def __init__(self, prefix: str, config, weights, *, causal: bool = True): super().__init__() embed_norm = config.hidden_size**0.5 if not prefix: prefix = "model" else: prefix = f"{prefix}.model" self.embed_tokens = TensorParallelEmbedding( prefix=f"{prefix}.embed_tokens", weights=weights ) self.embed_tokens.weight *= embed_norm self.model = FlashGemmaModel( prefix=prefix, config=config, weights=weights, causal=causal ) self.lm_head = SpeculativeHead.load( prefix=( f"{prefix}.embed_tokens" if config.tie_word_embeddings else f"{prefix}.lm_head" ), config=config, weights=weights, ) def forward( self, input_ids: torch.Tensor, position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, max_s: int, prefill_cache_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: input_embeds = self.embed_tokens(input_ids) hidden_states = self.model( input_embeds, position_ids, cu_seqlen_prefill, kv_cache, block_tables, slots, seqlen, max_s, ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] logits, speculative_logits = self.lm_head(hidden_states) return logits, speculative_logits