megatron_patch/model/qwen2/moe/router.py (206 lines of code) (raw):
# Copyright (c) 2024 Alibaba PAI and Nvidia Megatron-LM Team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC, abstractmethod
import torch
from megatron.core import parallel_state
from megatron.core.tensor_parallel import (
gather_from_sequence_parallel_region,
get_cuda_rng_tracker,
get_data_parallel_rng_tracker_name,
)
from megatron.core.tensor_parallel.random import (
get_cuda_rng_tracker,
get_data_parallel_rng_tracker_name,
)
from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.moe.moe_utils import (
MoEAuxLossAutoScaler,
save_to_aux_losses_tracker,
sinkhorn,
get_capacity,
switch_load_balancing_loss_func,
z_loss_func,
)
from megatron.core.transformer.transformer_config import TransformerConfig
def topk_softmax_with_capacity(
logits: torch.Tensor,
topk: int,
capacity_factor: float = None,
pad_to_capacity: bool = False,
drop_policy: str = "probs",
):
"""Apply capacity and padding to the top-k selection.
Args:
logits (torch.Tensor): Logits tensor.
topk (int): The number of experts to select for each token.
capacity_factor (int): The capacity factor of each expert. Will drop tokens if the number of tokens exceeds the capacity.
pad_to_capacity (bool): Whether to need padding in token drop mode.
drop_policy (str): The policy to drop tokens. Can be either "prob" or "position". If "prob", the tokens with the lowest probabilities will be dropped. If "position", tokens at the end of each batch will be dropped.
Returns:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Probs, indices and tokens_per_expert tensor.
(1) If there's no token padding, the shape of probs and indices is [tokens, top_k], indicating the selected experts for each token.
(2) If there's token padding, the shape of probs and indices is [num_expert, capacity], indicating the tokens selected for each expert.
"""
# TODO: Add Pre softmax.
assert logits.dim() == 2, f"Expected 2D logits [num_tokens, num_experts], got {logits.dim()}."
num_tokens = logits.shape[0]
num_experts = logits.shape[1]
#scores, top_indices = torch.topk(logits, k=topk, dim=1)
#probs = torch.softmax(scores, dim=-1, dtype=torch.float32).type_as(logits)
routing_weights = torch.softmax(logits, dim=1, dtype=torch.float32).type_as(logits)
probs, top_indices = torch.topk(routing_weights, k=topk, dim=-1)
if capacity_factor is None:
# TopK without capacity
tokens_per_expert = torch.histc(top_indices, bins=num_experts, min=0, max=num_experts)
return probs, top_indices, tokens_per_expert
else:
# TopK with capacity
expert_capacity = get_capacity(
num_tokens=num_tokens * topk, num_experts=num_experts, capacity_factor=capacity_factor,
)
# TopK selection, Maskout unused experts
topk_masked_gates = torch.zeros_like(logits).scatter(1, top_indices, probs)
topk_mask = torch.zeros_like(logits).scatter(1, top_indices, 1)
# Maskout exceeded tokens
if drop_policy == "probs":
capacity_probs, capacity_indices = torch.topk(
topk_masked_gates, k=expert_capacity, dim=0, sorted=False
)
capacity_mask = torch.zeros_like(logits).scatter(0, capacity_indices, 1)
elif drop_policy == "position":
_, capacity_indices = torch.topk(topk_mask, k=expert_capacity, dim=0, sorted=False)
capacity_mask = torch.zeros_like(logits).scatter(0, capacity_indices, 1)
capacity_probs = torch.gather(topk_masked_gates, 0, capacity_indices)
else:
raise ValueError(f"Invalid drop_policy: {drop_policy}")
if pad_to_capacity:
final_probs, final_indices = (
capacity_probs.T.contiguous(),
capacity_indices.T.contiguous(),
)
tokens_per_expert_before_capacity = topk_mask.sum(dim=0)
else:
# Get exceed mask and maskout exceeded probs and indices
final_mask = torch.logical_and(topk_mask, capacity_mask)
drop_mask = torch.logical_not(final_mask)
exceed_mask = torch.gather(drop_mask, 1, top_indices)
final_probs = probs * torch.logical_not(exceed_mask)
final_indices = top_indices.clone().masked_fill_(
exceed_mask, torch.iinfo(torch.long).max
)
tokens_per_expert_before_capacity = topk_mask.sum(dim=0)
return final_probs, final_indices, tokens_per_expert_before_capacity
class Router(ABC, MegatronModule):
"""Base Router class"""
def __init__(self, config: TransformerConfig) -> None:
"""
Initialize the Router module.
Args:
config (TransformerConfig): Configuration object for the Transformer model.
"""
super().__init__(config)
self.config = config
self.num_experts = self.config.num_moe_experts
self.moe_aux_loss_func = None
self.layer_number = None
# Initialize the gate weights.
self.weight = torch.nn.Parameter(
torch.empty((self.config.num_moe_experts, self.config.hidden_size))
)
with get_cuda_rng_tracker().fork(get_data_parallel_rng_tracker_name()):
config.init_method(self.weight)
setattr(self.weight, 'sequence_parallel', config.sequence_parallel)
def gating(self, input: torch.Tensor):
"""Forward pass of the router gate.
Args:
input (torch.Tensor): Input tensor.
Returns:
torch.Tensor: Logits tensor.
"""
logits = torch.nn.functional.linear(input, self.weight)
return logits
@abstractmethod
def routing(self, logits: torch.Tensor):
"""Routing function.
Args:
logits (torch.Tensor): Logits tensor.
Returns:
Tuple[torch.Tensor, torch.Tensor]: Tuple of tensors representing max probs and the indices.
"""
raise NotImplementedError("Routing function not implemented.")
@abstractmethod
def forward(self, input: torch.Tensor):
"""
Forward pass of the router.
Args:
input (torch.Tensor): Input tensor.
"""
raise NotImplementedError("Forward function not implemented.")
def set_layer_number(self, layer_number: int):
"""Set the layer number for the router."""
self.layer_number = layer_number
class TopKRouter(Router):
"""Route each token to the top-k experts."""
def __init__(self, config: TransformerConfig,) -> None:
"""Initialize the zero token dropping router.
Args:
config (TransformerConfig): The configuration for the transformer model.
"""
super().__init__(config=config)
self.topk = self.config.moe_router_topk
self.routing_type = self.config.moe_router_load_balancing_type
self.input_jitter = None
def sinkhorn_load_balancing(self, logits: torch.Tensor):
"""Apply sinkhorn routing to the logits tensor.
Args:
logits (torch.Tensor): The logits tensor.
Returns:
torch.Tensor: The logits tensor after applying sinkhorn routing.
"""
def _sinkhorn_activation(logits):
if self.topk == 1:
logits = torch.sigmoid(logits)
else: # k > 1
logits = torch.softmax(logits, dim=-1, dtype=torch.float32).type_as(logits)
return logits
assert self.config.moe_aux_loss_coeff == 0, "Sinkhorn routing does not support aux loss."
if self.training:
with torch.no_grad():
norm_logits = sinkhorn(
logits.to(dtype=torch.float32)
) # explicit fp32 conversion for stability
_, indices = torch.topk(norm_logits, k=self.topk, dim=1)
logits = _sinkhorn_activation(logits)
scores = torch.gather(logits, 1, indices)
else:
logits = _sinkhorn_activation(logits)
scores, indices = torch.topk(logits, k=self.topk, dim=1)
return scores, indices
def aux_loss_load_balancing(self, logits: torch.Tensor):
"""Apply loss-based load balancing to the logits tensor.
Args:
logits (torch.Tensor): the logits tensor after gating, shape: [num_tokens, num_experts].
Returns:
probs (torch.Tensor): the probabilities tensor after load balancing.
indices (torch.Tensor): the indices tensor after top-k selection.
"""
probs, indices, tokens_per_expert = topk_softmax_with_capacity(
logits,
self.topk,
capacity_factor=self.config.moe_expert_capacity_factor,
pad_to_capacity=self.config.moe_pad_expert_input_to_capacity,
drop_policy=self.config.moe_token_drop_policy,
)
# Apply load balancing loss
scores = torch.softmax(logits, dim=-1, dtype=torch.float32)
probs = self.apply_load_balancing_loss(scores, tokens_per_expert, activation=probs)
return probs, indices
def apply_load_balancing_loss(
self,
probs: torch.Tensor,
num_local_tokens_per_expert: torch.Tensor,
activation: torch.Tensor,
):
"""Applies auxiliary loss to the MoE layer.
Args:
probs (torch.Tensor): The probs output by the router for each token. [num_tokens, num_experts]
num_local_tokens_per_expert (torch.Tensor): The number of tokens per expert. [num_experts]
activation (torch.Tensor): The activation tensor to attach the gradient function to.
Returns:
torch.Tensor: The activation tensor with the attached gradient function.
"""
moe_aux_loss_coeff = (
self.config.moe_aux_loss_coeff / parallel_state.get_tensor_model_parallel_world_size()
)
aux_loss = switch_load_balancing_loss_func(
probs, num_local_tokens_per_expert, self.topk, moe_aux_loss_coeff
)
save_to_aux_losses_tracker(
"load_balancing_loss",
aux_loss / moe_aux_loss_coeff,
self.layer_number,
self.config.num_layers,
)
activation = MoEAuxLossAutoScaler.apply(activation, aux_loss)
return activation
def apply_z_loss(self, logits):
"""Encourages the router's logits to remain small to enhance stability.
Please refer to the ST-MoE paper (https://arxiv.org/pdf/2202.08906.pdf) for details.
Args:
logits (torch.Tensor): The logits of the router.
Returns:
torch.Tensor: The logits after applying the z-loss.
"""
if self.config.moe_z_loss_coeff is not None:
moe_z_loss_coeff = (
self.config.moe_z_loss_coeff / parallel_state.get_tensor_model_parallel_world_size()
)
z_loss = z_loss_func(logits, moe_z_loss_coeff)
logits = MoEAuxLossAutoScaler.apply(logits, z_loss)
save_to_aux_losses_tracker(
"z_loss",
z_loss / self.config.moe_z_loss_coeff,
self.layer_number,
self.config.num_layers,
)
return logits
def apply_input_jitter(self, input: torch.Tensor):
"""Add noise to the input tensor.
Refer to https://arxiv.org/abs/2101.03961.
Args:
input (Tensor): Input tensor.
Returns:
Tensor: Jittered input.
"""
if self.config.moe_input_jitter_eps is not None:
eps = self.config.moe_input_jitter_eps
if self.input_jitter is None:
self.input_jitter = torch.distributions.uniform.Uniform(
torch.tensor(1.0 - eps, device=input.device),
torch.tensor(1.0 + eps, device=input.device),
).rsample
return input * self.input_jitter(input.shape)
else:
return input
def routing(self, logits: torch.Tensor):
"""Top-k routing function
Args:
logits (torch.Tensor): Logits tensor after gating.
Returns:
probs (torch.Tensor): the probabilities tensor after load balancing.
indices (torch.Tensor): the indices tensor after top-k selection.
"""
logits = logits.view(-1, self.config.num_moe_experts)
# Apply Z-Loss
logits = self.apply_z_loss(logits)
if (
parallel_state.get_tensor_model_parallel_world_size() > 1
and self.config.moe_token_dispatcher_type == "alltoall"
):
# Gather the logits from the TP region
logits = gather_from_sequence_parallel_region(logits)
if self.routing_type == "sinkhorn":
scores, indices = self.sinkhorn_load_balancing(logits)
elif self.routing_type == "aux_loss":
scores, indices = self.aux_loss_load_balancing(logits)
elif self.routing_type == "none":
# A naive top-k routing without load balancing
scores, indices, _ = topk_softmax_with_capacity(
logits,
self.topk,
capacity_factor=self.config.moe_expert_capacity_factor,
pad_to_capacity=self.config.moe_pad_expert_input_to_capacity,
drop_policy=self.config.moe_token_drop_policy,
)
else:
raise ValueError(f"Unsupported MoE routing type: {self.routing_type}")
return scores, indices
def forward(self, input: torch.Tensor):
"""
Forward pass of the router.
Args:
input (torch.Tensor): Input tensor.
"""
self.hidden = input.shape[-1]
# Apply input jitter
input = self.apply_input_jitter(input)
logits = self.gating(input)
logits = logits.view(-1, self.config.num_moe_experts)
scores, indices = self.routing(logits)
return scores, indices