backends/gaudi/server/text_generation_server/layers/moe/fused_moe.py (97 lines of code) (raw):

# coding=utf-8 # Copyright 2023, 2024 DeepSeek-AI and The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Tuple, Optional import torch def grouped_topk( hidden_states: torch.Tensor, gating_output: torch.Tensor, topk: int, renormalize: bool, num_expert_group: int = 0, topk_group: int = 0, scoring_func: str = "softmax", e_score_correction_bias: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch" gating_output = gating_output.float() if e_score_correction_bias is not None: e_score_correction_bias = e_score_correction_bias.float() if scoring_func == "softmax": scores = torch.softmax(gating_output, dim=-1) elif scoring_func == "sigmoid": scores = gating_output.sigmoid() else: raise ValueError(f"Unsupported scoring function: {scoring_func}") num_token = scores.shape[0] if e_score_correction_bias is not None: # Store original scores before applying correction bias. We use biased # scores for expert selection but original scores for routing weights original_scores = scores scores = scores + e_score_correction_bias.unsqueeze(0) group_scores = ( scores.view(num_token, num_expert_group, -1).topk(2, dim=-1)[0].sum(dim=-1) ) else: group_scores = ( scores.view(num_token, num_expert_group, -1).max(dim=-1).values ) # [n, n_group] group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[ 1 ] # [n, top_k_group] group_mask = torch.zeros_like(group_scores) # [n, n_group] group_mask.scatter_(1, group_idx, 1) # [n, n_group] score_mask = ( group_mask.unsqueeze(-1) .expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group) .reshape(num_token, -1) ) # [n, e] tmp_scores = scores.masked_fill(~score_mask.bool(), float("-inf")) # [n, e] if e_score_correction_bias is not None: topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)[1] # Use original unbiased scores for the routing weights topk_weights = original_scores.gather(1, topk_ids) else: topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False) if renormalize: topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) return topk_weights.to(torch.float32), topk_ids.to(torch.int32) def fused_topk( hidden_states: torch.Tensor, gating_output: torch.Tensor, topk: int, renormalize: bool, ) -> Tuple[torch.Tensor, torch.Tensor]: topk_weights = torch.nn.functional.softmax( gating_output, dim=1, dtype=torch.float32 ) topk_weights, topk_ids = torch.topk(topk_weights, topk, dim=-1) if renormalize: topk_weights /= topk_weights.sum(dim=-1, keepdim=True) return topk_weights, topk_ids def select_experts( hidden_states: torch.Tensor, router_logits: torch.Tensor, top_k: int, use_grouped_topk: bool, renormalize: bool, topk_group: Optional[int] = None, num_expert_group: Optional[int] = None, scoring_func: str = "softmax", e_score_correction_bias: Optional[torch.Tensor] = None, ): # DeekSeekv2 uses grouped_top_k if use_grouped_topk: assert topk_group is not None assert num_expert_group is not None topk_weights, topk_ids = grouped_topk( hidden_states=hidden_states, gating_output=router_logits, topk=top_k, renormalize=renormalize, num_expert_group=num_expert_group, topk_group=topk_group, scoring_func=scoring_func, e_score_correction_bias=e_score_correction_bias, ) else: topk_weights, topk_ids = fused_topk( hidden_states=hidden_states, gating_output=router_logits, topk=top_k, renormalize=renormalize, ) return topk_weights, topk_ids