megatron_patch/model/deepseek_v2/layer_specs.py (120 lines of code) (raw):

# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. from typing import Optional from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear from megatron.core.transformer.enums import AttnMaskType from megatron.core.transformer.identity_op import IdentityOp from megatron.core.transformer.spec_utils import ModuleSpec from .moe.shared_experts import SharedExpertMLP from .moe.moe_layer import MoELayer, MoESubmodules from .mlp import MLP, MLPSubmodules from .transformer_layer import TransformerLayer, TransformerLayerSubmodules from .multi_latent_attention import ( MLASelfAttention, MLASelfAttentionSubmodules, ) try: from megatron.core.extensions.transformer_engine import ( TEColumnParallelGroupedLinear, TEColumnParallelLinear, TEDotProductAttention, TELayerNormColumnParallelLinear, TENorm, TERowParallelGroupedLinear, TERowParallelLinear, ) HAVE_TE = True except ImportError: HAVE_TE = False try: import apex # pylint: disable=unused-import from megatron.core.fusions.fused_layer_norm import FusedLayerNorm HAVE_APEX = True LNImpl = FusedLayerNorm except ImportError: import warnings from megatron.core.transformer.torch_layer_norm import WrappedTorchLayerNorm warnings.warn('Apex is not installed. Falling back to Torch LayerNorm') LNImpl = WrappedTorchLayerNorm def get_gpt_layer_with_transformer_engine_spec( num_experts: Optional[int] = None, moe_grouped_gemm: Optional[bool] = False, qk_layernorm: Optional[bool] = False, multi_latent_attention: Optional[bool] = False, fp8: Optional[str] = None, ) -> ModuleSpec: """Use this spec to use lower-level Transformer Engine modules (required for fp8 training). Args: num_experts (int, optional): Number of experts. Defaults to None. moe_grouped_gemm (bool, optional): To use Grouped GEMM. Defaults to False. qk_layernorm (bool, optional): To use layernorm for queries/keys. Defaults to False. fp8 (str, optional): Flag to decide the linear layer spec for MoE. Defaults to None. Returns: ModuleSpec: Module specification with TE modules """ mlp_moe = _get_mlp_module_spec( use_te=True, num_experts=num_experts, moe_grouped_gemm=moe_grouped_gemm, fp8=fp8 ) mlp_dense = _get_mlp_module_spec( use_te=True, num_experts=None, moe_grouped_gemm=moe_grouped_gemm, fp8=fp8 ) if multi_latent_attention: return ModuleSpec( module=TransformerLayer, submodules=TransformerLayerSubmodules( self_attention=ModuleSpec( module=MLASelfAttention, params={"attn_mask_type": AttnMaskType.causal}, submodules=MLASelfAttentionSubmodules( linear_q_proj=TEColumnParallelLinear, linear_q_down_proj=TEColumnParallelLinear, linear_q_up_proj=TEColumnParallelLinear, linear_kv_down_proj=TEColumnParallelLinear, linear_kv_up_proj=TEColumnParallelLinear, core_attention=TEDotProductAttention, linear_proj=TERowParallelLinear, q_layernorm=TENorm if qk_layernorm else IdentityOp, kv_layernorm=TENorm if qk_layernorm else IdentityOp, ), ), self_attn_bda=get_bias_dropout_add, pre_mlp_layernorm=TENorm if num_experts else IdentityOp, input_layernorm=TENorm if num_experts else IdentityOp, mlp=mlp_moe, mlp_dense=mlp_dense, mlp_bda=get_bias_dropout_add, ), ) def _get_mlp_module_spec( use_te: Optional[bool] = True, num_experts: Optional[int] = None, moe_grouped_gemm: Optional[bool] = False, fp8: Optional[str] = None, ) -> ModuleSpec: """Helper function to get module spec for MLP/MoE""" if num_experts is None: # Dense MLP w/ or w/o TE modules. return ModuleSpec( module=MLP, submodules=MLPSubmodules( linear_fc1=TEColumnParallelLinear if use_te else ColumnParallelLinear, linear_fc2=TERowParallelLinear if use_te else RowParallelLinear, ), ) else: # Mixture of experts with modules in megatron core. if use_te and moe_grouped_gemm: linear_fc1 = TEColumnParallelGroupedLinear linear_fc2 = TERowParallelGroupedLinear elif use_te and fp8: linear_fc1 = TEColumnParallelLinear linear_fc2 = TERowParallelLinear else: linear_fc1 = ColumnParallelLinear linear_fc2 = RowParallelLinear use_te_grouped_gemm = use_te and TEColumnParallelGroupedLinear is not None return ModuleSpec( module=MoELayer, submodules=MoESubmodules( experts=( MLPSubmodules(linear_fc1=linear_fc1, linear_fc2=linear_fc2) if not moe_grouped_gemm or use_te_grouped_gemm else None ), shared_experts=ModuleSpec( module=SharedExpertMLP, params={"gate": False}, submodules=MLPSubmodules( linear_fc1=TEColumnParallelLinear if use_te else ColumnParallelLinear, linear_fc2=TERowParallelLinear if use_te else RowParallelLinear, ), ), ), )