megatron_patch/model/qwen2/moe/router.py (206 lines of code) (raw):

# Copyright (c) 2024 Alibaba PAI and Nvidia Megatron-LM Team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from abc import ABC, abstractmethod import torch from megatron.core import parallel_state from megatron.core.tensor_parallel import ( gather_from_sequence_parallel_region, get_cuda_rng_tracker, get_data_parallel_rng_tracker_name, ) from megatron.core.tensor_parallel.random import ( get_cuda_rng_tracker, get_data_parallel_rng_tracker_name, ) from megatron.core.transformer.module import MegatronModule from megatron.core.transformer.moe.moe_utils import ( MoEAuxLossAutoScaler, save_to_aux_losses_tracker, sinkhorn, get_capacity, switch_load_balancing_loss_func, z_loss_func, ) from megatron.core.transformer.transformer_config import TransformerConfig def topk_softmax_with_capacity( logits: torch.Tensor, topk: int, capacity_factor: float = None, pad_to_capacity: bool = False, drop_policy: str = "probs", ): """Apply capacity and padding to the top-k selection. Args: logits (torch.Tensor): Logits tensor. topk (int): The number of experts to select for each token. capacity_factor (int): The capacity factor of each expert. Will drop tokens if the number of tokens exceeds the capacity. pad_to_capacity (bool): Whether to need padding in token drop mode. drop_policy (str): The policy to drop tokens. Can be either "prob" or "position". If "prob", the tokens with the lowest probabilities will be dropped. If "position", tokens at the end of each batch will be dropped. Returns: Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Probs, indices and tokens_per_expert tensor. (1) If there's no token padding, the shape of probs and indices is [tokens, top_k], indicating the selected experts for each token. (2) If there's token padding, the shape of probs and indices is [num_expert, capacity], indicating the tokens selected for each expert. """ # TODO: Add Pre softmax. assert logits.dim() == 2, f"Expected 2D logits [num_tokens, num_experts], got {logits.dim()}." num_tokens = logits.shape[0] num_experts = logits.shape[1] #scores, top_indices = torch.topk(logits, k=topk, dim=1) #probs = torch.softmax(scores, dim=-1, dtype=torch.float32).type_as(logits) routing_weights = torch.softmax(logits, dim=1, dtype=torch.float32).type_as(logits) probs, top_indices = torch.topk(routing_weights, k=topk, dim=-1) if capacity_factor is None: # TopK without capacity tokens_per_expert = torch.histc(top_indices, bins=num_experts, min=0, max=num_experts) return probs, top_indices, tokens_per_expert else: # TopK with capacity expert_capacity = get_capacity( num_tokens=num_tokens * topk, num_experts=num_experts, capacity_factor=capacity_factor, ) # TopK selection, Maskout unused experts topk_masked_gates = torch.zeros_like(logits).scatter(1, top_indices, probs) topk_mask = torch.zeros_like(logits).scatter(1, top_indices, 1) # Maskout exceeded tokens if drop_policy == "probs": capacity_probs, capacity_indices = torch.topk( topk_masked_gates, k=expert_capacity, dim=0, sorted=False ) capacity_mask = torch.zeros_like(logits).scatter(0, capacity_indices, 1) elif drop_policy == "position": _, capacity_indices = torch.topk(topk_mask, k=expert_capacity, dim=0, sorted=False) capacity_mask = torch.zeros_like(logits).scatter(0, capacity_indices, 1) capacity_probs = torch.gather(topk_masked_gates, 0, capacity_indices) else: raise ValueError(f"Invalid drop_policy: {drop_policy}") if pad_to_capacity: final_probs, final_indices = ( capacity_probs.T.contiguous(), capacity_indices.T.contiguous(), ) tokens_per_expert_before_capacity = topk_mask.sum(dim=0) else: # Get exceed mask and maskout exceeded probs and indices final_mask = torch.logical_and(topk_mask, capacity_mask) drop_mask = torch.logical_not(final_mask) exceed_mask = torch.gather(drop_mask, 1, top_indices) final_probs = probs * torch.logical_not(exceed_mask) final_indices = top_indices.clone().masked_fill_( exceed_mask, torch.iinfo(torch.long).max ) tokens_per_expert_before_capacity = topk_mask.sum(dim=0) return final_probs, final_indices, tokens_per_expert_before_capacity class Router(ABC, MegatronModule): """Base Router class""" def __init__(self, config: TransformerConfig) -> None: """ Initialize the Router module. Args: config (TransformerConfig): Configuration object for the Transformer model. """ super().__init__(config) self.config = config self.num_experts = self.config.num_moe_experts self.moe_aux_loss_func = None self.layer_number = None # Initialize the gate weights. self.weight = torch.nn.Parameter( torch.empty((self.config.num_moe_experts, self.config.hidden_size)) ) with get_cuda_rng_tracker().fork(get_data_parallel_rng_tracker_name()): config.init_method(self.weight) setattr(self.weight, 'sequence_parallel', config.sequence_parallel) def gating(self, input: torch.Tensor): """Forward pass of the router gate. Args: input (torch.Tensor): Input tensor. Returns: torch.Tensor: Logits tensor. """ logits = torch.nn.functional.linear(input, self.weight) return logits @abstractmethod def routing(self, logits: torch.Tensor): """Routing function. Args: logits (torch.Tensor): Logits tensor. Returns: Tuple[torch.Tensor, torch.Tensor]: Tuple of tensors representing max probs and the indices. """ raise NotImplementedError("Routing function not implemented.") @abstractmethod def forward(self, input: torch.Tensor): """ Forward pass of the router. Args: input (torch.Tensor): Input tensor. """ raise NotImplementedError("Forward function not implemented.") def set_layer_number(self, layer_number: int): """Set the layer number for the router.""" self.layer_number = layer_number class TopKRouter(Router): """Route each token to the top-k experts.""" def __init__(self, config: TransformerConfig,) -> None: """Initialize the zero token dropping router. Args: config (TransformerConfig): The configuration for the transformer model. """ super().__init__(config=config) self.topk = self.config.moe_router_topk self.routing_type = self.config.moe_router_load_balancing_type self.input_jitter = None def sinkhorn_load_balancing(self, logits: torch.Tensor): """Apply sinkhorn routing to the logits tensor. Args: logits (torch.Tensor): The logits tensor. Returns: torch.Tensor: The logits tensor after applying sinkhorn routing. """ def _sinkhorn_activation(logits): if self.topk == 1: logits = torch.sigmoid(logits) else: # k > 1 logits = torch.softmax(logits, dim=-1, dtype=torch.float32).type_as(logits) return logits assert self.config.moe_aux_loss_coeff == 0, "Sinkhorn routing does not support aux loss." if self.training: with torch.no_grad(): norm_logits = sinkhorn( logits.to(dtype=torch.float32) ) # explicit fp32 conversion for stability _, indices = torch.topk(norm_logits, k=self.topk, dim=1) logits = _sinkhorn_activation(logits) scores = torch.gather(logits, 1, indices) else: logits = _sinkhorn_activation(logits) scores, indices = torch.topk(logits, k=self.topk, dim=1) return scores, indices def aux_loss_load_balancing(self, logits: torch.Tensor): """Apply loss-based load balancing to the logits tensor. Args: logits (torch.Tensor): the logits tensor after gating, shape: [num_tokens, num_experts]. Returns: probs (torch.Tensor): the probabilities tensor after load balancing. indices (torch.Tensor): the indices tensor after top-k selection. """ probs, indices, tokens_per_expert = topk_softmax_with_capacity( logits, self.topk, capacity_factor=self.config.moe_expert_capacity_factor, pad_to_capacity=self.config.moe_pad_expert_input_to_capacity, drop_policy=self.config.moe_token_drop_policy, ) # Apply load balancing loss scores = torch.softmax(logits, dim=-1, dtype=torch.float32) probs = self.apply_load_balancing_loss(scores, tokens_per_expert, activation=probs) return probs, indices def apply_load_balancing_loss( self, probs: torch.Tensor, num_local_tokens_per_expert: torch.Tensor, activation: torch.Tensor, ): """Applies auxiliary loss to the MoE layer. Args: probs (torch.Tensor): The probs output by the router for each token. [num_tokens, num_experts] num_local_tokens_per_expert (torch.Tensor): The number of tokens per expert. [num_experts] activation (torch.Tensor): The activation tensor to attach the gradient function to. Returns: torch.Tensor: The activation tensor with the attached gradient function. """ moe_aux_loss_coeff = ( self.config.moe_aux_loss_coeff / parallel_state.get_tensor_model_parallel_world_size() ) aux_loss = switch_load_balancing_loss_func( probs, num_local_tokens_per_expert, self.topk, moe_aux_loss_coeff ) save_to_aux_losses_tracker( "load_balancing_loss", aux_loss / moe_aux_loss_coeff, self.layer_number, self.config.num_layers, ) activation = MoEAuxLossAutoScaler.apply(activation, aux_loss) return activation def apply_z_loss(self, logits): """Encourages the router's logits to remain small to enhance stability. Please refer to the ST-MoE paper (https://arxiv.org/pdf/2202.08906.pdf) for details. Args: logits (torch.Tensor): The logits of the router. Returns: torch.Tensor: The logits after applying the z-loss. """ if self.config.moe_z_loss_coeff is not None: moe_z_loss_coeff = ( self.config.moe_z_loss_coeff / parallel_state.get_tensor_model_parallel_world_size() ) z_loss = z_loss_func(logits, moe_z_loss_coeff) logits = MoEAuxLossAutoScaler.apply(logits, z_loss) save_to_aux_losses_tracker( "z_loss", z_loss / self.config.moe_z_loss_coeff, self.layer_number, self.config.num_layers, ) return logits def apply_input_jitter(self, input: torch.Tensor): """Add noise to the input tensor. Refer to https://arxiv.org/abs/2101.03961. Args: input (Tensor): Input tensor. Returns: Tensor: Jittered input. """ if self.config.moe_input_jitter_eps is not None: eps = self.config.moe_input_jitter_eps if self.input_jitter is None: self.input_jitter = torch.distributions.uniform.Uniform( torch.tensor(1.0 - eps, device=input.device), torch.tensor(1.0 + eps, device=input.device), ).rsample return input * self.input_jitter(input.shape) else: return input def routing(self, logits: torch.Tensor): """Top-k routing function Args: logits (torch.Tensor): Logits tensor after gating. Returns: probs (torch.Tensor): the probabilities tensor after load balancing. indices (torch.Tensor): the indices tensor after top-k selection. """ logits = logits.view(-1, self.config.num_moe_experts) # Apply Z-Loss logits = self.apply_z_loss(logits) if ( parallel_state.get_tensor_model_parallel_world_size() > 1 and self.config.moe_token_dispatcher_type == "alltoall" ): # Gather the logits from the TP region logits = gather_from_sequence_parallel_region(logits) if self.routing_type == "sinkhorn": scores, indices = self.sinkhorn_load_balancing(logits) elif self.routing_type == "aux_loss": scores, indices = self.aux_loss_load_balancing(logits) elif self.routing_type == "none": # A naive top-k routing without load balancing scores, indices, _ = topk_softmax_with_capacity( logits, self.topk, capacity_factor=self.config.moe_expert_capacity_factor, pad_to_capacity=self.config.moe_pad_expert_input_to_capacity, drop_policy=self.config.moe_token_drop_policy, ) else: raise ValueError(f"Unsupported MoE routing type: {self.routing_type}") return scores, indices def forward(self, input: torch.Tensor): """ Forward pass of the router. Args: input (torch.Tensor): Input tensor. """ self.hidden = input.shape[-1] # Apply input jitter input = self.apply_input_jitter(input) logits = self.gating(input) logits = logits.view(-1, self.config.num_moe_experts) scores, indices = self.routing(logits) return scores, indices