megatron_patch/model/mixtral/moe/moe_layer.py (113 lines of code) (raw):
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Union
import torch
from megatron.core import parallel_state, tensor_parallel
from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.moe.legacy_a2a_token_dispatcher import MoEAlltoAllSEQTokenDispatcher
from megatron.core.transformer.spec_utils import ModuleSpec
from .experts import GroupedMLP, SequentialMLP, TEGroupedMLP
from .router import TopKRouter
from .token_dispatcher import (
MoEAllGatherTokenDispatcher,
MoEAlltoAllTokenDispatcher,
)
from ..transformer_config import TransformerConfig
from ..transformer.mlp import MLPSubmodules
@dataclass
class MoESubmodules:
"""MoE Layer Submodule spec"""
experts: Union[ModuleSpec, type] = None
shared_experts: Union[ModuleSpec, type] = None
class BaseMoELayer(MegatronModule, ABC):
"""Base class for a mixture of experts layer.
Args:
config (TransformerConfig): Configuration object for the transformer model.
"""
def __init__(self, config: TransformerConfig, layer_number: int = None):
super(BaseMoELayer, self).__init__(config)
self.config = config
self.expert_parallel_size = parallel_state.get_expert_model_parallel_world_size()
assert self.expert_parallel_size > 0, "Expected non-negative expert parallel size"
if self.config.moe_extended_tp:
self.num_local_experts = self.config.num_moe_experts
local_expert_indices_offset = 0
else:
assert self.config.num_moe_experts % self.expert_parallel_size == 0
self.num_local_experts = self.config.num_moe_experts // self.expert_parallel_size
local_expert_indices_offset = (
parallel_state.get_expert_model_parallel_rank() * self.num_local_experts
)
self.use_shared_expert = self.config.moe_shared_expert_intermediate_size is not None
self.shared_expert_overlap = self.config.moe_shared_expert_overlap
self.local_expert_indices = [
local_expert_indices_offset + i for i in range(self.num_local_experts)
]
assert all(map(lambda x: x < self.config.num_moe_experts, self.local_expert_indices))
self.router = None
self.experts = None
self.shared_experts = None
self.token_dispatcher = None
self.layer_number = layer_number
@abstractmethod
def forward(self, hidden_states):
"""Forward method for the MoE layer."""
pass
def set_layer_number(self, layer_number: int):
"""Set the layer number for the MoE layer."""
self.layer_number = layer_number
self.router.set_layer_number(layer_number)
class MoELayer(BaseMoELayer):
"""Mixture of experts Layer **currently only supports no token dropping**.
Args:
BaseMoELayer (MegatronModule): Base class for MoE layers
"""
def __init__(
self, config: TransformerConfig, submodules: MLPSubmodules = None, layer_number: int = None
):
self.submodules = submodules
super(MoELayer, self).__init__(config=config, layer_number=layer_number)
self.moe_layer_recompute = config.moe_layer_recompute
# Initialize router
self.router = TopKRouter(config=self.config)
# Initialize experts
if self.config.moe_grouped_gemm:
if isinstance(self.submodules.experts, MLPSubmodules):
self.experts = TEGroupedMLP(
self.num_local_experts, self.config, self.submodules.experts
)
else:
self.experts = GroupedMLP(self.num_local_experts, self.config)
else:
assert isinstance(self.submodules.experts, MLPSubmodules)
self.experts = SequentialMLP(
self.num_local_experts, self.config, self.submodules.experts
)
# Initialize token dispatcher
if config.moe_token_dispatcher_type == "allgather":
self.token_dispatcher = MoEAllGatherTokenDispatcher(
self.num_local_experts, self.local_expert_indices, config=self.config
)
elif config.moe_token_dispatcher_type == "alltoall":
self.token_dispatcher = MoEAlltoAllTokenDispatcher(
self.num_local_experts, self.local_expert_indices, config=self.config
)
elif config.moe_token_dispatcher_type == "alltoall_seq":
self.token_dispatcher = MoEAlltoAllSEQTokenDispatcher(
self.num_local_experts, self.local_expert_indices, config=self.config
)
else:
raise ValueError(
f"Unsupported token dispatcher type: {config.moe_token_dispatcher_type}"
)
def forward(self, hidden_states: torch.Tensor):
if (
self.training
and self.config.tensor_model_parallel_size > 1
and not self.config.sequence_parallel
):
raise ValueError(
"During training, performance may degrade if MoE and tensor parallelism"
"are enabled without also enabling sequence parallelism."
)
# process MoE
def custom_forward(hidden_states):
probs, routing_map = self.router(hidden_states)
(dispatched_input, tokens_per_expert) = self.token_dispatcher.token_permutation(
hidden_states, probs, routing_map
)
expert_output, mlp_bias = self.experts(dispatched_input, tokens_per_expert)
output, mlp_bias = self.token_dispatcher.token_unpermutation(expert_output, mlp_bias)
if self.use_shared_expert and not self.shared_expert_overlap:
# if shared_expert_overlap is True, the expert calculation happens in
# the token_dispatcher to overlap communications and computations
output += self.shared_experts(hidden_states)
return output, mlp_bias
if self.moe_layer_recompute:
output, mlp_bias = tensor_parallel.checkpoint(custom_forward, False, hidden_states)
else:
output, mlp_bias = custom_forward(hidden_states)
return output, mlp_bias