megatron_patch/model/qwen3_moe/moe/moe_layer.py (70 lines of code) (raw):

# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. import torch from megatron.core import tensor_parallel from megatron.core.transformer.moe.legacy_a2a_token_dispatcher import MoEAlltoAllSEQTokenDispatcher from megatron.core.transformer.moe.token_dispatcher import ( MoEAllGatherTokenDispatcher, MoEAlltoAllTokenDispatcher, MoEFlexTokenDispatcher, ) from megatron.core.transformer.spec_utils import build_module from megatron.core.transformer.transformer_config import TransformerConfig from .router import TopKRouter from megatron.core.transformer.moe.moe_layer import MoESubmodules, BaseMoELayer class MoELayer(BaseMoELayer): """Mixture of experts Layer **currently only supports no token dropping**. Args: BaseMoELayer (MegatronModule): Base class for MoE layers """ def __init__( self, config: TransformerConfig, submodules: MoESubmodules = None, layer_number: int = None ): self.submodules = submodules super(MoELayer, self).__init__(config=config, layer_number=layer_number) self.moe_layer_recompute = config.moe_layer_recompute # Initialize router self.router = TopKRouter(config=self.config) # Initialize token dispatcher if config.moe_token_dispatcher_type == "allgather": self.token_dispatcher = MoEAllGatherTokenDispatcher( self.num_local_experts, self.local_expert_indices, config=self.config ) elif config.moe_token_dispatcher_type == "alltoall": self.token_dispatcher = MoEAlltoAllTokenDispatcher( self.num_local_experts, self.local_expert_indices, config=self.config ) elif config.moe_token_dispatcher_type == "alltoall_seq": self.token_dispatcher = MoEAlltoAllSEQTokenDispatcher( self.num_local_experts, self.local_expert_indices, config=self.config ) elif config.moe_token_dispatcher_type == "flex": self.token_dispatcher = MoEFlexTokenDispatcher( self.num_local_experts, self.local_expert_indices, config=self.config ) else: raise ValueError( f"Unsupported token dispatcher type: {config.moe_token_dispatcher_type}" ) # Initialize experts self.experts = build_module(self.submodules.experts, self.num_local_experts, self.config) # Initialize shared experts if self.use_shared_expert: self.shared_experts = build_module(self.submodules.shared_experts, config=self.config) if self.shared_expert_overlap: self.token_dispatcher.set_shared_experts(self.shared_experts) def forward(self, hidden_states: torch.Tensor): if ( self.training and self.config.tensor_model_parallel_size > 1 and not self.config.sequence_parallel ): raise ValueError( "During training, performance may degrade if MoE and tensor parallelism" "are enabled without also enabling sequence parallelism." ) # process MoE def custom_forward(hidden_states): probs, routing_map = self.router(hidden_states) (dispatched_input, tokens_per_expert) = self.token_dispatcher.token_permutation( hidden_states, probs, routing_map ) expert_output, mlp_bias = self.experts(dispatched_input, tokens_per_expert) output, mlp_bias = self.token_dispatcher.token_unpermutation(expert_output, mlp_bias) if self.use_shared_expert and not self.shared_expert_overlap: # if shared_expert_overlap is True, the expert calculation happens in # the token_dispatcher to overlap communications and computations output = output + self.shared_experts(hidden_states) return output, mlp_bias if self.moe_layer_recompute: output, mlp_bias = tensor_parallel.checkpoint(custom_forward, False, hidden_states) else: output, mlp_bias = custom_forward(hidden_states) return output, mlp_bias