megatron_patch/model/mixtral_bak/moe/moe_utils.py (39 lines of code) (raw):

# Copyright (c) 2023 Alibaba PAI and Nvidia Megatron-LM Team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch def switch_load_balancing_loss_func(gates, mask, moe_aux_loss_coeff): """Calculate the auxiliary loss for better load balacing. Please refer to the Switch Transformer paper (https://arxiv.org/abs/2101.03961) for details. Args: gates (torch.Tensor): The gates tensor representing the routing probabilities for each expert. mask (torch.Tensor): The 2D mask tensor indicating which experts are selected. Returns: torch.Tensor: The auxiliary loss for load balancing. """ num_experts = mask.size(-1) gates_mean = gates.mean(dim=0) selection_mean = mask.float().mean(dim=0) aux_loss = torch.sum(gates_mean * selection_mean) * num_experts aux_loss *= moe_aux_loss_coeff return aux_loss def z_loss_func(logits, z_loss_coeff): """Encourages the router's logits to remain small to enhance stability. Please refer to the ST-MoE paper (https://arxiv.org/pdf/2202.08906.pdf) for details. Args: logits (torch.Tensor): The logits of the router. Returns: torch.Tensor: The logits after applying the z-loss. """ z_loss = torch.mean(torch.square(torch.logsumexp(logits, dim=-1))) * z_loss_coeff return z_loss def sinkhorn(cost: torch.Tensor, tol: float = 0.0001): """Sinkhorn based MoE routing function Args: cost: A 2D tensor representing the cost matrix to be normalized. tol: A float value specifying the tolerance for convergence. Default is 0.0001. Returns: A 2D tensor representing the doubly stochastic matrix after Sinkhorn normalization. """ cost = torch.exp(cost) d0 = torch.ones(cost.size(0), device=cost.device, dtype=cost.dtype) d1 = torch.ones(cost.size(1), device=cost.device, dtype=cost.dtype) eps = 0.00000001 error = 1e9 d1_old = d1 while error > tol: d0 = (1 / d0.size(0)) * 1 / (torch.sum(d1 * cost, 1) + eps) d1 = (1 / d1.size(0)) * 1 / (torch.sum(d0.unsqueeze(1) * cost, 0) + eps) error = torch.mean(torch.abs(d1_old - d1)) d1_old = d1 return d1 * cost * d0.unsqueeze(1) class MoEAuxLossAutoScaler(torch.autograd.Function): """An AutoScaler that compute and scales the grad for auxiliary loss. """ main_loss_backward_scale: int = 1 @staticmethod def forward(ctx, output: torch.Tensor, aux_loss: torch.Tensor): """Preserve the aux_loss by storing it in the context to avoid garbage collection. Args: output (torch.Tensor): The output tensor. aux_loss (torch.Tensor): The auxiliary loss tensor. Returns: torch.Tensor: The output tensor. """ ctx.save_for_backward(aux_loss) return output @staticmethod def backward(ctx, grad_output: torch.Tensor): """Compute and scale the gradient for auxiliary loss.. Args: grad_output (torch.Tensor): The gradient of the output. Returns: Tuple[torch.Tensor, torch.Tensor]: The gradient of the output, scaled auxiliary loss gradient. """ (aux_loss,) = ctx.saved_tensors aux_loss_backward_scale = MoEAuxLossAutoScaler.main_loss_backward_scale scaled_aux_loss_grad = torch.ones_like(aux_loss) * aux_loss_backward_scale return grad_output, scaled_aux_loss_grad @staticmethod def set_loss_scale(scale: int): """set the scale of the aux loss. Args: scale (int): The scale value to set. Please ensure that the scale passed in matches the scale of the main_loss. """ MoEAuxLossAutoScaler.main_loss_backward_scale = scale