megatron_patch/model/mixtral/moe/moe_layer.py (113 lines of code) (raw):

# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. from abc import ABC, abstractmethod from dataclasses import dataclass from typing import Union import torch from megatron.core import parallel_state, tensor_parallel from megatron.core.transformer.module import MegatronModule from megatron.core.transformer.moe.legacy_a2a_token_dispatcher import MoEAlltoAllSEQTokenDispatcher from megatron.core.transformer.spec_utils import ModuleSpec from .experts import GroupedMLP, SequentialMLP, TEGroupedMLP from .router import TopKRouter from .token_dispatcher import ( MoEAllGatherTokenDispatcher, MoEAlltoAllTokenDispatcher, ) from ..transformer_config import TransformerConfig from ..transformer.mlp import MLPSubmodules @dataclass class MoESubmodules: """MoE Layer Submodule spec""" experts: Union[ModuleSpec, type] = None shared_experts: Union[ModuleSpec, type] = None class BaseMoELayer(MegatronModule, ABC): """Base class for a mixture of experts layer. Args: config (TransformerConfig): Configuration object for the transformer model. """ def __init__(self, config: TransformerConfig, layer_number: int = None): super(BaseMoELayer, self).__init__(config) self.config = config self.expert_parallel_size = parallel_state.get_expert_model_parallel_world_size() assert self.expert_parallel_size > 0, "Expected non-negative expert parallel size" if self.config.moe_extended_tp: self.num_local_experts = self.config.num_moe_experts local_expert_indices_offset = 0 else: assert self.config.num_moe_experts % self.expert_parallel_size == 0 self.num_local_experts = self.config.num_moe_experts // self.expert_parallel_size local_expert_indices_offset = ( parallel_state.get_expert_model_parallel_rank() * self.num_local_experts ) self.use_shared_expert = self.config.moe_shared_expert_intermediate_size is not None self.shared_expert_overlap = self.config.moe_shared_expert_overlap self.local_expert_indices = [ local_expert_indices_offset + i for i in range(self.num_local_experts) ] assert all(map(lambda x: x < self.config.num_moe_experts, self.local_expert_indices)) self.router = None self.experts = None self.shared_experts = None self.token_dispatcher = None self.layer_number = layer_number @abstractmethod def forward(self, hidden_states): """Forward method for the MoE layer.""" pass def set_layer_number(self, layer_number: int): """Set the layer number for the MoE layer.""" self.layer_number = layer_number self.router.set_layer_number(layer_number) class MoELayer(BaseMoELayer): """Mixture of experts Layer **currently only supports no token dropping**. Args: BaseMoELayer (MegatronModule): Base class for MoE layers """ def __init__( self, config: TransformerConfig, submodules: MLPSubmodules = None, layer_number: int = None ): self.submodules = submodules super(MoELayer, self).__init__(config=config, layer_number=layer_number) self.moe_layer_recompute = config.moe_layer_recompute # Initialize router self.router = TopKRouter(config=self.config) # Initialize experts if self.config.moe_grouped_gemm: if isinstance(self.submodules.experts, MLPSubmodules): self.experts = TEGroupedMLP( self.num_local_experts, self.config, self.submodules.experts ) else: self.experts = GroupedMLP(self.num_local_experts, self.config) else: assert isinstance(self.submodules.experts, MLPSubmodules) self.experts = SequentialMLP( self.num_local_experts, self.config, self.submodules.experts ) # Initialize token dispatcher if config.moe_token_dispatcher_type == "allgather": self.token_dispatcher = MoEAllGatherTokenDispatcher( self.num_local_experts, self.local_expert_indices, config=self.config ) elif config.moe_token_dispatcher_type == "alltoall": self.token_dispatcher = MoEAlltoAllTokenDispatcher( self.num_local_experts, self.local_expert_indices, config=self.config ) elif config.moe_token_dispatcher_type == "alltoall_seq": self.token_dispatcher = MoEAlltoAllSEQTokenDispatcher( self.num_local_experts, self.local_expert_indices, config=self.config ) else: raise ValueError( f"Unsupported token dispatcher type: {config.moe_token_dispatcher_type}" ) def forward(self, hidden_states: torch.Tensor): if ( self.training and self.config.tensor_model_parallel_size > 1 and not self.config.sequence_parallel ): raise ValueError( "During training, performance may degrade if MoE and tensor parallelism" "are enabled without also enabling sequence parallelism." ) # process MoE def custom_forward(hidden_states): probs, routing_map = self.router(hidden_states) (dispatched_input, tokens_per_expert) = self.token_dispatcher.token_permutation( hidden_states, probs, routing_map ) expert_output, mlp_bias = self.experts(dispatched_input, tokens_per_expert) output, mlp_bias = self.token_dispatcher.token_unpermutation(expert_output, mlp_bias) if self.use_shared_expert and not self.shared_expert_overlap: # if shared_expert_overlap is True, the expert calculation happens in # the token_dispatcher to overlap communications and computations output += self.shared_experts(hidden_states) return output, mlp_bias if self.moe_layer_recompute: output, mlp_bias = tensor_parallel.checkpoint(custom_forward, False, hidden_states) else: output, mlp_bias = custom_forward(hidden_states) return output, mlp_bias