in inference/model.py [0:0]
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass for the MoE module.
Args:
x (torch.Tensor): Input tensor.
Returns:
torch.Tensor: Output tensor after expert routing and computation.
"""
shape = x.size()
x = x.view(-1, self.dim)
weights, indices = self.gate(x)
y = torch.zeros_like(x)
counts = torch.bincount(indices.flatten(), minlength=self.n_routed_experts).tolist()
for i in range(self.experts_start_idx, self.experts_end_idx):
if counts[i] == 0:
continue
expert = self.experts[i]
idx, top = torch.where(indices == i)
y[idx] += expert(x[idx]) * weights[idx, top, None]
z = self.shared_experts(x)
if world_size > 1:
dist.all_reduce(y)
return (y + z).view(shape)