megatron_patch/model/mixtral/transformer/attention.py (517 lines of code) (raw):

# Copyright (c) 2023 Alibaba PAI and Nvidia Megatron-LM Team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch from torch import Tensor from abc import ABC, abstractmethod from dataclasses import dataclass from typing import Union, Tuple from megatron.core import InferenceParams, parallel_state, tensor_parallel from megatron.core.models.common.embeddings.rope_utils import ( apply_rotary_pos_emb, apply_rotary_pos_emb_with_cos_sin, ) from megatron.core.parallel_state import ( get_data_parallel_group, get_data_parallel_rank, get_data_parallel_world_size, get_tensor_model_parallel_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, ) from megatron.core import parallel_state, tensor_parallel from megatron.core.models.common.embeddings.rotary_pos_embedding import apply_rotary_pos_emb from megatron.core.transformer.module import MegatronModule from megatron.core.transformer.spec_utils import ModuleSpec, build_module from megatron.core.utils import divide from megatron.core.transformer.enums import AttnMaskType from megatron.core.transformer.transformer_config import TransformerConfig try: from flash_attn import flash_attn_with_kvcache except: flash_attn_with_kvcache = None try: import transformer_engine # pylint: disable=unused-import HAVE_TE = True from megatron.core.extensions.transformer_engine import SplitAlongDim except ImportError: HAVE_TE = False SplitAlongDim = None @dataclass class SelfAttentionSubmodules: """ Configuration class for specifying the submodules of a self-attention. """ linear_qkv: Union[ModuleSpec, type] = None core_attention: Union[ModuleSpec, type] = None linear_proj: Union[ModuleSpec, type] = None q_layernorm: Union[ModuleSpec, type] = None k_layernorm: Union[ModuleSpec, type] = None @dataclass class CrossAttentionSubmodules: """ Configuration class for specifying the submodules of a cross-attention. """ linear_q: Union[ModuleSpec, type] = None linear_kv: Union[ModuleSpec, type] = None core_attention: Union[ModuleSpec, type] = None linear_proj: Union[ModuleSpec, type] = None class Attention(MegatronModule, ABC): """Attention layer abstract class. This layer only contains common modules required for the "self attn" and "cross attn" specializations. """ def __init__( self, config: TransformerConfig, submodules: Union[SelfAttentionSubmodules, CrossAttentionSubmodules], layer_number: int, attn_mask_type: AttnMaskType, attention_type: str, cp_comm_type: str = None, ): super().__init__(config=config) self.config = config self.layer_number = layer_number self.attn_mask_type = attn_mask_type self.attention_type = attention_type # For normal attention without groups, num_query_groups == num_attention_heads, # so these two will be the same self.query_projection_size = self.config.kv_channels * self.config.num_attention_heads self.kv_projection_size = self.config.kv_channels * self.config.num_query_groups # Per attention head and per partition values. world_size = parallel_state.get_tensor_model_parallel_world_size() self.hidden_size_per_attention_head = divide( self.query_projection_size, self.config.num_attention_heads ) self.num_attention_heads_per_partition = divide(self.config.num_attention_heads, world_size) self.num_query_groups_per_partition = divide(self.config.num_query_groups, world_size) self.core_attention = build_module( submodules.core_attention, config=self.config, layer_number=self.layer_number, attn_mask_type=self.attn_mask_type, attention_type=self.attention_type, cp_comm_type=cp_comm_type, ) self.checkpoint_core_attention = self.config.recompute_granularity == 'selective' # Output. self.linear_proj = build_module( submodules.linear_proj, self.query_projection_size, self.config.hidden_size, config=self.config, init_method=self.config.output_layer_init_method, bias=self.config.add_bias_linear, input_is_parallel=True, skip_bias_add=True, is_expert=False, tp_comm_buffer_name='proj', ) def _checkpointed_attention_forward( self, query, key, value, attention_mask, rotary_pos_emb=None, attn_mask_type=None, packed_seq_params=None, ): """Forward method with selective activation checkpointing.""" def custom_forward(*inputs): query = inputs[0] key = inputs[1] value = inputs[2] attention_mask = inputs[3] attn_mask_type = inputs[5] attn_mask_type = AttnMaskType(attn_mask_type.item()) output_ = self.core_attention( query, key, value, attention_mask, attn_mask_type=attn_mask_type, packed_seq_params=packed_seq_params, ) return output_ if attn_mask_type is None: attn_mask_type = self.attn_mask_type attn_mask_type = torch.tensor([attn_mask_type.value], dtype=torch.int) hidden_states = tensor_parallel.checkpoint( custom_forward, False, query, key, value, attention_mask, rotary_pos_emb, attn_mask_type ) return hidden_states def _allocate_memory(self, inference_max_sequence_length, batch_size, dim, dtype): """Allocate memory to store kv cache during inference.""" return torch.empty( inference_max_sequence_length, batch_size, self.num_query_groups_per_partition, dim, dtype=dtype, device=torch.cuda.current_device(), ) def _adjust_key_value_for_inference( self, inference_params: InferenceParams, query: Tensor, key: Tensor, value: Tensor, rotary_pos_emb: Tensor, rotary_pos_cos: Tensor = None, rotary_pos_sin: Tensor = None, ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: """ Saves the generated key and value tensors to the end of the buffers in inference_params. Returns the full size keys and values from the provided inference_params, as well as adjusted rotary_pos_emb. Returns a tuple: (key, value, rotary_pos_emb) """ attn_mask_type = self.attn_mask_type if inference_params is None: return query, key, value, rotary_pos_emb, attn_mask_type # ================================================= # Pre-allocate memory for key-values for inference. # ================================================= if self.layer_number not in inference_params.key_value_memory_dict: inf_max_seq_length = inference_params.max_sequence_length inf_max_batch_size = inference_params.max_batch_size inference_key_memory = self._allocate_memory( inf_max_seq_length, inf_max_batch_size, key.shape[-1], key.dtype ) inference_value_memory = self._allocate_memory( inf_max_seq_length, inf_max_batch_size, value.shape[-1], value.dtype ) inference_params.key_value_memory_dict[self.layer_number] = ( inference_key_memory, inference_value_memory, ) else: # Get the pre-allocated buffers for this layer inference_key_memory, inference_value_memory = inference_params.key_value_memory_dict[ self.layer_number ] if inference_params.sequence_len_offset > 0: # This should mean that we are past the prompt forward_step # and so we need to turn off masking attn_mask_type = AttnMaskType.no_mask batch_start = inference_params.batch_size_offset batch_end = batch_start + key.size(1) assert batch_end <= inference_key_memory.size(1) sequence_start = inference_params.sequence_len_offset sequence_end = sequence_start + key.size(0) assert sequence_end <= inference_key_memory.size(0) if self.config.flash_decode: assert ( rotary_pos_cos is not None and rotary_pos_sin is not None ), "Flash decoding requires precomputed cos and sin tensors" if inference_params.sequence_len_offset > 0: # Decode phase, not prefill rotary_pos_cos_q = rotary_pos_cos[sequence_end - 1 : sequence_end] rotary_pos_sin_q = rotary_pos_sin[sequence_end - 1 : sequence_end] rotary_pos_cos_k = rotary_pos_cos[sequence_end - 1 : sequence_end] rotary_pos_sin_k = rotary_pos_sin[sequence_end - 1 : sequence_end] else: rotary_pos_cos_q = rotary_pos_cos[:sequence_end] rotary_pos_sin_q = rotary_pos_sin[:sequence_end] rotary_pos_cos_k = rotary_pos_cos[:sequence_end] rotary_pos_sin_k = rotary_pos_sin[:sequence_end] # Flash Decoding assumes that the keys stored in the KV Cache already have RoPE applied. # Apply RoPE before we store the keys to make it compatible with flash decoding kernel. key = apply_rotary_pos_emb_with_cos_sin(key, rotary_pos_cos_k, rotary_pos_sin_k) query = apply_rotary_pos_emb_with_cos_sin(query, rotary_pos_cos_q, rotary_pos_sin_q) else: rotary_pos_cos_q = None rotary_pos_sin_q = None # Copy key and values. inference_key_memory[sequence_start:sequence_end, batch_start:batch_end, ...] = key inference_value_memory[sequence_start:sequence_end, batch_start:batch_end, ...] = value key = inference_key_memory[:sequence_end, batch_start:batch_end, ...] value = inference_value_memory[:sequence_end, batch_start:batch_end, ...] # adjust the key rotary positional embedding if rotary_pos_emb is None: return query, key, value, rotary_pos_emb, attn_mask_type q_pos_emb, k_pos_emb = rotary_pos_emb q_pos_emb = q_pos_emb[sequence_start:sequence_end, :, :, :] k_pos_emb = k_pos_emb[:sequence_end, :, :, :] rotary_pos_emb = (q_pos_emb, k_pos_emb) return query, key, value, rotary_pos_emb, attn_mask_type @abstractmethod def get_query_key_value_tensors(self, hidden_states, key_value_states): """ This method needs to be implemented based on whether the derived class is "self-attn" or "cross-attn". """ def flash_decoding( self, sequence_len_offset: Tensor, query_layer: Tensor, key_layer: Tensor, value_layer: Tensor, inference_key_memory: Tensor, inference_value_memory: Tensor, rotary_cos: Tensor, rotary_sin: Tensor, ) -> (Tensor, Tensor): """ The flash decoding kernel will do the following in a single execution: 1. Compute RoPE embedding with precomputed cos & sin tensors 2. Update the KV Cache 3. Performs the flash attention operation """ assert flash_attn_with_kvcache is not None, ( "Flash Decoding requires the flash_attn_with_kvcache kernel, " "available in the flash-attn package." ) cache_seqlens = sequence_len_offset - 1 q = query_layer.permute(1, 0, 2, 3) k = key_layer.permute(1, 0, 2, 3) v = value_layer.permute(1, 0, 2, 3) k_cache = inference_key_memory.permute(1, 0, 2, 3) v_cache = inference_value_memory.permute(1, 0, 2, 3) if rotary_cos is not None: rotary_cos = rotary_cos.to(query_layer.dtype) if rotary_sin is not None: rotary_sin = rotary_sin.to(query_layer.dtype) out = flash_attn_with_kvcache( q=q, k_cache=k_cache, v_cache=v_cache, k=k, v=v, rotary_cos=rotary_cos, rotary_sin=rotary_sin, cache_seqlens=cache_seqlens, rotary_interleaved=False, ) return out def forward( self, hidden_states, attention_mask, key_value_states=None, inference_params=None, rotary_pos_emb=None, rotary_pos_cos=None, rotary_pos_sin=None, packed_seq_params=None, ): """ Perform a forward pass through the attention module. """ # hidden_states: [sq, b, h] if self.config.flash_decode: rotary_pos_emb = None else: assert rotary_pos_cos is None and rotary_pos_sin is None # For self attention we just duplicate the rotary_pos_emb if it isn't already if rotary_pos_emb is not None and not isinstance(rotary_pos_emb, tuple): rotary_pos_emb = (rotary_pos_emb,) * 2 # ===================== # Query, Key, and Value # ===================== # Get the query, key and value tensors based on the type of attention - # self or cross attn. query, key, value = self.get_query_key_value_tensors(hidden_states, key_value_states) # =================================================== # Adjust key, value, and rotary_pos_emb for inference # =================================================== # This branch only runs in the decode phase of flash decoding and returns after the linear # projection. This conditional is not used in the prefill phase or non-flash-decoding cases. if ( self.config.flash_decode and inference_params is not None and self.layer_number in inference_params.key_value_memory_dict # Decode phase if key already exists ): assert inference_params.sequence_len_offset is not None inference_key_memory, inference_value_memory = inference_params.key_value_memory_dict[ self.layer_number ] output = self.flash_decoding( sequence_len_offset=inference_params.sequence_len_offset, query_layer=query, key_layer=key, value_layer=value, inference_key_memory=inference_key_memory, inference_value_memory=inference_value_memory, rotary_cos=rotary_pos_cos, rotary_sin=rotary_pos_sin, ) out = output.transpose(0, 1).contiguous() context_layer = out.view(out.size(0), out.size(1), -1) output, bias = self.linear_proj(context_layer) return output, bias query, key, value, rotary_pos_emb, attn_mask_type = self._adjust_key_value_for_inference( inference_params, query, key, value, rotary_pos_emb, rotary_pos_cos, rotary_pos_sin ) if packed_seq_params is not None: query = query.squeeze(1) key = key.squeeze(1) value = value.squeeze(1) # ================================================ # relative positional embedding (rotary embedding) # ================================================ if rotary_pos_emb is not None and not self.config.flash_decode: q_pos_emb, k_pos_emb = rotary_pos_emb if packed_seq_params is not None: if packed_seq_params.cu_seqlens_q_padded is not None: cu_seqlens_q = packed_seq_params.cu_seqlens_q_padded else: cu_seqlens_q = packed_seq_params.cu_seqlens_q if packed_seq_params.cu_seqlens_kv_padded is not None: cu_seqlens_kv = packed_seq_params.cu_seqlens_kv_padded else: cu_seqlens_kv = packed_seq_params.cu_seqlens_kv else: cu_seqlens_q = cu_seqlens_kv = None query = apply_rotary_pos_emb( query, q_pos_emb, config=self.config, cu_seqlens=cu_seqlens_q ) key = apply_rotary_pos_emb(key, k_pos_emb, config=self.config, cu_seqlens=cu_seqlens_kv) # TODO, can apply positional embedding to value_layer so it has # absolute positional embedding. # otherwise, only relative positional embedding takes effect # value_layer = apply_rotary_pos_emb(value_layer, k_pos_emb) # ================================== # core attention computation # ================================== if self.checkpoint_core_attention and self.training: core_attn_out = self._checkpointed_attention_forward( query, key, value, attention_mask, attn_mask_type=attn_mask_type, packed_seq_params=packed_seq_params, ) else: core_attn_out = self.core_attention( query, key, value, attention_mask, attn_mask_type=attn_mask_type, packed_seq_params=packed_seq_params, ) if packed_seq_params is not None: # reshape to same output shape as unpacked case # (t, np, hn) -> (t, b=1, h=np*hn) # t is the pack size = sum (sq_i) # note that batch is a dummy dimension in the packed case core_attn_out = core_attn_out.reshape(core_attn_out.size(0), 1, -1) # ================= # Output. [sq, b, h] # ================= output, bias = self.linear_proj(core_attn_out) return output, bias class SelfAttention(Attention): """Self-attention layer class Self-attention layer takes input with size [s, b, h] and returns output of the same size. """ def __init__( self, config: TransformerConfig, submodules: SelfAttentionSubmodules, layer_number: int, attn_mask_type=AttnMaskType.padding, cp_comm_type: str = None, ): super().__init__( config=config, submodules=submodules, layer_number=layer_number, attn_mask_type=attn_mask_type, attention_type="self", cp_comm_type=cp_comm_type, ) self.linear_qkv = build_module( submodules.linear_qkv, self.config.hidden_size, self.query_projection_size + 2 * self.kv_projection_size, config=self.config, init_method=self.config.init_method, gather_output=False, bias=self.config.add_bias_linear or self.config.add_qkv_bias, skip_bias_add=False, is_expert=False, tp_comm_buffer_name='qkv', ) if submodules.q_layernorm is not None: self.q_layernorm = build_module( submodules.q_layernorm, hidden_size=self.hidden_size_per_attention_head, config=self.config, eps=self.config.layernorm_epsilon, ) else: self.q_layernorm = None if submodules.k_layernorm is not None: self.k_layernorm = build_module( submodules.k_layernorm, hidden_size=self.hidden_size_per_attention_head, config=self.config, eps=self.config.layernorm_epsilon, ) else: self.k_layernorm = None def run_realtime_tests(self): """Performs a consistency check. This function makes sure that tensors across devices are the same during an experiment. This is often not guaranteed to be so because of silent hardware failures (eg, memory corruption loading a checkpoint, network traffic corruption encountered during data transmission). (TODO) In the future, more tensors should be checked across the training run and checked every X iterations. This is left for future work. Equality of tensors is probably not required; transmitting hashes is sufficient.""" if not self.config.qk_layernorm: return # check that all tensor parallel and data parallel ranks have the same # Q & K layernorm parameters. rank = get_data_parallel_rank() inputs = torch.stack( [ self.q_layernorm.weight.data, self.q_layernorm.bias.data, self.k_layernorm.weight.data, self.k_layernorm.bias.data, ] ) dp_list = [torch.empty_like(inputs) for _ in range(get_data_parallel_world_size())] dp_list[rank] = inputs torch.distributed.all_gather(dp_list, inputs, group=get_data_parallel_group()) def _compare(srcs, tgts, names, parallelism): assert len(srcs) == len(tgts) == len(names) for src, tgt, name in zip(srcs, tgts, names): assert torch.all(src == tgt), ( f"Discrepancy between {name} in {parallelism} ranks {i} and {rank}. " f"Diff: {torch.norm(src - tgt)}" ) for i, dp in enumerate(dp_list): q_w, q_b, k_w, k_b = torch.unbind(dp) _compare( [q_w, q_b, k_w, k_b], [ self.q_layernorm.weight.data, self.q_layernorm.bias.data, self.k_layernorm.weight.data, self.k_layernorm.bias.data, ], ["q_w", "q_b", "k_w", "k_b"], "DP", ) rank = get_tensor_model_parallel_rank() tp_list = [torch.empty_like(inputs) for _ in range(get_tensor_model_parallel_world_size())] tp_list[rank] = inputs torch.distributed.all_gather(tp_list, inputs, group=get_tensor_model_parallel_group()) for i, tp in enumerate(tp_list): q_w, q_b, k_w, k_b = torch.unbind(tp) _compare( [q_w, q_b, k_w, k_b], [ self.q_layernorm.weight.data, self.q_layernorm.bias.data, self.k_layernorm.weight.data, self.k_layernorm.bias.data, ], ["q_w", "q_b", "k_w", "k_b"], "TP", ) def get_query_key_value_tensors(self, hidden_states, key_value_states=None): """ Derives `query`, `key` and `value` tensors from `hidden_states`. """ # Attention heads [sq, b, h] --> [sq, b, ng * (np/ng + 2) * hn)] mixed_qkv, _ = self.linear_qkv(hidden_states) # [sq, b, hp] --> [sq, b, ng, (np/ng + 2) * hn] new_tensor_shape = mixed_qkv.size()[:-1] + ( self.num_query_groups_per_partition, ( (self.num_attention_heads_per_partition // self.num_query_groups_per_partition + 2) * self.hidden_size_per_attention_head ), ) mixed_qkv = mixed_qkv.view(*new_tensor_shape) split_arg_list = [ ( self.num_attention_heads_per_partition // self.num_query_groups_per_partition * self.hidden_size_per_attention_head ), self.hidden_size_per_attention_head, self.hidden_size_per_attention_head, ] if SplitAlongDim is not None: # [sq, b, ng, (np/ng + 2) * hn] # --> [sq, b, ng, np/ng * hn], [sq, b, ng, hn], [sq, b, ng, hn] (query, key, value) = SplitAlongDim(mixed_qkv, 3, split_arg_list) else: # [sq, b, ng, (np/ng + 2) * hn] # --> [sq, b, ng, np/ng * hn], [sq, b, ng, hn], [sq, b, ng, hn] (query, key, value) = torch.split(mixed_qkv, split_arg_list, dim=3) # [sq, b, ng, np/ng * hn] -> [sq, b, np, hn] query = query.reshape(query.size(0), query.size(1), -1, self.hidden_size_per_attention_head) if self.q_layernorm is not None: query = self.q_layernorm(query) if self.k_layernorm is not None: key = self.k_layernorm(key) if self.config.test_mode: self.run_realtime_tests() return query, key, value class CrossAttention(Attention): """Cross-attention layer class Cross-attention layer takes input with size [s, b, h] and context with size [s, b, h] and returns output of the same size. """ def __init__( self, config: TransformerConfig, submodules: CrossAttentionSubmodules, layer_number: int, attn_mask_type=AttnMaskType.padding, cp_comm_type: str = None, ): super().__init__( config=config, submodules=submodules, layer_number=layer_number, attn_mask_type=attn_mask_type, attention_type="cross", cp_comm_type=cp_comm_type, ) if self.config.num_query_groups != self.config.num_attention_heads: raise ValueError("Group query attention is not currently supported in cross attention.") assert self.query_projection_size == self.kv_projection_size self.linear_q = build_module( submodules.linear_q, self.config.hidden_size, self.query_projection_size, config=self.config, init_method=self.config.init_method, gather_output=False, bias=self.config.add_bias_linear, skip_bias_add=False, is_expert=False, ) self.linear_kv = build_module( submodules.linear_kv, self.config.hidden_size, 2 * self.kv_projection_size, config=self.config, init_method=self.config.init_method, gather_output=False, bias=self.config.add_bias_linear, skip_bias_add=False, is_expert=False, ) def get_query_key_value_tensors(self, hidden_states, key_value_states): """ Derives `query` tensor from `hidden_states`, and `key`/`value` tensors from `key_value_states`. """ # Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)] mixed_kv, _ = self.linear_kv(key_value_states) # [sk, b, (np * 2 * hn)] --> [sk, b, np, 2 * hn] new_tensor_shape = mixed_kv.size()[:-1] + ( self.num_attention_heads_per_partition, 2 * self.hidden_size_per_attention_head, ) mixed_kv = mixed_kv.view(*new_tensor_shape) # [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn] (key, value) = tensor_parallel.split_tensor_along_last_dim(mixed_kv, 2) # Attention head [sq, b, h] --> [sq, b, hp] query, _ = self.linear_q(hidden_states) # [sq, b, hp] --> [sq, b, np, hn] new_tensor_shape = query.size()[:-1] + ( self.num_attention_heads_per_partition, self.hidden_size_per_attention_head, ) query = query.view(*new_tensor_shape) return query, key, value