megatron_patch/model/qwen1_5/moe/moe_layer.py (78 lines of code) (raw):

# Copyright (c) 2023 Alibaba PAI and Nvidia Megatron-LM Team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from abc import ABC, abstractmethod import torch import torch.nn.functional as F from megatron.core import parallel_state from megatron.core.transformer.module import MegatronModule from megatron.core.transformer.transformer_config import TransformerConfig from .experts import GroupedMLP, SequentialMLP from .router import TopKRouter from .token_dispatcher import ( MoEAllGatherTokenDispatcher, MoEAlltoAllTokenDispatcher, ) from ..transformer.mlp import MLPSubmodules, MLP class BaseMoELayer(MegatronModule, ABC): """Base class for a mixture of experts layer. Args: config (TransformerConfig): Configuration object for the transformer model. """ def __init__(self, config: TransformerConfig, layer_number: int = None): super(BaseMoELayer, self).__init__(config) self.config = config self.expert_parallel_size = parallel_state.get_expert_model_parallel_world_size() assert self.expert_parallel_size > 0, "Expected non-negative expert parallel size" assert self.config.num_moe_experts % self.expert_parallel_size == 0 self.num_local_experts = self.config.num_moe_experts // self.expert_parallel_size local_expert_indices_offset = ( parallel_state.get_expert_model_parallel_rank() * self.num_local_experts ) self.local_expert_indices = [ local_expert_indices_offset + i for i in range(self.num_local_experts) ] assert all(map(lambda x: x < self.config.num_moe_experts, self.local_expert_indices)) self.router = None self.experts = None self.token_dispatcher = None self.layer_number = layer_number @abstractmethod def forward(self, hidden_states): pass def set_layer_number(self, layer_number: int): self.layer_number = layer_number self.router.set_layer_number(layer_number) class MoELayer(BaseMoELayer): """Mixture of experts Layer **currently only supports no token dropping**. Args: BaseMoELayer (MegatronModule): Base class for MoE layers """ def __init__( self, config: TransformerConfig, submodules: MLPSubmodules = None, layer_number: int = None ): self.submodules = submodules super(MoELayer, self).__init__(config=config, layer_number=layer_number) self.router = TopKRouter(config=self.config) self.enable_shared_experts = config.enable_shared_expert if config.enable_shared_expert: self.shared_expert = MLP(self.config, submodules, is_expert=False, is_shared_expert=True) self.shared_expert_gate = torch.nn.Linear(config.hidden_size, 1, bias=False) if self.config.moe_grouped_gemm: self.experts = GroupedMLP(self.num_local_experts, self.config) else: assert isinstance(self.submodules, MLPSubmodules) self.experts = SequentialMLP(self.num_local_experts, self.config, self.submodules) if config.moe_token_dispatcher_type == "allgather": self.token_dispatcher = MoEAllGatherTokenDispatcher( self.num_local_experts, self.local_expert_indices, config=self.config ) elif config.moe_token_dispatcher_type == "alltoall": self.token_dispatcher = MoEAlltoAllTokenDispatcher( self.num_local_experts, self.local_expert_indices, config=self.config ) else: raise ValueError( f"Unsupported token dispatcher type: {config.moe_token_dispatcher_type}" ) def forward(self, hidden_states: torch.Tensor): # process MoE scores, indices = self.router(hidden_states) (dispatched_input, tokens_per_expert) = self.token_dispatcher.token_permutation( hidden_states, scores, indices ) expert_output, mlp_bias = self.experts(dispatched_input, tokens_per_expert) output, mlp_bias = self.token_dispatcher.token_unpermutation(expert_output, mlp_bias) if self.enable_shared_experts: shared_expert_output, shared_bias = self.shared_expert(hidden_states) shared_expert_output = F.sigmoid(self.shared_expert_gate(hidden_states).view(-1, 1)) * shared_expert_output.view(-1, hidden_states.shape[-1]) output = output + shared_expert_output.view(-1, hidden_states.shape[-2], hidden_states.shape[-1]) return output, mlp_bias