in deepseek_vl2/models/modeling_deepseek.py [0:0]
def forward(self, hidden_states):
identity = hidden_states
orig_shape = hidden_states.shape
topk_idx, topk_weight, aux_loss = self.gate(hidden_states)
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
flat_topk_idx = topk_idx.view(-1)
if self.training:
hidden_states = hidden_states.repeat_interleave(
self.num_experts_per_tok, dim=0
)
y = torch.empty_like(hidden_states)
for i, expert in enumerate(self.experts):
y[flat_topk_idx == i] = expert(hidden_states[flat_topk_idx == i])
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
y = y.to(hidden_states.dtype).view(*orig_shape)
y = AddAuxiliaryLoss.apply(y, aux_loss)
else:
y = self.moe_infer(hidden_states, topk_idx, topk_weight).view(*orig_shape)
if self.config.n_shared_experts is not None:
y = y + self.shared_experts(identity)
return y