def gaudi_qwen2moe_block_sparse_moe_forward()

in optimum/habana/transformers/models/qwen2_moe/modeling_qwen2_moe.py [0:0]


def gaudi_qwen2moe_block_sparse_moe_forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    - optimize expert forward, remove dynamic control and dynamic shape
    """
    batch_size, sequence_length, hidden_dim = hidden_states.shape
    hidden_states = hidden_states.view(-1, hidden_dim)

    # router_logits: (batch * sequence_length, n_experts)
    router_logits = self.gate(hidden_states)

    routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
    routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)

    if self.norm_topk_prob:
        routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
    # we cast back to the input dtype
    routing_weights = routing_weights.to(hidden_states.dtype)

    if self.training:
        final_hidden_states = torch.zeros(
            (batch_size, sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
        )

        padded_weights = torch.zeros(
            (batch_size * sequence_length, self.num_experts), dtype=hidden_states.dtype, device=hidden_states.device
        )
        padded_weights.scatter_(-1, selected_experts, routing_weights)
        padded_weights = padded_weights.reshape(-1, sequence_length, self.num_experts)
        padded_weights = padded_weights.permute(2, 0, 1).unsqueeze(-1)

        # Loop over all available experts in the model and perform the computation on each expert
        for expert_idx in range(self.num_experts):
            expert_layer = self.experts[expert_idx]
            padded_weight = padded_weights[expert_idx]
            current_state_static = hidden_states.reshape(-1, hidden_dim)
            current_hidden_states_static = (
                expert_layer.pre_mlp_forward(current_state_static).reshape(-1, sequence_length, hidden_dim)
                * padded_weight
            )
            final_hidden_states = final_hidden_states + current_hidden_states_static
    else:
        experts_range = range(self.num_experts)
        w1_list = [self.experts[i].gate_proj.weight.squeeze() for i in experts_range]
        w2_list = [self.experts[i].down_proj.weight.squeeze() for i in experts_range]
        w3_list = [self.experts[i].up_proj.weight.squeeze() for i in experts_range]

        final_hidden_states = torch.ops.hpu.mixture_of_experts(
            hidden_states=hidden_states,
            expert_routing_table=selected_experts,
            router_weights=routing_weights,
            w1=w1_list,
            w2=w3_list,  # Note that there is a different naming convention of w1, w2, and w3 between optimum habana's mixtral model and dynamic MoE kernel.
            w3=w2_list,
            permuted_weights=True,
            activation="silu",
            experts_min=0,
            experts_max=(self.num_experts - 1),
        )
        final_hidden_states = final_hidden_states.reshape(-1, sequence_length, hidden_dim)

        if is_deepspeed_available():
            from deepspeed import comm as dist

            if dist.is_initialized():
                dist.all_reduce(final_hidden_states, op=dist.ReduceOp.SUM)

    shared_expert_output = self.shared_expert(hidden_states)

    shared_expert_output = F.sigmoid(self.shared_expert_gate(hidden_states)) * shared_expert_output

    shared_expert_output = shared_expert_output.reshape(-1, sequence_length, hidden_dim)

    final_hidden_states = final_hidden_states + shared_expert_output

    return final_hidden_states, router_logits