megatron_patch/model/qwen3_moe/moe/router.py (111 lines of code) (raw):

# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. from functools import partial import torch from megatron.core.tensor_parallel import gather_from_sequence_parallel_region from megatron.core.transformer.moe.moe_utils import ( sequence_load_balancing_loss_func, switch_load_balancing_loss_func, ) from megatron.core.transformer.moe.router import TopKRouter as _TopKRuter from .moe_utils import topk_softmax_with_capacity class TopKRouter(_TopKRuter): """Route each token to the top-k experts.""" def compute_routing_scores_for_aux_loss(self, logits: torch.Tensor) -> torch.Tensor: """Compute routing scores based on the score function. Args: logits (torch.Tensor): The logits tensor after gating, shape: [num_tokens, num_experts]. Returns: torch.Tensor: The normalized routing scores. """ if self.score_function == "softmax": scores = torch.softmax(logits, dim=-1, dtype=torch.float32) # NOTE: hard code norm_topk_prob here scores = scores / (scores.sum(dim=-1, keepdim=True) + 1e-20) elif self.score_function == "sigmoid": scores = torch.sigmoid(logits) scores = ( scores / (scores.sum(dim=-1, keepdim=True) + 1e-20) if self.topk > 1 else scores ) else: raise ValueError(f"Invalid score_function: {self.score_function}") return scores def aux_loss_load_balancing(self, logits: torch.Tensor): """Apply auxiliary loss-based load balancing to the logits tensor. Args: logits (torch.Tensor): The logits tensor after gating, shape: [num_tokens, num_experts]. Returns: probs (torch.Tensor): The probabilities of token to experts assignment. routing_map (torch.Tensor): The mask of token to experts assignment. """ probs, routing_map, tokens_per_expert = topk_softmax_with_capacity( logits, self.topk, capacity_factor=self.config.moe_expert_capacity_factor, pad_to_capacity=self.config.moe_pad_expert_input_to_capacity, drop_policy=self.config.moe_token_drop_policy, use_pre_softmax=self.config.moe_router_pre_softmax, num_groups=self.config.moe_router_num_groups, group_topk=self.config.moe_router_group_topk, scaling_factor=self.config.moe_router_topk_scaling_factor, deterministic_mode=self.config.deterministic_mode, score_function=self.score_function, expert_bias=self.expert_bias, ) if self.training and torch.is_grad_enabled(): # Apply auxiliary load balancing loss # Skip auxiliary loss calculations when using torch.no_grad() or checkpointing. scores = self.compute_routing_scores_for_aux_loss(logits) aux_loss_func = partial( switch_load_balancing_loss_func, probs=scores, tokens_per_expert=tokens_per_expert, topk=self.topk, ) probs = self.apply_load_balancing_loss( activation=probs, load_balancing_loss_func=aux_loss_func ) return probs, routing_map def seq_aux_loss_load_balancing(self, logits: torch.Tensor, bsz: int, seq_length: int): """Apply sequence-auxiliary loss-based load balancing to the logits tensor. Args: logits (torch.Tensor): The logits tensor after gating, shape: [num_tokens, num_experts]. bsz (int): The batch size. seq_length (int): The sequence length. Returns: probs (torch.Tensor): The probabilities of token to experts assignment. routing_map (torch.Tensor): The mask of token to experts assignment. """ probs, routing_map, tokens_per_expert = topk_softmax_with_capacity( logits, self.topk, capacity_factor=self.config.moe_expert_capacity_factor, pad_to_capacity=self.config.moe_pad_expert_input_to_capacity, drop_policy=self.config.moe_token_drop_policy, use_pre_softmax=self.config.moe_router_pre_softmax, num_groups=self.config.moe_router_num_groups, group_topk=self.config.moe_router_group_topk, scaling_factor=self.config.moe_router_topk_scaling_factor, deterministic_mode=self.config.deterministic_mode, score_function=self.score_function, expert_bias=self.expert_bias, ) if self.training and torch.is_grad_enabled(): # Apply sequence-auxiliary load balancing loss scores = self.compute_routing_scores_for_aux_loss(logits) aux_loss_func = partial( sequence_load_balancing_loss_func, probs=scores, routing_map=routing_map, batch_size=bsz, seq_length=seq_length, topk=self.topk, ) probs = self.apply_load_balancing_loss( activation=probs, load_balancing_loss_func=aux_loss_func ) return probs, routing_map def routing(self, logits: torch.Tensor): """Top-k routing function Args: logits (torch.Tensor): Logits tensor after gating. Returns: probs (torch.Tensor): The probabilities of token to experts assignment. routing_map (torch.Tensor): The mapping of token to experts assignment, with shape [num_tokens, num_experts]. """ seq_length, bsz = logits.shape[:2] logits = logits.view(-1, self.config.num_moe_experts) # Apply Z-Loss logits = self.apply_z_loss(logits) if self.config.moe_token_dispatcher_type == "alltoall_seq": # Gather the logits from the TP region logits = gather_from_sequence_parallel_region(logits) if self.routing_type == "sinkhorn": scores, routing_map = self.sinkhorn_load_balancing(logits) elif self.routing_type == "aux_loss": scores, routing_map = self.aux_loss_load_balancing(logits) elif self.routing_type == "seq_aux_loss": scores, routing_map = self.seq_aux_loss_load_balancing(logits, bsz, seq_length) elif self.routing_type == "none": # A naive top-k routing without load balancing scores, routing_map, _ = topk_softmax_with_capacity( logits, self.topk, capacity_factor=self.config.moe_expert_capacity_factor, pad_to_capacity=self.config.moe_pad_expert_input_to_capacity, drop_policy=self.config.moe_token_drop_policy, use_pre_softmax=self.config.moe_router_pre_softmax, num_groups=self.config.moe_router_num_groups, group_topk=self.config.moe_router_group_topk, scaling_factor=self.config.moe_router_topk_scaling_factor, deterministic_mode=self.config.deterministic_mode, score_function=self.score_function, expert_bias=self.expert_bias, ) else: raise ValueError(f"Unsupported MoE routing type: {self.routing_type}") # Prevent extra local tokens accumulation on evaluation or activation recomputation if self.enable_expert_bias and torch.is_grad_enabled(): with torch.no_grad(): self.local_tokens_per_expert += routing_map.sum(dim=0) return scores, routing_map