megatron_patch/model/qwen1_5/moe/moe_layer.py (78 lines of code) (raw):
# Copyright (c) 2023 Alibaba PAI and Nvidia Megatron-LM Team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC, abstractmethod
import torch
import torch.nn.functional as F
from megatron.core import parallel_state
from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.transformer_config import TransformerConfig
from .experts import GroupedMLP, SequentialMLP
from .router import TopKRouter
from .token_dispatcher import (
MoEAllGatherTokenDispatcher,
MoEAlltoAllTokenDispatcher,
)
from ..transformer.mlp import MLPSubmodules, MLP
class BaseMoELayer(MegatronModule, ABC):
"""Base class for a mixture of experts layer.
Args:
config (TransformerConfig): Configuration object for the transformer model.
"""
def __init__(self, config: TransformerConfig, layer_number: int = None):
super(BaseMoELayer, self).__init__(config)
self.config = config
self.expert_parallel_size = parallel_state.get_expert_model_parallel_world_size()
assert self.expert_parallel_size > 0, "Expected non-negative expert parallel size"
assert self.config.num_moe_experts % self.expert_parallel_size == 0
self.num_local_experts = self.config.num_moe_experts // self.expert_parallel_size
local_expert_indices_offset = (
parallel_state.get_expert_model_parallel_rank() * self.num_local_experts
)
self.local_expert_indices = [
local_expert_indices_offset + i for i in range(self.num_local_experts)
]
assert all(map(lambda x: x < self.config.num_moe_experts, self.local_expert_indices))
self.router = None
self.experts = None
self.token_dispatcher = None
self.layer_number = layer_number
@abstractmethod
def forward(self, hidden_states):
pass
def set_layer_number(self, layer_number: int):
self.layer_number = layer_number
self.router.set_layer_number(layer_number)
class MoELayer(BaseMoELayer):
"""Mixture of experts Layer **currently only supports no token dropping**.
Args:
BaseMoELayer (MegatronModule): Base class for MoE layers
"""
def __init__(
self, config: TransformerConfig, submodules: MLPSubmodules = None, layer_number: int = None
):
self.submodules = submodules
super(MoELayer, self).__init__(config=config, layer_number=layer_number)
self.router = TopKRouter(config=self.config)
self.enable_shared_experts = config.enable_shared_expert
if config.enable_shared_expert:
self.shared_expert = MLP(self.config, submodules, is_expert=False, is_shared_expert=True)
self.shared_expert_gate = torch.nn.Linear(config.hidden_size, 1, bias=False)
if self.config.moe_grouped_gemm:
self.experts = GroupedMLP(self.num_local_experts, self.config)
else:
assert isinstance(self.submodules, MLPSubmodules)
self.experts = SequentialMLP(self.num_local_experts, self.config, self.submodules)
if config.moe_token_dispatcher_type == "allgather":
self.token_dispatcher = MoEAllGatherTokenDispatcher(
self.num_local_experts, self.local_expert_indices, config=self.config
)
elif config.moe_token_dispatcher_type == "alltoall":
self.token_dispatcher = MoEAlltoAllTokenDispatcher(
self.num_local_experts, self.local_expert_indices, config=self.config
)
else:
raise ValueError(
f"Unsupported token dispatcher type: {config.moe_token_dispatcher_type}"
)
def forward(self, hidden_states: torch.Tensor):
# process MoE
scores, indices = self.router(hidden_states)
(dispatched_input, tokens_per_expert) = self.token_dispatcher.token_permutation(
hidden_states, scores, indices
)
expert_output, mlp_bias = self.experts(dispatched_input, tokens_per_expert)
output, mlp_bias = self.token_dispatcher.token_unpermutation(expert_output, mlp_bias)
if self.enable_shared_experts:
shared_expert_output, shared_bias = self.shared_expert(hidden_states)
shared_expert_output = F.sigmoid(self.shared_expert_gate(hidden_states).view(-1, 1)) * shared_expert_output.view(-1, hidden_states.shape[-1])
output = output + shared_expert_output.view(-1, hidden_states.shape[-2], hidden_states.shape[-1])
return output, mlp_bias