def forward()

in optimum/habana/transformers/models/deepseek_v3/modeling_deepseek_v3.py [0:0]


    def forward(self, hidden_states):
        identity = hidden_states
        orig_shape = hidden_states.shape
        topk_idx, topk_weight = self.gate(hidden_states)
        hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
        # we cast back to the input dtype
        topk_weight = topk_weight.to(hidden_states.dtype)
        batch = orig_shape[0]
        sequence_length = orig_shape[1]
        hidden_dim = orig_shape[2]
        # changes for expert parallelism -- replacement for moe_infer()
        if self.training:
            padded_weights = torch.zeros(
                (batch * sequence_length, self.config.n_routed_experts),
                dtype=topk_weight.dtype,
                device=topk_weight.device,
            )
            padded_weights.scatter_(-1, topk_idx, topk_weight)
            padded_weights = padded_weights.reshape(-1, sequence_length, self.config.n_routed_experts)
            padded_weights = padded_weights.permute(2, 0, 1).unsqueeze(-1)

            final_hidden_states = torch.zeros(
                (batch, sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
            )
            for i, expert in enumerate(self.experts):
                current_hidden_state = expert(hidden_states)
                current_padded_weight = padded_weights[i]
                final_hidden_states = (
                    final_hidden_states
                    + current_hidden_state.reshape(-1, sequence_length, hidden_dim) * current_padded_weight
                )
            final_hidden_states = final_hidden_states.type(hidden_states.dtype)
            final_hidden_states = final_hidden_states.view(*orig_shape)
            # final_hidden_states = AddAuxiliaryLoss.apply(final_hidden_states, aux_loss)
        else:
            final_hidden_states = torch.zeros(
                (batch * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
            )
            # changes to support hpu fused dynamic MoE op -- replacement for moe_infer()
            # loop through expert slices due to limits on max. experts supported by mixture_of_experts op
            for idx in range(self.expert_slice):
                experts_min = (self.ep_rank * self.experts_per_rank) + (self.expert_chunk * idx)
                experts_max = min((experts_min + self.expert_chunk), (self.ep_rank + 1) * self.experts_per_rank)
                experts_range = range(experts_min, experts_max)
                gate_proj_list = [self.experts[i].gate_proj.weight.squeeze() for i in experts_range]
                down_proj_list = [self.experts[i].down_proj.weight.squeeze() for i in experts_range]
                up_proj_list = [self.experts[i].up_proj.weight.squeeze() for i in experts_range]

                hidden_states_slice = torch.ops.hpu.mixture_of_experts(
                    hidden_states=hidden_states,
                    expert_routing_table=topk_idx,
                    router_weights=topk_weight,
                    w1=gate_proj_list,
                    w2=up_proj_list,
                    w3=down_proj_list,
                    permuted_weights=True,
                    activation="silu",
                    experts_min=experts_min,
                    experts_max=experts_max - 1,
                )
                final_hidden_states = final_hidden_states + hidden_states_slice
                htcore.mark_step()

            if self.ep_size > 1:
                final_hidden_states = _all_reduce(final_hidden_states)
            elif is_deepspeed_available():
                from deepspeed import comm as dist

                if dist.is_initialized():
                    dist.all_reduce(final_hidden_states, op=dist.ReduceOp.SUM)

            final_hidden_states = final_hidden_states.type(hidden_states.dtype)
            final_hidden_states = final_hidden_states.reshape(-1, sequence_length, hidden_dim)

        if self.config.n_shared_experts is not None:
            final_hidden_states = final_hidden_states + self.shared_experts(identity)

        return final_hidden_states