megatron_patch/model/qwen2_moe/layer_specs.py (281 lines of code) (raw):
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import warnings
from typing import Optional
from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add
from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear
from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules
from megatron.core.transformer.dot_product_attention import DotProductAttention
from megatron.core.transformer.enums import AttnMaskType
from megatron.core.transformer.identity_op import IdentityOp
from megatron.core.transformer.mlp import MLP, MLPSubmodules
from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.transformer.transformer_block import (
TransformerBlockSubmodules,
get_num_layers_to_build,
)
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.transformer.transformer_layer import (
TransformerLayer,
TransformerLayerSubmodules,
get_transformer_layer_offset,
)
from megatron.core.utils import get_te_version, is_te_min_version
from megatron.core.transformer.moe.experts import GroupedMLP, SequentialMLP, TEGroupedMLP
from megatron.core.transformer.moe.moe_layer import MoELayer, MoESubmodules
from megatron.core.transformer.moe.shared_experts import SharedExpertMLP
try:
from megatron.core.extensions.transformer_engine import (
TEColumnParallelLinear,
TEDotProductAttention,
TELayerNormColumnParallelLinear,
TELinear,
TENorm,
TERowParallelLinear,
TEColumnParallelGroupedLinear,
TERowParallelGroupedLinear,
)
HAVE_TE = True
except ImportError:
HAVE_TE = False
try:
import apex # pylint: disable=unused-import
from megatron.core.fusions.fused_layer_norm import FusedLayerNorm
HAVE_APEX = True
LNImpl = FusedLayerNorm
except ImportError:
from megatron.core.transformer.torch_norm import WrappedTorchNorm
warnings.warn('Apex is not installed. Falling back to Torch Norm')
LNImpl = WrappedTorchNorm
def get_moe_module_spec(
use_te: Optional[bool] = True,
num_experts: Optional[int] = None,
moe_grouped_gemm: Optional[bool] = False,
moe_use_legacy_grouped_gemm: Optional[bool] = False,
use_shared_expert_gate: Optional[bool] = False,
) -> ModuleSpec:
"""Helper function to get module spec for MoE"""
assert num_experts is not None
mlp = MLPSubmodules(
linear_fc1=TEColumnParallelLinear if use_te else ColumnParallelLinear,
linear_fc2=TERowParallelLinear if use_te else RowParallelLinear,
)
# experts spec
if moe_grouped_gemm:
## use GroupedMLP
if use_te and TEColumnParallelGroupedLinear is not None and not moe_use_legacy_grouped_gemm:
## use TEGroupedLinear
expert_module = TEGroupedMLP
expert_submodule = MLPSubmodules(
linear_fc1=TEColumnParallelGroupedLinear, linear_fc2=TERowParallelGroupedLinear
)
else:
## use legacy GroupedMLP
expert_module = GroupedMLP
expert_submodule = None
warnings.warn(
'The legacy GroupedMLP will be deprecated in Megatron-Core v0.12.0. '
'Please update the TransformerEngine to version>=1.7.0 and use TEGroupedMLP.'
)
else:
## use SequentialMLP
expert_module = SequentialMLP
if use_te and not is_te_min_version("1.7.0.dev0"):
warnings.warn(
"Only transformer-engine>=1.7.0 supports MoE experts, "
f"but your version is {get_te_version()}. Use local linear implementation instead."
)
expert_submodule = MLPSubmodules(
linear_fc1=ColumnParallelLinear, linear_fc2=RowParallelLinear
)
else:
expert_submodule = mlp
experts = ModuleSpec(module=expert_module, submodules=expert_submodule)
# shared experts spec
shared_experts = ModuleSpec(module=SharedExpertMLP, params={"gate": use_shared_expert_gate}, submodules=mlp)
# MoE module spec
moe_module_spec = ModuleSpec(
module=MoELayer, submodules=MoESubmodules(experts=experts, shared_experts=shared_experts)
)
return moe_module_spec
def get_gpt_layer_with_transformer_engine_spec(
num_experts: Optional[int] = None,
moe_grouped_gemm: Optional[bool] = False,
qk_layernorm: Optional[bool] = False,
fp8: Optional[str] = None, # pylint: disable=unused-arguments
moe_use_legacy_grouped_gemm: Optional[bool] = False,
) -> ModuleSpec:
"""Use this spec to use lower-level Transformer Engine modules (required for fp8 training).
Args:
num_experts (int, optional): Number of experts. Defaults to None.
moe_grouped_gemm (bool, optional): To use Grouped GEMM. Defaults to False.
qk_layernorm (bool, optional): To use layernorm for queries/keys. Defaults to False.
fp8 (str, optional): Deprecated. For temporary Nemo compatibility.
moe_use_legacy_grouped_gemm (bool, optional): Force use the legacy GroupedMLP.
Defaults to False.
Returns:
ModuleSpec: Module specification with TE modules
"""
if fp8 is not None:
warnings.warn(
'The fp8 argument in "get_gpt_layer_with_transformer_engine_spec" has been deprecated'
' and will be removed soon. Please update your code accordingly.'
)
mlp = get_mlp_module_spec(
use_te=True,
num_experts=num_experts,
moe_grouped_gemm=moe_grouped_gemm,
moe_use_legacy_grouped_gemm=moe_use_legacy_grouped_gemm,
)
# TENorm significantly harms convergence when used
# for QKLayerNorm if TE Version < 1.9;
# we instead use the Apex implementation.
qk_norm = TENorm if is_te_min_version("1.9.0") else FusedLayerNorm
return ModuleSpec(
module=TransformerLayer,
submodules=TransformerLayerSubmodules(
self_attention=ModuleSpec(
module=SelfAttention,
params={"attn_mask_type": AttnMaskType.causal},
submodules=SelfAttentionSubmodules(
linear_qkv=TELayerNormColumnParallelLinear,
core_attention=TEDotProductAttention,
linear_proj=TERowParallelLinear,
q_layernorm=qk_norm if qk_layernorm else IdentityOp,
k_layernorm=qk_norm if qk_layernorm else IdentityOp,
),
),
self_attn_bda=get_bias_dropout_add,
pre_mlp_layernorm=TENorm if num_experts else IdentityOp,
mlp=mlp,
mlp_bda=get_bias_dropout_add,
),
)
def get_gpt_layer_local_spec(
num_experts: Optional[int] = None,
moe_grouped_gemm: Optional[bool] = False,
qk_layernorm: Optional[bool] = False,
fp8: Optional[str] = None, # pylint: disable=unused-arguments
moe_use_legacy_grouped_gemm: Optional[bool] = False,
) -> ModuleSpec:
"""Use this spec for an implementation using only modules in Megatron-Core.
Args:
num_experts (int, optional): Number of experts. Defaults to None.
moe_grouped_gemm (bool, optional): To use Grouped GEMM. Defaults to False.
qk_layernorm (bool, optional): To use layernorm for queries/keys. Defaults to False.
fp8 (str, optional): Deprecated. For temporary Nemo compatibility.
moe_use_legacy_grouped_gemm (bool, optional): Force use the legacy GroupedMLP.
Defaults to False.
Returns:
ModuleSpec: Module specification with Megatron-Core modules
"""
if fp8 is not None:
warnings.warn(
'The fp8 argument in "get_gpt_layer_local_spec" has been deprecated'
' and will be removed soon. Please update your code accordingly.'
)
mlp = get_mlp_module_spec(
use_te=False,
num_experts=num_experts,
moe_grouped_gemm=moe_grouped_gemm,
moe_use_legacy_grouped_gemm=moe_use_legacy_grouped_gemm,
)
return ModuleSpec(
module=TransformerLayer,
submodules=TransformerLayerSubmodules(
input_layernorm=LNImpl,
self_attention=ModuleSpec(
module=SelfAttention,
params={"attn_mask_type": AttnMaskType.causal},
submodules=SelfAttentionSubmodules(
linear_qkv=ColumnParallelLinear,
core_attention=DotProductAttention,
linear_proj=RowParallelLinear,
q_layernorm=LNImpl if qk_layernorm else IdentityOp,
k_layernorm=LNImpl if qk_layernorm else IdentityOp,
),
),
self_attn_bda=get_bias_dropout_add,
pre_mlp_layernorm=LNImpl,
mlp=mlp,
mlp_bda=get_bias_dropout_add,
sharded_state_dict_keys_map={
'input_layernorm.': 'self_attention.linear_qkv.layer_norm_',
'pre_mlp_layernorm.': 'mlp.linear_fc1.layer_norm_',
},
),
)
def _get_mlp_module_spec(
use_te: Optional[bool] = True,
num_experts: Optional[int] = None,
moe_grouped_gemm: Optional[bool] = False,
fp8: Optional[str] = None, # pylint: disable=unused-arguments
moe_use_legacy_grouped_gemm: Optional[bool] = False,
):
warnings.warn(
"""This private function is on a deprecation track. Please switch to `get_mlp_module_spec`
since it will be removed in a future release."""
)
return get_mlp_module_spec(
use_te=use_te,
num_experts=num_experts,
moe_grouped_gemm=moe_grouped_gemm,
fp8=fp8,
moe_use_legacy_grouped_gemm=moe_use_legacy_grouped_gemm,
)
def get_mlp_module_spec(
use_te: Optional[bool] = True,
num_experts: Optional[int] = None,
moe_grouped_gemm: Optional[bool] = False,
fp8: Optional[str] = None, # pylint: disable=unused-arguments
moe_use_legacy_grouped_gemm: Optional[bool] = False,
) -> ModuleSpec:
"""Helper function to get module spec for MLP/MoE"""
if fp8 is not None:
warnings.warn(
'The fp8 argument in "_get_mlp_module_spec" has been deprecated'
' and will be removed soon. Please update your code accordingly.'
)
if num_experts is None:
# Dense MLP w/ or w/o TE modules.
return ModuleSpec(
module=MLP,
submodules=MLPSubmodules(
linear_fc1=TEColumnParallelLinear if use_te else ColumnParallelLinear,
linear_fc2=TERowParallelLinear if use_te else RowParallelLinear,
),
)
else:
# Mixture of experts with modules in megatron core.
return get_moe_module_spec(
use_te=use_te,
num_experts=num_experts,
moe_grouped_gemm=moe_grouped_gemm,
moe_use_legacy_grouped_gemm=moe_use_legacy_grouped_gemm,
use_shared_expert_gate=True,
)
def get_gpt_decoder_block_spec(
config: TransformerConfig, use_transformer_engine: bool
) -> TransformerBlockSubmodules:
"""GPT block spec."""
if use_transformer_engine:
layer_norm_impl = TENorm
else:
layer_norm_impl = LNImpl
# Layer specs.
dense_layer_spec = (
get_gpt_layer_with_transformer_engine_spec(
num_experts=None,
moe_grouped_gemm=False,
qk_layernorm=config.qk_layernorm,
moe_use_legacy_grouped_gemm=config.moe_use_legacy_grouped_gemm,
)
if use_transformer_engine
else get_gpt_layer_local_spec(
num_experts=None,
moe_grouped_gemm=False,
qk_layernorm=config.qk_layernorm,
moe_use_legacy_grouped_gemm=config.moe_use_legacy_grouped_gemm,
)
)
moe_layer_spec = (
get_gpt_layer_with_transformer_engine_spec(
num_experts=config.num_moe_experts,
moe_grouped_gemm=config.moe_grouped_gemm,
qk_layernorm=config.qk_layernorm,
moe_use_legacy_grouped_gemm=config.moe_use_legacy_grouped_gemm,
)
if use_transformer_engine
else get_gpt_layer_local_spec(
num_experts=config.num_moe_experts,
moe_grouped_gemm=config.moe_grouped_gemm,
qk_layernorm=config.qk_layernorm,
moe_use_legacy_grouped_gemm=config.moe_use_legacy_grouped_gemm,
)
)
# Parse config.moe_layer_freq to determine the pattern of expert/dense layers.
# 0 stands for dense layers, 1 stands for expert layers.
# For integer N: Creates a pattern with one expert layer every N layers.
# For string pattern: Evaluates the str directly (e.g. "[1,0,1]" for alternating expert/dense).
if isinstance(config.moe_layer_freq, int):
moe_layer_pattern = [
1 if (i % config.moe_layer_freq == 0) else 0 for i in range(config.num_layers)
]
elif isinstance(config.moe_layer_freq, list):
moe_layer_pattern = config.moe_layer_freq
assert len(moe_layer_pattern) == config.num_layers, (
f"Invalid length of moe_layer_pattern: {len(moe_layer_pattern)}, "
f"expected {config.num_layers}, "
f"current moe layer pattern: {config.moe_layer_freq}"
)
else:
raise ValueError(
f"Invalid moe_layer_freq: {type(config.moe_layer_freq)}, {config.moe_layer_freq}"
)
# Create the layer specs for the model.
layer_specs = []
for layer_number in range(config.num_layers):
if moe_layer_pattern[layer_number] == 1:
layer_specs.append(moe_layer_spec)
elif moe_layer_pattern[layer_number] == 0:
layer_specs.append(dense_layer_spec)
else:
raise ValueError(f"Invalid layer pattern: {moe_layer_pattern}")
# Slice the layer specs to only include the layers that are built in this pipeline stage.
# Note: MCore layer_number starts at 1
offset = get_transformer_layer_offset(config)
num_layers_to_build = get_num_layers_to_build(config)
layer_specs = layer_specs[offset : offset + num_layers_to_build]
# Block spec.
block_spec = TransformerBlockSubmodules(layer_specs=layer_specs, layer_norm=layer_norm_impl)
return block_spec