megatron_patch/model/qwen3_moe/moe/moe_layer.py (70 lines of code) (raw):
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import torch
from megatron.core import tensor_parallel
from megatron.core.transformer.moe.legacy_a2a_token_dispatcher import MoEAlltoAllSEQTokenDispatcher
from megatron.core.transformer.moe.token_dispatcher import (
MoEAllGatherTokenDispatcher,
MoEAlltoAllTokenDispatcher,
MoEFlexTokenDispatcher,
)
from megatron.core.transformer.spec_utils import build_module
from megatron.core.transformer.transformer_config import TransformerConfig
from .router import TopKRouter
from megatron.core.transformer.moe.moe_layer import MoESubmodules, BaseMoELayer
class MoELayer(BaseMoELayer):
"""Mixture of experts Layer **currently only supports no token dropping**.
Args:
BaseMoELayer (MegatronModule): Base class for MoE layers
"""
def __init__(
self, config: TransformerConfig, submodules: MoESubmodules = None, layer_number: int = None
):
self.submodules = submodules
super(MoELayer, self).__init__(config=config, layer_number=layer_number)
self.moe_layer_recompute = config.moe_layer_recompute
# Initialize router
self.router = TopKRouter(config=self.config)
# Initialize token dispatcher
if config.moe_token_dispatcher_type == "allgather":
self.token_dispatcher = MoEAllGatherTokenDispatcher(
self.num_local_experts, self.local_expert_indices, config=self.config
)
elif config.moe_token_dispatcher_type == "alltoall":
self.token_dispatcher = MoEAlltoAllTokenDispatcher(
self.num_local_experts, self.local_expert_indices, config=self.config
)
elif config.moe_token_dispatcher_type == "alltoall_seq":
self.token_dispatcher = MoEAlltoAllSEQTokenDispatcher(
self.num_local_experts, self.local_expert_indices, config=self.config
)
elif config.moe_token_dispatcher_type == "flex":
self.token_dispatcher = MoEFlexTokenDispatcher(
self.num_local_experts, self.local_expert_indices, config=self.config
)
else:
raise ValueError(
f"Unsupported token dispatcher type: {config.moe_token_dispatcher_type}"
)
# Initialize experts
self.experts = build_module(self.submodules.experts, self.num_local_experts, self.config)
# Initialize shared experts
if self.use_shared_expert:
self.shared_experts = build_module(self.submodules.shared_experts, config=self.config)
if self.shared_expert_overlap:
self.token_dispatcher.set_shared_experts(self.shared_experts)
def forward(self, hidden_states: torch.Tensor):
if (
self.training
and self.config.tensor_model_parallel_size > 1
and not self.config.sequence_parallel
):
raise ValueError(
"During training, performance may degrade if MoE and tensor parallelism"
"are enabled without also enabling sequence parallelism."
)
# process MoE
def custom_forward(hidden_states):
probs, routing_map = self.router(hidden_states)
(dispatched_input, tokens_per_expert) = self.token_dispatcher.token_permutation(
hidden_states, probs, routing_map
)
expert_output, mlp_bias = self.experts(dispatched_input, tokens_per_expert)
output, mlp_bias = self.token_dispatcher.token_unpermutation(expert_output, mlp_bias)
if self.use_shared_expert and not self.shared_expert_overlap:
# if shared_expert_overlap is True, the expert calculation happens in
# the token_dispatcher to overlap communications and computations
output = output + self.shared_experts(hidden_states)
return output, mlp_bias
if self.moe_layer_recompute:
output, mlp_bias = tensor_parallel.checkpoint(custom_forward, False, hidden_states)
else:
output, mlp_bias = custom_forward(hidden_states)
return output, mlp_bias