in optimum/habana/transformers/models/deepseek_v2/modeling_deepseek_v2.py [0:0]
def forward(self, hidden_states):
bsz, seq_len, h = hidden_states.shape
### compute gating score
hidden_states = hidden_states.view(-1, h)
logits = F.linear(hidden_states.type(torch.bfloat16), self.weight.type(torch.bfloat16), None).to(
dtype=torch.float32
)
if self.scoring_func == "softmax":
scores = logits.softmax(dim=-1, dtype=torch.float32)
else:
raise NotImplementedError(f"insupportable scoring function for MoE gating: {self.scoring_func}")
### select top-k experts
if self.topk_method == "greedy":
topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=True)
elif self.topk_method == "group_limited_greedy":
group_scores = scores.view(bsz * seq_len, self.n_group, -1).max(dim=-1).values # [n, n_group]
group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=True)[1] # [n, top_k_group]
group_mask = torch.zeros_like(group_scores) # [n, n_group]
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
score_mask = (
group_mask.unsqueeze(-1)
.expand(bsz * seq_len, self.n_group, self.n_routed_experts // self.n_group)
.reshape(bsz * seq_len, -1)
) # [n, e]
tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
topk_weight, topk_idx = torch.topk(tmp_scores, k=self.top_k, dim=-1, sorted=True)
### norm gate to sum 1
if self.top_k > 1 and self.norm_topk_prob:
denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
topk_weight = topk_weight / denominator
else:
topk_weight = topk_weight * self.routed_scaling_factor
### expert-level computation auxiliary loss
if self.training and self.alpha > 0.0:
scores_for_aux = scores
aux_topk = self.top_k
# always compute aux loss based on the naive greedy topk method
topk_idx_for_aux_loss = topk_idx.view(bsz, -1)
if self.seq_aux:
scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1)
ce = torch.zeros(bsz, self.n_routed_experts, device=hidden_states.device)
ce.scatter_add_(
1,
topk_idx_for_aux_loss,
torch.ones(bsz, seq_len * aux_topk, device=hidden_states.device),
).div_(seq_len * aux_topk / self.n_routed_experts)
aux_loss = (ce * scores_for_seq_aux.mean(dim=1)).sum(dim=1).mean() * self.alpha
else:
mask_ce = F.one_hot(topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts)
ce = mask_ce.float().mean(0)
Pi = scores_for_aux.mean(0)
fi = ce * self.n_routed_experts
aux_loss = (Pi * fi).sum() * self.alpha
else:
aux_loss = None
return topk_idx, topk_weight, aux_loss