megatron_patch/model/qwen3_moe/moe/moe_utils.py (79 lines of code) (raw):
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import math
from typing import Optional
import torch
from megatron.core.transformer.moe.moe_utils import get_capacity, group_limited_topk
def topk_softmax_with_capacity(
logits: torch.Tensor,
topk: int,
capacity_factor: Optional[float] = None,
pad_to_capacity: bool = False,
drop_policy: str = "probs",
use_pre_softmax: bool = False,
num_groups: Optional[int] = None,
group_topk: Optional[int] = None,
scaling_factor: Optional[float] = None,
deterministic_mode: bool = False,
score_function: str = "softmax",
expert_bias: Optional[torch.Tensor] = None,
):
"""Apply capacity and padding to the top-k selection.
Args:
logits (torch.Tensor): Logits tensor.
topk (int): The number of experts to select for each token.
capacity_factor (float): The capacity factor of each expert. Will drop tokens if the number
of tokens exceeds the capacity.
pad_to_capacity (bool): Whether to need padding in token drop mode. The probs for padded
tokens will be 0.
drop_policy (str): The policy to drop tokens. Can be either "prob" or "position".
If "prob", the tokens with the lowest probabilities will be dropped.
If "position", tokens at the end of each batch will be dropped.
use_pre_softmax (bool): Whether to apply softmax before top-k selection.
num_groups (int): Number of groups for routed experts.
group_topk (int): Number of selected groups for each token.
scaling_factor (float): Scaling factor of routing score in top-k selection.
deterministic_mode (bool): Deprecated.
score_function (str): The score function to use. Can be either "softmax" or "sigmoid".
expert_bias (torch.Tensor): The bias added to logits for expert routing.
Returns:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
- routing_probs (torch.Tensor): A tensor of shape [num_tokens, num_experts] containing
the routing probabilities for each token to each expert.
- routing_map (torch.Tensor): A mask tensor of shape [num_tokens, num_experts]
indicating which experts were selected for each token. True values represent
the selected experts.
- tokens_per_expert (torch.Tensor): A tensor of shape [num_experts] containing
the number of local tokens assigned to each expert before dropping and padding.
"""
assert logits.dim() == 2, f"Expected 2D logits [num_tokens, num_experts], got {logits.dim()}."
num_tokens, num_experts = logits.shape
def compute_topk(scores, topk, num_groups=None, group_topk=None):
if group_topk:
return group_limited_topk(
scores=scores,
topk=topk,
num_tokens=num_tokens,
num_experts=num_experts,
num_groups=num_groups,
group_topk=group_topk,
)
else:
return torch.topk(scores, k=topk, dim=1)
if score_function == "softmax":
if use_pre_softmax:
scores = torch.softmax(logits, dim=-1, dtype=torch.float32).type_as(logits)
probs, top_indices = compute_topk(scores, topk, num_groups, group_topk)
else:
scores, top_indices = compute_topk(logits, topk, num_groups, group_topk)
probs = torch.softmax(scores, dim=-1, dtype=torch.float32).type_as(logits)
# NOTE: hard code norm_topk_prob here
probs = probs / (probs.sum(dim=-1, keepdim=True) + 1e-20)
elif score_function == "sigmoid":
scores = torch.sigmoid(logits)
if expert_bias is not None:
scores_for_routing = scores + expert_bias
_, top_indices = compute_topk(scores_for_routing, topk, num_groups, group_topk)
scores = torch.gather(scores, dim=1, index=top_indices).type_as(logits)
else:
scores, top_indices = compute_topk(scores, topk, num_groups, group_topk)
probs = scores / (scores.sum(dim=-1, keepdim=True) + 1e-20) if topk > 1 else scores
else:
raise ValueError(f"Invalid score_function: {score_function}")
if scaling_factor:
probs = probs * scaling_factor
# TODO Try using element-wise operations instead of scatter?
topk_masked_gates = torch.zeros_like(logits).scatter(1, top_indices, probs)
topk_map = torch.zeros_like(logits).int().scatter(1, top_indices, 1).bool()
tokens_per_expert = topk_map.sum(dim=0)
if capacity_factor is None:
# TopK without capacity
return topk_masked_gates, topk_map, tokens_per_expert
else:
# TopK with capacity
expert_capacity = get_capacity(
num_tokens=num_tokens * topk, num_experts=num_experts, capacity_factor=capacity_factor
)
# Maskout exceeded tokens
if drop_policy == "probs":
_, capacity_indices = torch.topk(
topk_masked_gates, k=expert_capacity, dim=0, sorted=False
)
capacity_mask = torch.zeros_like(logits).scatter(0, capacity_indices, 1).bool()
elif drop_policy == "position":
_, capacity_indices = torch.topk(topk_map.int(), k=expert_capacity, dim=0, sorted=False)
capacity_mask = torch.zeros_like(logits).scatter(0, capacity_indices, 1).bool()
else:
raise ValueError(f"Invalid drop_policy: {drop_policy}")
if pad_to_capacity:
final_map = capacity_mask
final_probs = topk_masked_gates * final_map
else:
# Get exceed mask and maskout exceeded probs and indices
final_map = torch.logical_and(topk_map, capacity_mask)
final_probs = topk_masked_gates * final_map
return final_probs, final_map, tokens_per_expert