in optimum/habana/transformers/models/deepseek_v3/modeling_deepseek_v3.py [0:0]
def forward(self, hidden_states):
identity = hidden_states
orig_shape = hidden_states.shape
topk_idx, topk_weight = self.gate(hidden_states)
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
# we cast back to the input dtype
topk_weight = topk_weight.to(hidden_states.dtype)
batch = orig_shape[0]
sequence_length = orig_shape[1]
hidden_dim = orig_shape[2]
# changes for expert parallelism -- replacement for moe_infer()
if self.training:
padded_weights = torch.zeros(
(batch * sequence_length, self.config.n_routed_experts),
dtype=topk_weight.dtype,
device=topk_weight.device,
)
padded_weights.scatter_(-1, topk_idx, topk_weight)
padded_weights = padded_weights.reshape(-1, sequence_length, self.config.n_routed_experts)
padded_weights = padded_weights.permute(2, 0, 1).unsqueeze(-1)
final_hidden_states = torch.zeros(
(batch, sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
)
for i, expert in enumerate(self.experts):
current_hidden_state = expert(hidden_states)
current_padded_weight = padded_weights[i]
final_hidden_states = (
final_hidden_states
+ current_hidden_state.reshape(-1, sequence_length, hidden_dim) * current_padded_weight
)
final_hidden_states = final_hidden_states.type(hidden_states.dtype)
final_hidden_states = final_hidden_states.view(*orig_shape)
# final_hidden_states = AddAuxiliaryLoss.apply(final_hidden_states, aux_loss)
else:
final_hidden_states = torch.zeros(
(batch * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
)
# changes to support hpu fused dynamic MoE op -- replacement for moe_infer()
# loop through expert slices due to limits on max. experts supported by mixture_of_experts op
for idx in range(self.expert_slice):
experts_min = (self.ep_rank * self.experts_per_rank) + (self.expert_chunk * idx)
experts_max = min((experts_min + self.expert_chunk), (self.ep_rank + 1) * self.experts_per_rank)
experts_range = range(experts_min, experts_max)
gate_proj_list = [self.experts[i].gate_proj.weight.squeeze() for i in experts_range]
down_proj_list = [self.experts[i].down_proj.weight.squeeze() for i in experts_range]
up_proj_list = [self.experts[i].up_proj.weight.squeeze() for i in experts_range]
hidden_states_slice = torch.ops.hpu.mixture_of_experts(
hidden_states=hidden_states,
expert_routing_table=topk_idx,
router_weights=topk_weight,
w1=gate_proj_list,
w2=up_proj_list,
w3=down_proj_list,
permuted_weights=True,
activation="silu",
experts_min=experts_min,
experts_max=experts_max - 1,
)
final_hidden_states = final_hidden_states + hidden_states_slice
htcore.mark_step()
if self.ep_size > 1:
final_hidden_states = _all_reduce(final_hidden_states)
elif is_deepspeed_available():
from deepspeed import comm as dist
if dist.is_initialized():
dist.all_reduce(final_hidden_states, op=dist.ReduceOp.SUM)
final_hidden_states = final_hidden_states.type(hidden_states.dtype)
final_hidden_states = final_hidden_states.reshape(-1, sequence_length, hidden_dim)
if self.config.n_shared_experts is not None:
final_hidden_states = final_hidden_states + self.shared_experts(identity)
return final_hidden_states