megatron_patch/model/deepseek_v2/transformer_layer.py (226 lines of code) (raw):
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from abc import ABC
from dataclasses import dataclass, field
from typing import Dict, Optional, Union
import torch
from megatron.core import parallel_state
from megatron.core.dist_checkpointing.mapping import ShardedStateDict
from megatron.core.dist_checkpointing.utils import apply_prefix_mapping
from megatron.core.transformer.cuda_graphs import CudaGraphManager
from megatron.core.transformer.identity_op import IdentityFuncOp, IdentityOp
from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.spec_utils import ModuleSpec, build_module
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.utils import make_viewless_tensor
@dataclass
class TransformerLayerSubmodules:
"""
Configuration class for specifying the submodules of a transformer layer.
This class defines the structure and default implementations for various
components of a transformer layer, allowing for flexible customization
of the layer's architecture.
Args:
input_layernorm (Union[ModuleSpec, type]): Specification for the input layer normalization.
self_attention (Union[ModuleSpec, type]): Specification for the self-attention mechanism.
self_attn_bda (Union[ModuleSpec, type]): Specification for the bias-dropout-add operation
after self-attention.
pre_cross_attn_layernorm (Union[ModuleSpec, type]): Specification for the layer
normalization before cross-attention.
cross_attention (Union[ModuleSpec, type]): Specification for the cross-attention mechanism.
cross_attn_bda (Union[ModuleSpec, type]): Specification for the bias-dropout-add operation
after cross-attention.
pre_mlp_layernorm (Union[ModuleSpec, type]): Specification for the layer normalization
before the MLP.
mlp (Union[ModuleSpec, type]): Specification for the MLP.
mlp_bda (Union[ModuleSpec, type]): Specification for the bias-dropout-add operation
after the MLP.
sharded_state_dict_keys_map (Dict[str, str]): Mapping for sharded tensor keys to be applied
in the `sharded_state_dict` method.
"""
input_layernorm: Union[ModuleSpec, type] = IdentityOp
self_attention: Union[ModuleSpec, type] = IdentityOp
self_attn_bda: Union[ModuleSpec, type] = IdentityFuncOp
pre_cross_attn_layernorm: Union[ModuleSpec, type] = IdentityOp
cross_attention: Union[ModuleSpec, type] = IdentityOp
cross_attn_bda: Union[ModuleSpec, type] = IdentityFuncOp
pre_mlp_layernorm: Union[ModuleSpec, type] = IdentityOp
mlp: Union[ModuleSpec, type] = IdentityOp
mlp_dense: Union[ModuleSpec, type] = IdentityOp
mlp_bda: Union[ModuleSpec, type] = IdentityFuncOp
# Mapping for sharded tensor keys to be applied in `sharded_state_dict` method
sharded_state_dict_keys_map: Dict[str, str] = field(default_factory=dict)
class BaseTransformerLayer(ABC):
"""A common parent class for `TransformerLayer` like implementations.
A dummy class that is subclassed by similar `TransformerLayer`s e.g. the
`TransformerLayer` in this file and possibly other `TransformerLayer`
implementations that aim to use `TransformerBlock` as the base module.
The main purpose is to check if any layer (or module) provided in the spec
is a subclass of this class to allow fanning-out of that spec for all the
layers in the `TransformerBlock`. See `_get_block_submodules` method
implementation in `transformer_block.py` file for more details.
"""
def __init__(self):
pass
class TransformerLayer(MegatronModule, BaseTransformerLayer):
"""A single transformer layer.
Transformer layer takes input with size [s, b, h] and returns an
output of the same size.
"""
def __init__(
self,
config: TransformerConfig,
submodules: TransformerLayerSubmodules,
layer_number: int = 1,
hidden_dropout: float = None,
):
super().__init__(config=config)
if config.enable_cuda_graph and self.training:
assert (
not config.cpu_offloading and config.recompute_granularity is None
), "Cudagraphs not supported"
self.cudagraph_manager = CudaGraphManager()
self.submodules_config = submodules
self.layer_number = layer_number + self._get_layer_offset()
self.hidden_dropout = config.hidden_dropout if hidden_dropout is None else hidden_dropout
# [Module 1: Input Layernorm] Optional Layernorm on the input data
# TODO: add pytorch only layernorm
self.input_layernorm = build_module(
submodules.input_layernorm,
config=self.config,
hidden_size=self.config.hidden_size,
eps=self.config.layernorm_epsilon,
)
attention_optional_kwargs = {}
if config.cp_comm_type is not None:
if isinstance(config.cp_comm_type, list):
attention_optional_kwargs["cp_comm_type"] = config.cp_comm_type[self.layer_number]
else:
attention_optional_kwargs["cp_comm_type"] = config.cp_comm_type
# [Module 2: SelfAttention]
self.self_attention = build_module(
submodules.self_attention,
config=self.config,
layer_number=layer_number,
**attention_optional_kwargs,
)
# [Module 3: BiasDropoutFusion]
self.self_attn_bda = build_module(submodules.self_attn_bda)
# [Module 4: Post SelfAttention] Optional Layernorm after self-attn
self.pre_cross_attn_layernorm = build_module(
submodules.pre_cross_attn_layernorm,
config=self.config,
hidden_size=self.config.hidden_size,
eps=self.config.layernorm_epsilon,
)
# [Module 5: CrossAttention]
self.cross_attention = build_module(
submodules.cross_attention,
config=self.config,
layer_number=layer_number,
**attention_optional_kwargs,
)
# [Module 6: BiasDropoutFusion]
self.cross_attn_bda = build_module(submodules.cross_attn_bda, config=self.config)
# [Module 7: Pre MLP] Optional Layernorm before MLP
self.pre_mlp_layernorm = build_module(
submodules.pre_mlp_layernorm,
config=self.config,
hidden_size=self.config.hidden_size,
eps=self.config.layernorm_epsilon,
)
# [Module 8: MLP block]
# TODO how to set the gpt_layer_spec.py when we have moe_frequency > 1,
# where MLP and MoE layer both appear alternately?
if self.layer_number > self.config.moe_layer_freq:
self.mlp = build_module(submodules.mlp, config=self.config)
else:
self.mlp = build_module(submodules.mlp_dense, config=self.config)
"""
self.mlp = build_module(submodules.mlp, config=self.config)
"""
if hasattr(self.mlp, 'set_layer_number'):
self.mlp.set_layer_number(self.layer_number)
# [Module 9: BiasDropoutFusion]
self.mlp_bda = build_module(submodules.mlp_bda)
# @jcasper how should we handle nvfuser?
# Set bias+dropout+add fusion grad_enable execution handler.
# TORCH_MAJOR = int(torch.__version__.split('.')[0])
# TORCH_MINOR = int(torch.__version__.split('.')[1])
# use_nvfuser = TORCH_MAJOR > 1 or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10)
# self.bias_dropout_add_exec_handler = nullcontext if use_nvfuser else torch.enable_grad
self.bias_dropout_add_exec_handler = torch.enable_grad
def _get_layer_offset(self):
"""Get the index number of this layer, given the level of pipelining."""
pipeline_rank = parallel_state.get_pipeline_model_parallel_rank()
if not parallel_state.is_inside_encoder():
pipeline_rank = (
pipeline_rank - parallel_state.get_pipeline_model_parallel_decoder_start()
)
num_layers_per_pipeline_rank = (
self.config.num_layers // self.config.pipeline_model_parallel_size
)
if parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None:
vp_rank = parallel_state.get_virtual_pipeline_model_parallel_rank()
vp_size = parallel_state.get_virtual_pipeline_model_parallel_world_size()
total_num_layers = self.config.num_layers
num_layers_per_virtual_rank = num_layers_per_pipeline_rank // vp_size
total_virtual_chunks = total_num_layers // vp_size
offset = vp_rank * total_virtual_chunks + (pipeline_rank * num_layers_per_virtual_rank)
else:
# Each stage gets a contiguous set of layers.
if self.config.pipeline_model_parallel_size > 1:
if (
self.config.first_pipeline_num_layers is not None
or self.config.last_pipeline_num_layers is not None
):
# Calculate number of pipelines for distributing layers
middle_pipeline_stages = self.config.pipeline_model_parallel_size
middle_pipeline_stages -= sum(
[
1 if x is not None else 0
for x in (
self.config.first_pipeline_num_layers,
self.config.last_pipeline_num_layers,
)
]
)
# Calculate layers to distribute
first_pipeline_offset = (
0
if self.config.first_pipeline_num_layers is None
else self.config.first_pipeline_num_layers
)
last_pipeline_offset = (
0
if self.config.last_pipeline_num_layers is None
else self.config.last_pipeline_num_layers
)
middle_num_layers = (
self.config.num_layers - first_pipeline_offset - last_pipeline_offset
)
if middle_pipeline_stages > 0:
num_layers_per_pipeline_rank = middle_num_layers // middle_pipeline_stages
else:
num_layers_per_pipeline_rank = 0
middle_pipeline_rank = (
pipeline_rank
if self.config.first_pipeline_num_layers is None
else pipeline_rank - 1
)
if pipeline_rank == 0:
offset = 0
else:
offset = (
middle_pipeline_rank * num_layers_per_pipeline_rank
) + first_pipeline_offset
else:
offset = pipeline_rank * num_layers_per_pipeline_rank
else:
offset = 0
return offset
def forward(
self,
hidden_states,
attention_mask=None,
context=None,
context_mask=None,
rotary_pos_emb=None,
rotary_pos_cos=None,
rotary_pos_sin=None,
inference_params=None,
packed_seq_params=None,
):
"""
Perform a forward pass through the transformer layer.
This method implements the core computation of a transformer layer, including
self-attention, cross-attention (if applicable), and feed-forward operations.
Args:
hidden_states (Tensor): Input tensor of shape [s, b, h] where s is sequence length,
b is batch size, and h is hidden size.
attention_mask (Tensor): Mask tensor for self-attention.
context (Tensor, optional): Context tensor for cross-attention.
context_mask (Tensor, optional): Mask tensor for cross-attention.
rotary_pos_emb (Tensor, optional): Rotary positional embeddings.
inference_params (object, optional): Parameters for inference-time optimizations.
packed_seq_params (object, optional): Parameters for packed sequence processing.
Returns:
Tuple[Tensor, Tensor]: A tuple containing:
output (Tensor): Transformed hidden states of shape [s, b, h].
context (Tensor): Updated context tensor if cross-attention is used,
otherwise None.
"""
# Residual connection.
residual = hidden_states
# Optional Input Layer norm
input_layernorm_output = self.input_layernorm(hidden_states)
# Self attention.
attention_output_with_bias = self.self_attention(
input_layernorm_output,
attention_mask=attention_mask,
inference_params=inference_params,
rotary_pos_emb=rotary_pos_emb,
rotary_pos_cos=rotary_pos_cos,
rotary_pos_sin=rotary_pos_sin,
packed_seq_params=packed_seq_params,
)
# TODO: could we move `bias_dropout_add_exec_handler` itself
# inside the module provided in the `bias_dropout_add_spec` module?
with self.bias_dropout_add_exec_handler():
hidden_states = self.self_attn_bda(self.training, self.config.bias_dropout_fusion)(
attention_output_with_bias, residual, self.hidden_dropout
)
# Residual connection.
residual = hidden_states
# Optional Layer norm after self-attention
pre_cross_attn_layernorm_output = self.pre_cross_attn_layernorm(hidden_states)
# Cross attention.
attention_output_with_bias = self.cross_attention(
pre_cross_attn_layernorm_output,
attention_mask=context_mask,
key_value_states=context,
inference_params=inference_params,
)
if isinstance(attention_output_with_bias, dict) and "context" in attention_output_with_bias:
context = attention_output_with_bias["context"]
# TODO: could we move `bias_dropout_add_exec_handler` itself
# inside the module provided in the `bias_dropout_add_spec` module?
with self.bias_dropout_add_exec_handler():
hidden_states = self.cross_attn_bda(self.training, self.config.bias_dropout_fusion)(
attention_output_with_bias, residual, self.hidden_dropout
)
# Residual connection.
residual = hidden_states
# Optional Layer norm post the cross-attention.
pre_mlp_layernorm_output = self.pre_mlp_layernorm(hidden_states)
# MLP.
mlp_output_with_bias = self.mlp(pre_mlp_layernorm_output)
# TODO: could we move `bias_dropout_add_exec_handler` itself
# inside the module provided in the `bias_dropout_add_spec` module?
with self.bias_dropout_add_exec_handler():
hidden_states = self.mlp_bda(self.training, self.config.bias_dropout_fusion)(
mlp_output_with_bias, residual, self.hidden_dropout
)
# Jit compiled function creates 'view' tensor. This tensor
# potentially gets saved in the MPU checkpoint function context,
# which rejects view tensors. While making a viewless tensor here
# won't result in memory savings (like the data loader, or
# p2p_communication), it serves to document the origin of this
# 'view' tensor.
output = make_viewless_tensor(
inp=hidden_states, requires_grad=hidden_states.requires_grad, keep_graph=True
)
# CUDA graph requires returned values to be Tensors
if self.config.external_cuda_graph and self.training:
return output
return output, context
def sharded_state_dict(
self, prefix: str = '', sharded_offsets: tuple = (), metadata: Optional[dict] = None
) -> ShardedStateDict:
"""
Generate a sharded state dictionary for the transformer layer.
Args:
prefix (str, optional): Prefix to be added to all keys in the state dict.
sharded_offsets (tuple, optional): Tuple of sharding offsets.
metadata (Optional[dict], optional): Additional metadata for sharding.
Returns:
ShardedStateDict: A dictionary containing the sharded state of the transformer layer.
"""
sharded_state_dict = super().sharded_state_dict(prefix, sharded_offsets, metadata)
prefixed_map = {
f'{prefix}{k}': f'{prefix}{v}'
for k, v in self.submodules_config.sharded_state_dict_keys_map.items()
}
if prefixed_map:
apply_prefix_mapping(sharded_state_dict, prefixed_map)
return sharded_state_dict
def __call__(self, *args, **kwargs):
if hasattr(self, 'cudagraph_manager'):
return self.cudagraph_manager(self, args, kwargs)
return super(MegatronModule, self).__call__(*args, **kwargs)