megatron_patch/model/mixtral/layer_specs.py (129 lines of code) (raw):
# Copyright (c) 2023 Alibaba PAI and Nvidia Megatron-LM Team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add
from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear
from megatron.core.transformer.moe.shared_experts import SharedExpertMLP
from megatron.core.transformer.dot_product_attention import DotProductAttention
from megatron.core.transformer.enums import AttnMaskType
from megatron.core.transformer.identity_op import IdentityOp
from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules
from .transformer.mlp import MLP, MLPSubmodules
from .transformer.attention import SelfAttention, SelfAttentionSubmodules
from .moe.moe_layer import MoELayer, MoESubmodules
try:
from megatron.core.extensions.transformer_engine import (
TEColumnParallelGroupedLinear,
TEColumnParallelLinear,
TEDotProductAttention,
TELayerNormColumnParallelLinear,
TENorm,
TERowParallelGroupedLinear,
TERowParallelLinear,
)
HAVE_TE = True
except ImportError:
HAVE_TE = False
try:
import apex # pylint: disable=unused-import
from megatron.core.fusions.fused_layer_norm import FusedLayerNorm
HAVE_APEX = True
LNImpl = FusedLayerNorm
except ImportError:
import warnings
from megatron.core.transformer.torch_layer_norm import WrappedTorchLayerNorm
warnings.warn('Apex is not installed. Falling back to Torch LayerNorm')
LNImpl = WrappedTorchLayerNorm
# Use this spec to use lower level Transformer Engine modules (required for fp8 training)
def get_gpt_layer_with_transformer_engine_spec(
num_experts: int = None, moe_grouped_gemm: bool = False
) -> ModuleSpec:
"""
Generates a spec for a GPT transformer layer using Transformer Engine modules.
Args:
num_experts: Optional; the number of experts to use in a Mixture of Experts (MoE) setup.
If `None`, a dense multi-layer perceptron (MLP) is used instead of MoE.
moe_grouped_gemm: Optional; if `True`, enables grouped GEMM for MoE operations,
which can be more efficient for certain configurations.
Returns:
A ModuleSpec object that specifies how to construct a GPT transformer layer with
the appropriate submodules for self-attention and MLP/MoE using Transformer Engine optimizations.
"""
mlp = _get_mlp_module_spec(
use_te=True, num_experts=num_experts, moe_grouped_gemm=moe_grouped_gemm
)
return ModuleSpec(
module=TransformerLayer,
submodules=TransformerLayerSubmodules(
self_attention=ModuleSpec(
module=SelfAttention,
params={"attn_mask_type": AttnMaskType.causal},
submodules=SelfAttentionSubmodules(
linear_qkv=TELayerNormColumnParallelLinear,
core_attention=TEDotProductAttention,
linear_proj=TERowParallelLinear,
),
),
self_attn_bda=get_bias_dropout_add,
pre_mlp_layernorm=TENorm if num_experts else IdentityOp,
mlp=mlp,
mlp_bda=get_bias_dropout_add,
),
)
# Use this spec for an implementation using only modules in megatron core
def get_gpt_layer_local_spec(num_experts: int = None, moe_grouped_gemm: bool = False) -> ModuleSpec:
"""
Generates a specification for a GPT transformer layer using only the core modules from Megatron.
Args:
num_experts: Optional; the number of experts to use in a Mixture of Experts (MoE) setup.
If `None`, a dense multi-layer perceptron (MLP) is used instead of MoE.
moe_grouped_gemm: Optional; if `True`, enables grouped GEMM for MoE operations,
which can be more efficient for certain configurations.
Returns:
A ModuleSpec object that specifies how to construct a GPT transformer layer with
standard Megatron core modules without the lower-level Transformer Engine optimizations.
"""
mlp = _get_mlp_module_spec(
use_te=False, num_experts=num_experts, moe_grouped_gemm=moe_grouped_gemm
)
return ModuleSpec(
module=TransformerLayer,
submodules=TransformerLayerSubmodules(
input_layernorm=FusedLayerNorm,
self_attention=ModuleSpec(
module=SelfAttention,
params={"attn_mask_type": AttnMaskType.causal},
submodules=SelfAttentionSubmodules(
linear_qkv=ColumnParallelLinear,
core_attention=DotProductAttention,
linear_proj=RowParallelLinear,
),
),
self_attn_bda=get_bias_dropout_add,
pre_mlp_layernorm=FusedLayerNorm,
mlp=mlp,
mlp_bda=get_bias_dropout_add,
sharded_state_dict_keys_map={
'input_layernorm.': 'self_attention.linear_qkv.layer_norm_',
'pre_mlp_layernorm.': 'mlp.linear_fc1.layer_norm_',
},
),
)
def _get_mlp_module_spec(
use_te: Optional[bool] = True,
num_experts: Optional[int] = None,
moe_grouped_gemm: Optional[bool] = False,
fp8: Optional[str] = None,
) -> ModuleSpec:
"""Helper function to get module spec for MLP/MoE"""
if num_experts is None:
# Dense MLP w/ or w/o TE modules.
return ModuleSpec(
module=MLP,
submodules=MLPSubmodules(
linear_fc1=TELayerNormColumnParallelLinear if use_te else ColumnParallelLinear,
linear_fc2=TERowParallelLinear if use_te else RowParallelLinear,
),
)
else:
# Mixture of experts with modules in megatron core.
if use_te and moe_grouped_gemm:
linear_fc1 = TEColumnParallelGroupedLinear
linear_fc2 = TERowParallelGroupedLinear
elif use_te and fp8:
linear_fc1 = TEColumnParallelLinear
linear_fc2 = TERowParallelLinear
else:
linear_fc1 = ColumnParallelLinear
linear_fc2 = RowParallelLinear
use_te_grouped_gemm = use_te and TEColumnParallelGroupedLinear is not None
return ModuleSpec(
module=MoELayer,
submodules=MoESubmodules(
experts=(
MLPSubmodules(linear_fc1=linear_fc1, linear_fc2=linear_fc2)
if not moe_grouped_gemm or use_te_grouped_gemm
else None
),
shared_experts=ModuleSpec(
module=SharedExpertMLP,
params={"gate": False},
submodules=MLPSubmodules(
linear_fc1=TEColumnParallelLinear if use_te else ColumnParallelLinear,
linear_fc2=TERowParallelLinear if use_te else RowParallelLinear,
),
),
),
)