megatron_patch/model/mixtral/layer_specs.py (129 lines of code) (raw):

# Copyright (c) 2023 Alibaba PAI and Nvidia Megatron-LM Team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Optional from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear from megatron.core.transformer.moe.shared_experts import SharedExpertMLP from megatron.core.transformer.dot_product_attention import DotProductAttention from megatron.core.transformer.enums import AttnMaskType from megatron.core.transformer.identity_op import IdentityOp from megatron.core.transformer.spec_utils import ModuleSpec from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules from .transformer.mlp import MLP, MLPSubmodules from .transformer.attention import SelfAttention, SelfAttentionSubmodules from .moe.moe_layer import MoELayer, MoESubmodules try: from megatron.core.extensions.transformer_engine import ( TEColumnParallelGroupedLinear, TEColumnParallelLinear, TEDotProductAttention, TELayerNormColumnParallelLinear, TENorm, TERowParallelGroupedLinear, TERowParallelLinear, ) HAVE_TE = True except ImportError: HAVE_TE = False try: import apex # pylint: disable=unused-import from megatron.core.fusions.fused_layer_norm import FusedLayerNorm HAVE_APEX = True LNImpl = FusedLayerNorm except ImportError: import warnings from megatron.core.transformer.torch_layer_norm import WrappedTorchLayerNorm warnings.warn('Apex is not installed. Falling back to Torch LayerNorm') LNImpl = WrappedTorchLayerNorm # Use this spec to use lower level Transformer Engine modules (required for fp8 training) def get_gpt_layer_with_transformer_engine_spec( num_experts: int = None, moe_grouped_gemm: bool = False ) -> ModuleSpec: """ Generates a spec for a GPT transformer layer using Transformer Engine modules. Args: num_experts: Optional; the number of experts to use in a Mixture of Experts (MoE) setup. If `None`, a dense multi-layer perceptron (MLP) is used instead of MoE. moe_grouped_gemm: Optional; if `True`, enables grouped GEMM for MoE operations, which can be more efficient for certain configurations. Returns: A ModuleSpec object that specifies how to construct a GPT transformer layer with the appropriate submodules for self-attention and MLP/MoE using Transformer Engine optimizations. """ mlp = _get_mlp_module_spec( use_te=True, num_experts=num_experts, moe_grouped_gemm=moe_grouped_gemm ) return ModuleSpec( module=TransformerLayer, submodules=TransformerLayerSubmodules( self_attention=ModuleSpec( module=SelfAttention, params={"attn_mask_type": AttnMaskType.causal}, submodules=SelfAttentionSubmodules( linear_qkv=TELayerNormColumnParallelLinear, core_attention=TEDotProductAttention, linear_proj=TERowParallelLinear, ), ), self_attn_bda=get_bias_dropout_add, pre_mlp_layernorm=TENorm if num_experts else IdentityOp, mlp=mlp, mlp_bda=get_bias_dropout_add, ), ) # Use this spec for an implementation using only modules in megatron core def get_gpt_layer_local_spec(num_experts: int = None, moe_grouped_gemm: bool = False) -> ModuleSpec: """ Generates a specification for a GPT transformer layer using only the core modules from Megatron. Args: num_experts: Optional; the number of experts to use in a Mixture of Experts (MoE) setup. If `None`, a dense multi-layer perceptron (MLP) is used instead of MoE. moe_grouped_gemm: Optional; if `True`, enables grouped GEMM for MoE operations, which can be more efficient for certain configurations. Returns: A ModuleSpec object that specifies how to construct a GPT transformer layer with standard Megatron core modules without the lower-level Transformer Engine optimizations. """ mlp = _get_mlp_module_spec( use_te=False, num_experts=num_experts, moe_grouped_gemm=moe_grouped_gemm ) return ModuleSpec( module=TransformerLayer, submodules=TransformerLayerSubmodules( input_layernorm=FusedLayerNorm, self_attention=ModuleSpec( module=SelfAttention, params={"attn_mask_type": AttnMaskType.causal}, submodules=SelfAttentionSubmodules( linear_qkv=ColumnParallelLinear, core_attention=DotProductAttention, linear_proj=RowParallelLinear, ), ), self_attn_bda=get_bias_dropout_add, pre_mlp_layernorm=FusedLayerNorm, mlp=mlp, mlp_bda=get_bias_dropout_add, sharded_state_dict_keys_map={ 'input_layernorm.': 'self_attention.linear_qkv.layer_norm_', 'pre_mlp_layernorm.': 'mlp.linear_fc1.layer_norm_', }, ), ) def _get_mlp_module_spec( use_te: Optional[bool] = True, num_experts: Optional[int] = None, moe_grouped_gemm: Optional[bool] = False, fp8: Optional[str] = None, ) -> ModuleSpec: """Helper function to get module spec for MLP/MoE""" if num_experts is None: # Dense MLP w/ or w/o TE modules. return ModuleSpec( module=MLP, submodules=MLPSubmodules( linear_fc1=TELayerNormColumnParallelLinear if use_te else ColumnParallelLinear, linear_fc2=TERowParallelLinear if use_te else RowParallelLinear, ), ) else: # Mixture of experts with modules in megatron core. if use_te and moe_grouped_gemm: linear_fc1 = TEColumnParallelGroupedLinear linear_fc2 = TERowParallelGroupedLinear elif use_te and fp8: linear_fc1 = TEColumnParallelLinear linear_fc2 = TERowParallelLinear else: linear_fc1 = ColumnParallelLinear linear_fc2 = RowParallelLinear use_te_grouped_gemm = use_te and TEColumnParallelGroupedLinear is not None return ModuleSpec( module=MoELayer, submodules=MoESubmodules( experts=( MLPSubmodules(linear_fc1=linear_fc1, linear_fc2=linear_fc2) if not moe_grouped_gemm or use_te_grouped_gemm else None ), shared_experts=ModuleSpec( module=SharedExpertMLP, params={"gate": False}, submodules=MLPSubmodules( linear_fc1=TEColumnParallelLinear if use_te else ColumnParallelLinear, linear_fc2=TERowParallelLinear if use_te else RowParallelLinear, ), ), ), )