megatron_patch/model/deepseek_v2/transformer_layer.py (226 lines of code) (raw):

# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. from abc import ABC from dataclasses import dataclass, field from typing import Dict, Optional, Union import torch from megatron.core import parallel_state from megatron.core.dist_checkpointing.mapping import ShardedStateDict from megatron.core.dist_checkpointing.utils import apply_prefix_mapping from megatron.core.transformer.cuda_graphs import CudaGraphManager from megatron.core.transformer.identity_op import IdentityFuncOp, IdentityOp from megatron.core.transformer.module import MegatronModule from megatron.core.transformer.spec_utils import ModuleSpec, build_module from megatron.core.transformer.transformer_config import TransformerConfig from megatron.core.utils import make_viewless_tensor @dataclass class TransformerLayerSubmodules: """ Configuration class for specifying the submodules of a transformer layer. This class defines the structure and default implementations for various components of a transformer layer, allowing for flexible customization of the layer's architecture. Args: input_layernorm (Union[ModuleSpec, type]): Specification for the input layer normalization. self_attention (Union[ModuleSpec, type]): Specification for the self-attention mechanism. self_attn_bda (Union[ModuleSpec, type]): Specification for the bias-dropout-add operation after self-attention. pre_cross_attn_layernorm (Union[ModuleSpec, type]): Specification for the layer normalization before cross-attention. cross_attention (Union[ModuleSpec, type]): Specification for the cross-attention mechanism. cross_attn_bda (Union[ModuleSpec, type]): Specification for the bias-dropout-add operation after cross-attention. pre_mlp_layernorm (Union[ModuleSpec, type]): Specification for the layer normalization before the MLP. mlp (Union[ModuleSpec, type]): Specification for the MLP. mlp_bda (Union[ModuleSpec, type]): Specification for the bias-dropout-add operation after the MLP. sharded_state_dict_keys_map (Dict[str, str]): Mapping for sharded tensor keys to be applied in the `sharded_state_dict` method. """ input_layernorm: Union[ModuleSpec, type] = IdentityOp self_attention: Union[ModuleSpec, type] = IdentityOp self_attn_bda: Union[ModuleSpec, type] = IdentityFuncOp pre_cross_attn_layernorm: Union[ModuleSpec, type] = IdentityOp cross_attention: Union[ModuleSpec, type] = IdentityOp cross_attn_bda: Union[ModuleSpec, type] = IdentityFuncOp pre_mlp_layernorm: Union[ModuleSpec, type] = IdentityOp mlp: Union[ModuleSpec, type] = IdentityOp mlp_dense: Union[ModuleSpec, type] = IdentityOp mlp_bda: Union[ModuleSpec, type] = IdentityFuncOp # Mapping for sharded tensor keys to be applied in `sharded_state_dict` method sharded_state_dict_keys_map: Dict[str, str] = field(default_factory=dict) class BaseTransformerLayer(ABC): """A common parent class for `TransformerLayer` like implementations. A dummy class that is subclassed by similar `TransformerLayer`s e.g. the `TransformerLayer` in this file and possibly other `TransformerLayer` implementations that aim to use `TransformerBlock` as the base module. The main purpose is to check if any layer (or module) provided in the spec is a subclass of this class to allow fanning-out of that spec for all the layers in the `TransformerBlock`. See `_get_block_submodules` method implementation in `transformer_block.py` file for more details. """ def __init__(self): pass class TransformerLayer(MegatronModule, BaseTransformerLayer): """A single transformer layer. Transformer layer takes input with size [s, b, h] and returns an output of the same size. """ def __init__( self, config: TransformerConfig, submodules: TransformerLayerSubmodules, layer_number: int = 1, hidden_dropout: float = None, ): super().__init__(config=config) if config.enable_cuda_graph and self.training: assert ( not config.cpu_offloading and config.recompute_granularity is None ), "Cudagraphs not supported" self.cudagraph_manager = CudaGraphManager() self.submodules_config = submodules self.layer_number = layer_number + self._get_layer_offset() self.hidden_dropout = config.hidden_dropout if hidden_dropout is None else hidden_dropout # [Module 1: Input Layernorm] Optional Layernorm on the input data # TODO: add pytorch only layernorm self.input_layernorm = build_module( submodules.input_layernorm, config=self.config, hidden_size=self.config.hidden_size, eps=self.config.layernorm_epsilon, ) attention_optional_kwargs = {} if config.cp_comm_type is not None: if isinstance(config.cp_comm_type, list): attention_optional_kwargs["cp_comm_type"] = config.cp_comm_type[self.layer_number] else: attention_optional_kwargs["cp_comm_type"] = config.cp_comm_type # [Module 2: SelfAttention] self.self_attention = build_module( submodules.self_attention, config=self.config, layer_number=layer_number, **attention_optional_kwargs, ) # [Module 3: BiasDropoutFusion] self.self_attn_bda = build_module(submodules.self_attn_bda) # [Module 4: Post SelfAttention] Optional Layernorm after self-attn self.pre_cross_attn_layernorm = build_module( submodules.pre_cross_attn_layernorm, config=self.config, hidden_size=self.config.hidden_size, eps=self.config.layernorm_epsilon, ) # [Module 5: CrossAttention] self.cross_attention = build_module( submodules.cross_attention, config=self.config, layer_number=layer_number, **attention_optional_kwargs, ) # [Module 6: BiasDropoutFusion] self.cross_attn_bda = build_module(submodules.cross_attn_bda, config=self.config) # [Module 7: Pre MLP] Optional Layernorm before MLP self.pre_mlp_layernorm = build_module( submodules.pre_mlp_layernorm, config=self.config, hidden_size=self.config.hidden_size, eps=self.config.layernorm_epsilon, ) # [Module 8: MLP block] # TODO how to set the gpt_layer_spec.py when we have moe_frequency > 1, # where MLP and MoE layer both appear alternately? if self.layer_number > self.config.moe_layer_freq: self.mlp = build_module(submodules.mlp, config=self.config) else: self.mlp = build_module(submodules.mlp_dense, config=self.config) """ self.mlp = build_module(submodules.mlp, config=self.config) """ if hasattr(self.mlp, 'set_layer_number'): self.mlp.set_layer_number(self.layer_number) # [Module 9: BiasDropoutFusion] self.mlp_bda = build_module(submodules.mlp_bda) # @jcasper how should we handle nvfuser? # Set bias+dropout+add fusion grad_enable execution handler. # TORCH_MAJOR = int(torch.__version__.split('.')[0]) # TORCH_MINOR = int(torch.__version__.split('.')[1]) # use_nvfuser = TORCH_MAJOR > 1 or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10) # self.bias_dropout_add_exec_handler = nullcontext if use_nvfuser else torch.enable_grad self.bias_dropout_add_exec_handler = torch.enable_grad def _get_layer_offset(self): """Get the index number of this layer, given the level of pipelining.""" pipeline_rank = parallel_state.get_pipeline_model_parallel_rank() if not parallel_state.is_inside_encoder(): pipeline_rank = ( pipeline_rank - parallel_state.get_pipeline_model_parallel_decoder_start() ) num_layers_per_pipeline_rank = ( self.config.num_layers // self.config.pipeline_model_parallel_size ) if parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None: vp_rank = parallel_state.get_virtual_pipeline_model_parallel_rank() vp_size = parallel_state.get_virtual_pipeline_model_parallel_world_size() total_num_layers = self.config.num_layers num_layers_per_virtual_rank = num_layers_per_pipeline_rank // vp_size total_virtual_chunks = total_num_layers // vp_size offset = vp_rank * total_virtual_chunks + (pipeline_rank * num_layers_per_virtual_rank) else: # Each stage gets a contiguous set of layers. if self.config.pipeline_model_parallel_size > 1: if ( self.config.first_pipeline_num_layers is not None or self.config.last_pipeline_num_layers is not None ): # Calculate number of pipelines for distributing layers middle_pipeline_stages = self.config.pipeline_model_parallel_size middle_pipeline_stages -= sum( [ 1 if x is not None else 0 for x in ( self.config.first_pipeline_num_layers, self.config.last_pipeline_num_layers, ) ] ) # Calculate layers to distribute first_pipeline_offset = ( 0 if self.config.first_pipeline_num_layers is None else self.config.first_pipeline_num_layers ) last_pipeline_offset = ( 0 if self.config.last_pipeline_num_layers is None else self.config.last_pipeline_num_layers ) middle_num_layers = ( self.config.num_layers - first_pipeline_offset - last_pipeline_offset ) if middle_pipeline_stages > 0: num_layers_per_pipeline_rank = middle_num_layers // middle_pipeline_stages else: num_layers_per_pipeline_rank = 0 middle_pipeline_rank = ( pipeline_rank if self.config.first_pipeline_num_layers is None else pipeline_rank - 1 ) if pipeline_rank == 0: offset = 0 else: offset = ( middle_pipeline_rank * num_layers_per_pipeline_rank ) + first_pipeline_offset else: offset = pipeline_rank * num_layers_per_pipeline_rank else: offset = 0 return offset def forward( self, hidden_states, attention_mask=None, context=None, context_mask=None, rotary_pos_emb=None, rotary_pos_cos=None, rotary_pos_sin=None, inference_params=None, packed_seq_params=None, ): """ Perform a forward pass through the transformer layer. This method implements the core computation of a transformer layer, including self-attention, cross-attention (if applicable), and feed-forward operations. Args: hidden_states (Tensor): Input tensor of shape [s, b, h] where s is sequence length, b is batch size, and h is hidden size. attention_mask (Tensor): Mask tensor for self-attention. context (Tensor, optional): Context tensor for cross-attention. context_mask (Tensor, optional): Mask tensor for cross-attention. rotary_pos_emb (Tensor, optional): Rotary positional embeddings. inference_params (object, optional): Parameters for inference-time optimizations. packed_seq_params (object, optional): Parameters for packed sequence processing. Returns: Tuple[Tensor, Tensor]: A tuple containing: output (Tensor): Transformed hidden states of shape [s, b, h]. context (Tensor): Updated context tensor if cross-attention is used, otherwise None. """ # Residual connection. residual = hidden_states # Optional Input Layer norm input_layernorm_output = self.input_layernorm(hidden_states) # Self attention. attention_output_with_bias = self.self_attention( input_layernorm_output, attention_mask=attention_mask, inference_params=inference_params, rotary_pos_emb=rotary_pos_emb, rotary_pos_cos=rotary_pos_cos, rotary_pos_sin=rotary_pos_sin, packed_seq_params=packed_seq_params, ) # TODO: could we move `bias_dropout_add_exec_handler` itself # inside the module provided in the `bias_dropout_add_spec` module? with self.bias_dropout_add_exec_handler(): hidden_states = self.self_attn_bda(self.training, self.config.bias_dropout_fusion)( attention_output_with_bias, residual, self.hidden_dropout ) # Residual connection. residual = hidden_states # Optional Layer norm after self-attention pre_cross_attn_layernorm_output = self.pre_cross_attn_layernorm(hidden_states) # Cross attention. attention_output_with_bias = self.cross_attention( pre_cross_attn_layernorm_output, attention_mask=context_mask, key_value_states=context, inference_params=inference_params, ) if isinstance(attention_output_with_bias, dict) and "context" in attention_output_with_bias: context = attention_output_with_bias["context"] # TODO: could we move `bias_dropout_add_exec_handler` itself # inside the module provided in the `bias_dropout_add_spec` module? with self.bias_dropout_add_exec_handler(): hidden_states = self.cross_attn_bda(self.training, self.config.bias_dropout_fusion)( attention_output_with_bias, residual, self.hidden_dropout ) # Residual connection. residual = hidden_states # Optional Layer norm post the cross-attention. pre_mlp_layernorm_output = self.pre_mlp_layernorm(hidden_states) # MLP. mlp_output_with_bias = self.mlp(pre_mlp_layernorm_output) # TODO: could we move `bias_dropout_add_exec_handler` itself # inside the module provided in the `bias_dropout_add_spec` module? with self.bias_dropout_add_exec_handler(): hidden_states = self.mlp_bda(self.training, self.config.bias_dropout_fusion)( mlp_output_with_bias, residual, self.hidden_dropout ) # Jit compiled function creates 'view' tensor. This tensor # potentially gets saved in the MPU checkpoint function context, # which rejects view tensors. While making a viewless tensor here # won't result in memory savings (like the data loader, or # p2p_communication), it serves to document the origin of this # 'view' tensor. output = make_viewless_tensor( inp=hidden_states, requires_grad=hidden_states.requires_grad, keep_graph=True ) # CUDA graph requires returned values to be Tensors if self.config.external_cuda_graph and self.training: return output return output, context def sharded_state_dict( self, prefix: str = '', sharded_offsets: tuple = (), metadata: Optional[dict] = None ) -> ShardedStateDict: """ Generate a sharded state dictionary for the transformer layer. Args: prefix (str, optional): Prefix to be added to all keys in the state dict. sharded_offsets (tuple, optional): Tuple of sharding offsets. metadata (Optional[dict], optional): Additional metadata for sharding. Returns: ShardedStateDict: A dictionary containing the sharded state of the transformer layer. """ sharded_state_dict = super().sharded_state_dict(prefix, sharded_offsets, metadata) prefixed_map = { f'{prefix}{k}': f'{prefix}{v}' for k, v in self.submodules_config.sharded_state_dict_keys_map.items() } if prefixed_map: apply_prefix_mapping(sharded_state_dict, prefixed_map) return sharded_state_dict def __call__(self, *args, **kwargs): if hasattr(self, 'cudagraph_manager'): return self.cudagraph_manager(self, args, kwargs) return super(MegatronModule, self).__call__(*args, **kwargs)