megatron_patch/model/qwen3_moe/moe_module_specs.py (61 lines of code) (raw):

# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. import warnings from typing import Optional from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear from megatron.core.transformer.mlp import MLPSubmodules from megatron.core.transformer.moe.experts import GroupedMLP, SequentialMLP, TEGroupedMLP # from megatron.core.transformer.moe.moe_layer import MoELayer, MoESubmodules from megatron.core.transformer.moe.shared_experts import SharedExpertMLP from megatron.core.transformer.spec_utils import ModuleSpec from megatron.core.utils import get_te_version, is_te_min_version try: from megatron.core.extensions.transformer_engine import ( TEColumnParallelGroupedLinear, TEColumnParallelLinear, TERowParallelGroupedLinear, TERowParallelLinear, ) HAVE_TE = True except ImportError: HAVE_TE = False from .moe.moe_layer import MoELayer, MoESubmodules def get_moe_module_spec( use_te: Optional[bool] = True, num_experts: Optional[int] = None, moe_grouped_gemm: Optional[bool] = False, moe_use_legacy_grouped_gemm: Optional[bool] = False, ) -> ModuleSpec: """Helper function to get module spec for MoE""" assert num_experts is not None mlp = MLPSubmodules( linear_fc1=TEColumnParallelLinear if use_te else ColumnParallelLinear, linear_fc2=TERowParallelLinear if use_te else RowParallelLinear, ) # experts spec if moe_grouped_gemm: ## use GroupedMLP if use_te and TEColumnParallelGroupedLinear is not None and not moe_use_legacy_grouped_gemm: ## use TEGroupedLinear expert_module = TEGroupedMLP expert_submodule = MLPSubmodules( linear_fc1=TEColumnParallelGroupedLinear, linear_fc2=TERowParallelGroupedLinear ) else: ## use legacy GroupedMLP expert_module = GroupedMLP expert_submodule = None warnings.warn( 'The legacy GroupedMLP will be deprecated in Megatron-Core v0.12.0. ' 'Please update the TransformerEngine to version>=1.7.0 and use TEGroupedMLP.' ) else: ## use SequentialMLP expert_module = SequentialMLP if use_te and not is_te_min_version("1.7.0.dev0"): warnings.warn( "Only transformer-engine>=1.7.0 supports MoE experts, " f"but your version is {get_te_version()}. Use local linear implementation instead." ) expert_submodule = MLPSubmodules( linear_fc1=ColumnParallelLinear, linear_fc2=RowParallelLinear ) else: expert_submodule = mlp experts = ModuleSpec(module=expert_module, submodules=expert_submodule) # shared experts spec shared_experts = ModuleSpec(module=SharedExpertMLP, params={"gate": False}, submodules=mlp) # MoE module spec moe_module_spec = ModuleSpec( module=MoELayer, submodules=MoESubmodules(experts=experts, shared_experts=shared_experts) ) return moe_module_spec