in torchrec/modules/crossnet.py [0:0]
def forward(self, input: torch.Tensor) -> torch.Tensor:
x_0 = input.unsqueeze(2) # (B, N, 1)
x_l = x_0
for layer in range(self._num_layers):
# set up gating:
if self._num_experts > 1:
gating = []
for i in range(self._num_experts):
# pyre-ignore[16]: `Optional` has no attribute `__getitem__`.
gating.append(self.gates[i](x_l.squeeze(2)))
gating = torch.stack(gating, 1) # (B, K, 1)
# set up experts
experts = []
for i in range(self._num_experts):
expert = torch.matmul(
# pyre-ignore[29]
self.V_kernels[layer][i],
x_l,
) # (B, r, 1)
expert = torch.matmul(
# pyre-ignore[29]
self.C_kernels[layer][i],
self._activation(expert),
) # (B, r, 1)
expert = torch.matmul(
# pyre-ignore[29]
self.U_kernels[layer][i],
self._activation(expert),
) # (B, N, 1)
# pyre-ignore[29]
expert = x_0 * (expert + self.bias[layer]) # (B, N, 1)
experts.append(expert.squeeze(2)) # (B, N)
experts = torch.stack(experts, 2) # (B, N, K)
if self._num_experts > 1:
# MOE update
moe = torch.matmul(
experts,
# pyre-ignore[61]: `gating` may not be initialized here.
torch.nn.functional.softmax(gating, 1),
) # (B, N, 1)
x_l = moe + x_l # (B, N, 1)
else:
x_l = experts + x_l # (B, N, 1)
return torch.squeeze(x_l, dim=2) # (B, N)