server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py (372 lines of code) (raw):

# coding=utf-8 # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. # # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX # and OPT implementations in this library. It has been modified from its # original forms to accommodate minor architectural differences compared # to GPT-NeoX and OPT used by the Meta AI team that trained the model. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch import torch.distributed from torch import nn from transformers.activations import ACT2FN from typing import Optional, List, Tuple from text_generation_server.layers.attention import ( paged_attention, attention, Seqlen, ) from text_generation_server.layers import ( TensorParallelRowLinear, TensorParallelColumnLinear, TensorParallelEmbedding, SpeculativeHead, get_linear, ) from text_generation_server.layers.attention.kv_cache import get_kv_scales def load_qkv(config, prefix: str, weights, head_size, num_heads): if config.quantize == "gptq": return _load_qkv_gptq( config, prefix, weights, ) elif config.quantize == "marlin": raise RuntimeError( "GPT-2 models with marlin quantization are not yet supported" ) else: return _load_qkv(config, prefix, weights, head_size, num_heads) def _load_qkv_gptq(config, prefix: str, weights): world_size = weights.process_group.size() rank = weights.process_group.rank() # Weights weight = weights.get_weights_col_packed_qkv( f"{prefix}.c_attn", config.num_attention_heads, config.num_attention_heads, ) # Bias slice_ = weights._get_slice(f"{prefix}.c_attn.bias") shape = slice_.get_shape() total_size = shape[0] assert total_size % 3 == 0, f"Prepacked is not divisible by {3}" single_size = total_size // 3 assert single_size % world_size == 0 block_size = single_size // world_size start = rank * block_size stop = (rank + 1) * block_size tensors = [] for i in range(3): tensor = slice_[start + i * single_size : stop + i * single_size] tensors.append(tensor) bias = torch.cat(tensors, dim=0) bias = bias.to(device=weights.device) return TensorParallelColumnLinear(get_linear(weight, bias)) def _load_qkv(config, prefix: str, weights, head_size, num_heads): """Load QKV from a single, transposed matrix.""" slice_ = weights._get_slice(f"{prefix}.c_attn.weight") shape = slice_.get_shape() total_size = shape[1] assert total_size % 3 == 0, f"Prepacked is not divisible by {3}" world_size = weights.process_group.size() single_size = total_size // 3 assert single_size % world_size == 0 rank = weights.process_group.rank() # Weights block_size = single_size // world_size start = rank * block_size stop = (rank + 1) * block_size tensors = [] for i in range(3): tensor = slice_[:, start + i * single_size : stop + i * single_size] tensors.append(tensor) weight = torch.cat(tensors, dim=1).T weight = weight.to(dtype=weights.dtype) weight = weight.to(device=weights.device) # Bias slice_ = weights._get_slice(f"{prefix}.c_attn.bias") shape = slice_.get_shape() total_size = shape[0] single_size = total_size // 3 block_size = single_size // world_size assert single_size % world_size == 0 start = rank * block_size stop = (rank + 1) * block_size b = [] for i in range(3): tensor = slice_[start + i * single_size : stop + i * single_size] b.append(tensor) bias = torch.cat(b, dim=0) bias = bias.to(dtype=weights.dtype) bias = bias.to(device=weights.device) assert list(bias.shape) == [ 3 * num_heads * head_size ], f"{weight.shape} != {[3 * num_heads * head_size]}" return TensorParallelColumnLinear(get_linear(weight, bias)) def load_row(config, prefix: str, weights, bias: bool): """load_row, but with transposed weight matrices.""" if config.quantize == "gptq": weight = weights.get_weights_row(prefix) else: weight = weights.get_sharded(f"{prefix}.weight", dim=0).T if bias and weights.process_group.rank() == 0: # Rank is only on the first rank process bias = weights.get_tensor(f"{prefix}.bias") else: bias = None return TensorParallelRowLinear( get_linear(weight, bias), process_group=weights.process_group ) def load_col(config, prefix: str, weights, bias: bool): """load_col, but with transposed weight matrices.""" if config.quantize == "gptq": weight = weights.get_multi_weights_col([prefix], dim=1) else: weight = weights.get_sharded(f"{prefix}.weight", dim=1).T if bias: bias = weights.get_sharded(f"{prefix}.bias", dim=0) else: bias = None return TensorParallelColumnLinear(get_linear(weight, bias)) class FlashGPT2Attention(torch.nn.Module): def __init__( self, prefix: str, config, weights, ): super().__init__() self.num_heads = config.num_attention_heads self.hidden_size = config.hidden_size self.head_size = self.hidden_size // self.num_heads self.softmax_scale = self.head_size**-0.5 if self.num_heads % weights.process_group.size() != 0: raise ValueError( f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} " f"and `num_shards`: {weights.process_group.size()}" ) self.num_heads = self.num_heads // weights.process_group.size() self.query_key_value = load_qkv( config, prefix=prefix, weights=weights, head_size=self.head_size, num_heads=self.num_heads, ) self.kv_scales = get_kv_scales(weights, f"{prefix}") self.o_proj = load_row( config, prefix=f"{prefix}.c_proj", weights=weights, bias=True, ) self.kv_head_mapping = torch.arange( 0, self.num_heads, dtype=torch.int32, device=weights.device ) def forward( self, hidden_states, cu_seqlen_prefill, kv_cache, block_tables, slots, seqlen, max_s, ): query, key, value = self.query_key_value(hidden_states).split( self.head_size * self.num_heads, dim=1 ) query = query.view(-1, self.num_heads, self.head_size) key = key.view(-1, self.num_heads, self.head_size) value = value.view(-1, self.num_heads, self.head_size) kv_cache.store( key=key, value=value, slots=slots, kv_scales=self.kv_scales, ) # Prefill if cu_seqlen_prefill is not None: # flash attention attn_output = attention( query=query, key=key, value=value, kv_cache=kv_cache, kv_scales=self.kv_scales, seqlen=seqlen, block_tables=block_tables, softmax_scale=self.softmax_scale, ) # Decode else: attn_output = paged_attention( query, kv_cache, self.kv_head_mapping, self.softmax_scale, block_tables, seqlen, max_s, kv_scales=self.kv_scales, ) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) class GPT2MLP(nn.Module): def __init__(self, prefix: str, config, weights): super().__init__() act = config.activation_function self.act = ( ACT2FN[act] if "gelu" not in act else lambda x: torch.nn.functional.gelu( x, approximate=( "tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none" ), ) ) self.c_fc = load_col( config, prefix=f"{prefix}.c_fc", weights=weights, bias=True ) self.c_proj = load_row( config, prefix=f"{prefix}.c_proj", weights=weights, bias=True, ) intermediate_size = ( config.n_inner if config.n_inner is not None else 4 * config.hidden_size ) self.intermediate_size = intermediate_size // weights.process_group.size() def forward(self, hidden_states): hidden_states = self.c_fc(hidden_states) hidden_states = self.act(hidden_states) return self.c_proj(hidden_states) class FlashGPT2Layer(nn.Module): def __init__(self, prefix: str, config, weights): super().__init__() self.self_attn = FlashGPT2Attention( prefix=f"{prefix}.attn", config=config, weights=weights ) self.mlp = GPT2MLP(prefix=f"{prefix}.mlp", config=config, weights=weights) self.input_layernorm = nn.LayerNorm.load( prefix=f"{prefix}.ln_1", weights=weights, eps=config.layer_norm_epsilon ) self.post_attention_layernorm = nn.LayerNorm.load( prefix=f"{prefix}.ln_2", weights=weights, eps=config.layer_norm_epsilon, ) def forward( self, hidden_states, residual, cu_seqlen_prefill, kv_cache, block_tables, slots, seqlen, max_s, ): residual = hidden_states hidden_states = self.input_layernorm(hidden_states) # Self Attention attn_output = self.self_attn( hidden_states, cu_seqlen_prefill, kv_cache, block_tables, slots, seqlen, max_s, ) hidden_states = attn_output + residual residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) mlp_output = self.mlp(hidden_states) return residual + mlp_output, residual class FlashGPT2Model(torch.nn.Module): def __init__(self, prefix: str, config, weights): super().__init__() process_group = weights.process_group self.tp_rank = process_group.rank() self.tp_world_size = process_group.size() self.layers = nn.ModuleList( [ FlashGPT2Layer( prefix=( f"h.{layer_id}" if not prefix else f"{prefix}.h.{layer_id}" ), config=config, weights=weights, ) for layer_id in range(config.num_hidden_layers) ] ) self.norm = nn.LayerNorm.load( prefix="ln_f" if not prefix else f"{prefix}.ln_f", weights=weights, eps=config.layer_norm_epsilon, ) self.gradient_checkpointing = False self.head_size = self.layers[0].self_attn.head_size self.num_heads = self.layers[0].self_attn.num_heads def forward( self, inputs_embeds: torch.Tensor, position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, max_s: int, true_max_s: int, prefill_cache_indices: Optional[torch.Tensor], ) -> torch.Tensor: hidden_states = inputs_embeds residual = None for i, layer in enumerate(self.layers): hidden_states, residual = layer( hidden_states, residual, cu_seqlen_prefill, kv_cache[i], block_tables, slots, seqlen, max_s, ) hidden_states = self.norm(hidden_states) return hidden_states class FlashGPT2ForCausalLM(torch.nn.Module): def __init__(self, prefix: str, config, weights): super().__init__() self.embed_tokens = TensorParallelEmbedding( prefix=("wte" if not prefix else f"{prefix}.wte"), weights=weights, ) self.embed_positions = TensorParallelEmbedding( prefix=("wpe" if not prefix else f"{prefix}.wpe"), weights=weights, ) self.model = FlashGPT2Model(prefix, config, weights) self.lm_head = SpeculativeHead.load( config, prefix="wte" if not prefix else f"{prefix}.wte", weights=weights, ) def forward( self, input_ids: torch.Tensor, position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, max_s: int, prefill_cache_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: token_embeds = self.embed_tokens(input_ids) position_embeds = self.embed_positions(position_ids) inputs_embeds = token_embeds + position_embeds hidden_states = self.model( inputs_embeds, position_ids, cu_seqlen_prefill, kv_cache, block_tables, slots, seqlen, max_s, true_max_s=max_s, prefill_cache_indices=prefill_cache_indices, ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] logits, speculative_logits = self.lm_head(hidden_states) return logits, speculative_logits