in optimum/habana/transformers/models/qwen2_moe/modeling_qwen2_moe.py [0:0]
def gaudi_qwen2moe_block_sparse_moe_forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
- optimize expert forward, remove dynamic control and dynamic shape
"""
batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
# router_logits: (batch * sequence_length, n_experts)
router_logits = self.gate(hidden_states)
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
if self.norm_topk_prob:
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
# we cast back to the input dtype
routing_weights = routing_weights.to(hidden_states.dtype)
if self.training:
final_hidden_states = torch.zeros(
(batch_size, sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
)
padded_weights = torch.zeros(
(batch_size * sequence_length, self.num_experts), dtype=hidden_states.dtype, device=hidden_states.device
)
padded_weights.scatter_(-1, selected_experts, routing_weights)
padded_weights = padded_weights.reshape(-1, sequence_length, self.num_experts)
padded_weights = padded_weights.permute(2, 0, 1).unsqueeze(-1)
# Loop over all available experts in the model and perform the computation on each expert
for expert_idx in range(self.num_experts):
expert_layer = self.experts[expert_idx]
padded_weight = padded_weights[expert_idx]
current_state_static = hidden_states.reshape(-1, hidden_dim)
current_hidden_states_static = (
expert_layer.pre_mlp_forward(current_state_static).reshape(-1, sequence_length, hidden_dim)
* padded_weight
)
final_hidden_states = final_hidden_states + current_hidden_states_static
else:
experts_range = range(self.num_experts)
w1_list = [self.experts[i].gate_proj.weight.squeeze() for i in experts_range]
w2_list = [self.experts[i].down_proj.weight.squeeze() for i in experts_range]
w3_list = [self.experts[i].up_proj.weight.squeeze() for i in experts_range]
final_hidden_states = torch.ops.hpu.mixture_of_experts(
hidden_states=hidden_states,
expert_routing_table=selected_experts,
router_weights=routing_weights,
w1=w1_list,
w2=w3_list, # Note that there is a different naming convention of w1, w2, and w3 between optimum habana's mixtral model and dynamic MoE kernel.
w3=w2_list,
permuted_weights=True,
activation="silu",
experts_min=0,
experts_max=(self.num_experts - 1),
)
final_hidden_states = final_hidden_states.reshape(-1, sequence_length, hidden_dim)
if is_deepspeed_available():
from deepspeed import comm as dist
if dist.is_initialized():
dist.all_reduce(final_hidden_states, op=dist.ReduceOp.SUM)
shared_expert_output = self.shared_expert(hidden_states)
shared_expert_output = F.sigmoid(self.shared_expert_gate(hidden_states)) * shared_expert_output
shared_expert_output = shared_expert_output.reshape(-1, sequence_length, hidden_dim)
final_hidden_states = final_hidden_states + shared_expert_output
return final_hidden_states, router_logits