megatron_patch/model/mixtral_bak/transformer/attention.py (322 lines of code) (raw):

# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. from abc import ABC, abstractmethod from dataclasses import dataclass from typing import Union import torch from megatron.core import parallel_state, tensor_parallel from megatron.core.models.common.embeddings.rotary_pos_embedding import apply_rotary_pos_emb from megatron.core.transformer.custom_layers.transformer_engine import SplitAlongDim from megatron.core.transformer.module import MegatronModule from megatron.core.transformer.spec_utils import ModuleSpec, build_module from megatron.core.utils import divide from megatron.core.transformer.enums import AttnMaskType from megatron.core.transformer.transformer_config import TransformerConfig @dataclass class SelfAttentionSubmodules: linear_qkv: Union[ModuleSpec, type] = None core_attention: Union[ModuleSpec, type] = None linear_proj: Union[ModuleSpec, type] = None @dataclass class CrossAttentionSubmodules: linear_q: Union[ModuleSpec, type] = None linear_kv: Union[ModuleSpec, type] = None core_attention: Union[ModuleSpec, type] = None linear_proj: Union[ModuleSpec, type] = None class Attention(MegatronModule, ABC): """Attention layer abstract class. This layer only contains common modules required for the "self attn" and "cross attn" specializations. """ def __init__( self, config: TransformerConfig, submodules: Union[SelfAttentionSubmodules, CrossAttentionSubmodules], layer_number: int, attn_mask_type: AttnMaskType, attention_type: str, ): super().__init__(config=config) self.config = config self.layer_number = layer_number self.attn_mask_type = attn_mask_type self.attention_type = attention_type # For normal attention without groups, num_query_groups == num_attention_heads, # so these two will be the same self.query_projection_size = self.config.kv_channels * self.config.num_attention_heads self.kv_projection_size = self.config.kv_channels * self.config.num_query_groups # Per attention head and per partition values. world_size = parallel_state.get_tensor_model_parallel_world_size() self.hidden_size_per_attention_head = divide( self.query_projection_size, self.config.num_attention_heads ) self.num_attention_heads_per_partition = divide(self.config.num_attention_heads, world_size) self.num_query_groups_per_partition = divide(self.config.num_query_groups, world_size) self.core_attention = build_module( submodules.core_attention, config=self.config, layer_number=self.layer_number, attn_mask_type=self.attn_mask_type, attention_type=self.attention_type, ) self.checkpoint_core_attention = self.config.recompute_granularity == 'selective' # Output. self.linear_proj = build_module( submodules.linear_proj, self.query_projection_size, self.config.hidden_size, config=self.config, init_method=self.config.output_layer_init_method, bias=self.config.add_bias_linear, input_is_parallel=True, skip_bias_add=True, is_expert=False, tp_comm_buffer_name='proj', ) def _checkpointed_attention_forward( self, query, key, value, attention_mask, rotary_pos_emb=None, attn_mask_type=None, packed_seq_params=None, ): """Forward method with selective activation checkpointing.""" def custom_forward(*inputs): query = inputs[0] key = inputs[1] value = inputs[2] attention_mask = inputs[3] attn_mask_type = inputs[5] attn_mask_type = AttnMaskType(attn_mask_type.item()) output_ = self.core_attention( query, key, value, attention_mask, attn_mask_type=attn_mask_type, packed_seq_params=packed_seq_params, ) return output_ if attn_mask_type is None: attn_mask_type = self.attn_mask_type attn_mask_type = torch.tensor([attn_mask_type.value], dtype=torch.int) hidden_states = tensor_parallel.checkpoint( custom_forward, False, query, key, value, attention_mask, rotary_pos_emb, attn_mask_type, ) return hidden_states def _allocate_memory(self, inference_max_sequence_length, batch_size, dtype): """Allocate memory to store kv cache during inference.""" return torch.empty( inference_max_sequence_length, batch_size, self.num_query_groups_per_partition, self.hidden_size_per_attention_head, dtype=dtype, device=torch.cuda.current_device(), ) def _adjust_key_value_for_inference(self, inference_params, key, value, rotary_pos_emb): """ Saves the generated key and value tensors to the end of the buffers in inference_params. Returns the full size keys and values from the provided inference_params, as well as adjusted rotary_pos_emb. Returns a tuple: (key, value, rotary_pos_emb) """ attn_mask_type = self.attn_mask_type if inference_params is None: return key, value, rotary_pos_emb, attn_mask_type # ================================================= # Pre-allocate memory for key-values for inference. # ================================================= is_first_step = False if self.layer_number not in inference_params.key_value_memory_dict: inf_max_seq_length = inference_params.max_sequence_length inf_max_batch_size = inference_params.max_batch_size inference_key_memory = self._allocate_memory( inf_max_seq_length, inf_max_batch_size, key.dtype ) inference_value_memory = self._allocate_memory( inf_max_seq_length, inf_max_batch_size, value.dtype ) inference_params.key_value_memory_dict[self.layer_number] = ( inference_key_memory, inference_value_memory, ) is_first_step = True else: # Get the pre-allocated buffers for this layer inference_key_memory, inference_value_memory = inference_params.key_value_memory_dict[ self.layer_number ] attn_mask_type = AttnMaskType.no_mask batch_start = inference_params.batch_size_offset batch_end = batch_start + key.size(1) assert batch_end <= inference_key_memory.size(1) sequence_start = inference_params.sequence_len_offset sequence_end = sequence_start + key.size(0) assert sequence_end <= inference_key_memory.size(0) # Copy key and values. inference_key_memory[sequence_start:sequence_end, batch_start:batch_end, ...] = key inference_value_memory[sequence_start:sequence_end, batch_start:batch_end, ...] = value key = inference_key_memory[:sequence_end, batch_start:batch_end, ...] value = inference_value_memory[:sequence_end, batch_start:batch_end, ...] # adjust the key rotary positional embedding if rotary_pos_emb is not None: q_pos_emb, k_pos_emb = rotary_pos_emb # need to cross check this condition during inference # if not set_inference_key_value_memory: if not is_first_step: # In inference, we compute one token at a time. # Select the correct positional embedding # (only the last token in the sequence) q_pos_emb = q_pos_emb[sequence_end - 1 : sequence_end] else: # In the first forward pass of inference, # we use the entire provided prefix. # q_pos_emb here has the rope embeddings of the entire # prefix + to-be-generated output so # we slice to just the prefix. q_pos_emb = q_pos_emb[:sequence_end, :, :, :] k_pos_emb = k_pos_emb[:sequence_end, :, :, :] rotary_pos_emb = (q_pos_emb, k_pos_emb) return key, value, rotary_pos_emb, attn_mask_type @abstractmethod def get_query_key_value_tensors(self, hidden_states, key_value_states): """ This method needs to be implemented based on whether the derived class is "self-attn" or "cross-attn". """ def forward( self, hidden_states, attention_mask, key_value_states=None, inference_params=None, rotary_pos_emb=None, packed_seq_params=None, ): # hidden_states: [sq, b, h] # For self attention we just duplicate the rotary_pos_emb if it isn't already if rotary_pos_emb is not None and not isinstance(rotary_pos_emb, tuple): rotary_pos_emb = (rotary_pos_emb,) * 2 # ===================== # Query, Key, and Value # ===================== # Get the query, key and value tensors based on the type of attention - # self or cross attn. query, key, value = self.get_query_key_value_tensors(hidden_states, key_value_states) # =================================================== # Adjust key, value, and rotary_pos_emb for inference # =================================================== key, value, rotary_pos_emb, attn_mask_type = self._adjust_key_value_for_inference( inference_params, key, value, rotary_pos_emb ) if packed_seq_params is not None: query = query.squeeze(1) key = key.squeeze(1) value = value.squeeze(1) # ================================================ # relative positional embedding (rotary embedding) # ================================================ if rotary_pos_emb is not None: q_pos_emb, k_pos_emb = rotary_pos_emb if packed_seq_params is not None: cu_seqlens_q = packed_seq_params.cu_seqlens_q cu_seqlens_kv = packed_seq_params.cu_seqlens_kv else: cu_seqlens_q = cu_seqlens_kv = None query = apply_rotary_pos_emb( query, q_pos_emb, fused=self.config.apply_rope_fusion, cu_seqlens=cu_seqlens_q ) key = apply_rotary_pos_emb( key, k_pos_emb, fused=self.config.apply_rope_fusion, cu_seqlens=cu_seqlens_kv ) # TODO, can apply positional embedding to value_layer so it has # absolute positional embedding. # otherwise, only relative positional embedding takes effect # value_layer = apply_rotary_pos_emb(value_layer, k_pos_emb) # ================================== # core attention computation # ================================== if self.checkpoint_core_attention: core_attn_out = self._checkpointed_attention_forward( query, key, value, attention_mask, attn_mask_type=attn_mask_type, packed_seq_params=packed_seq_params, ) else: core_attn_out = self.core_attention( query, key, value, attention_mask, attn_mask_type=attn_mask_type, packed_seq_params=packed_seq_params, ) if packed_seq_params is not None: # reshape to same output shape as unpacked case # (t, np, hn) -> (t, b=1, h=np*hn) # t is the pack size = sum (sq_i) # note that batch is a dummy dimension in the packed case core_attn_out = core_attn_out.reshape(core_attn_out.size(0), 1, -1) # ================= # Output. [sq, b, h] # ================= output, bias = self.linear_proj(core_attn_out) return output, bias class SelfAttention(Attention): """Self-attention layer class Self-attention layer takes input with size [s, b, h] and returns output of the same size. """ def __init__( self, config: TransformerConfig, submodules: SelfAttentionSubmodules, layer_number: int, attn_mask_type=AttnMaskType.padding, ): super().__init__( config=config, submodules=submodules, layer_number=layer_number, attn_mask_type=attn_mask_type, attention_type="self", ) self.linear_qkv = build_module( submodules.linear_qkv, self.config.hidden_size, self.query_projection_size + 2 * self.kv_projection_size, config=self.config, init_method=self.config.init_method, gather_output=False, bias=self.config.add_bias_linear, skip_bias_add=False, is_expert=False, tp_comm_buffer_name='qkv', ) def get_query_key_value_tensors(self, hidden_states, key_value_states=None): """ Derives `query`, `key` and `value` tensors from `hidden_states`. """ # Attention heads [sq, b, h] --> [sq, b, ng * (np/ng + 2) * hn)] mixed_qkv, _ = self.linear_qkv(hidden_states) # [sq, b, hp] --> [sq, b, ng, (np/ng + 2) * hn] new_tensor_shape = mixed_qkv.size()[:-1] + ( self.num_query_groups_per_partition, ( (self.num_attention_heads_per_partition // self.num_query_groups_per_partition + 2) * self.hidden_size_per_attention_head ), ) mixed_qkv = mixed_qkv.view(*new_tensor_shape) split_arg_list = [ ( self.num_attention_heads_per_partition // self.num_query_groups_per_partition * self.hidden_size_per_attention_head ), self.hidden_size_per_attention_head, self.hidden_size_per_attention_head, ] if SplitAlongDim is not None: # [sq, b, ng, (np/ng + 2) * hn] --> [sq, b, ng, np/ng * hn], [sq, b, ng, hn], [sq, b, ng, hn] (query, key, value) = SplitAlongDim(mixed_qkv, 3, split_arg_list,) else: # [sq, b, ng, (np/ng + 2) * hn] --> [sq, b, ng, np/ng * hn], [sq, b, ng, hn], [sq, b, ng, hn] (query, key, value) = torch.split(mixed_qkv, split_arg_list, dim=3,) # [sq, b, ng, np/ng * hn] -> [sq, b, np, hn] query = query.reshape(query.size(0), query.size(1), -1, self.hidden_size_per_attention_head) return query, key, value class CrossAttention(Attention): """Cross-attention layer class Cross-attention layer takes input with size [s, b, h] and context with size [s, b, h] and returns output of the same size. """ def __init__( self, config: TransformerConfig, submodules: CrossAttentionSubmodules, layer_number: int, attn_mask_type=AttnMaskType.padding, ): super().__init__( config=config, submodules=submodules, layer_number=layer_number, attn_mask_type=attn_mask_type, attention_type="cross", ) if self.config.num_query_groups != self.config.num_attention_heads: raise ValueError( f"Group query attention is not currently supported in cross attention." ) assert self.query_projection_size == self.kv_projection_size self.linear_q = build_module( submodules.linear_q, self.config.hidden_size, self.query_projection_size, config=self.config, init_method=self.config.init_method, gather_output=False, bias=self.config.add_bias_linear, skip_bias_add=False, is_expert=False, ) self.linear_kv = build_module( submodules.linear_kv, self.config.hidden_size, 2 * self.kv_projection_size, config=self.config, init_method=self.config.init_method, gather_output=False, bias=self.config.add_bias_linear, skip_bias_add=False, is_expert=False, ) def get_query_key_value_tensors(self, hidden_states, key_value_states): """ Derives `query` tensor from `hidden_states`, and `key`/`value` tensors from `key_value_states`. """ # Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)] mixed_kv, _ = self.linear_kv(key_value_states) # [sk, b, (np * 2 * hn)] --> [sk, b, np, 2 * hn] new_tensor_shape = mixed_kv.size()[:-1] + ( self.num_attention_heads_per_partition, 2 * self.hidden_size_per_attention_head, ) mixed_kv = mixed_kv.view(*new_tensor_shape) # [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn] (key, value) = tensor_parallel.split_tensor_along_last_dim(mixed_kv, 2) # Attention head [sq, b, h] --> [sq, b, hp] query, _ = self.linear_q(hidden_states) # [sq, b, hp] --> [sq, b, np, hn] new_tensor_shape = query.size()[:-1] + ( self.num_attention_heads_per_partition, self.hidden_size_per_attention_head, ) query = query.view(*new_tensor_shape) return query, key, value