def forward()

in torchrec/modules/crossnet.py [0:0]


    def forward(self, input: torch.Tensor) -> torch.Tensor:
        x_0 = input.unsqueeze(2)  # (B, N, 1)
        x_l = x_0

        for layer in range(self._num_layers):
            # set up gating:
            if self._num_experts > 1:
                gating = []
                for i in range(self._num_experts):
                    # pyre-ignore[16]: `Optional` has no attribute `__getitem__`.
                    gating.append(self.gates[i](x_l.squeeze(2)))
                gating = torch.stack(gating, 1)  # (B, K, 1)

            # set up experts
            experts = []
            for i in range(self._num_experts):
                expert = torch.matmul(
                    # pyre-ignore[29]
                    self.V_kernels[layer][i],
                    x_l,
                )  # (B, r, 1)
                expert = torch.matmul(
                    # pyre-ignore[29]
                    self.C_kernels[layer][i],
                    self._activation(expert),
                )  # (B, r, 1)
                expert = torch.matmul(
                    # pyre-ignore[29]
                    self.U_kernels[layer][i],
                    self._activation(expert),
                )  # (B, N, 1)
                # pyre-ignore[29]
                expert = x_0 * (expert + self.bias[layer])  # (B, N, 1)
                experts.append(expert.squeeze(2))  # (B, N)
            experts = torch.stack(experts, 2)  # (B, N, K)

            if self._num_experts > 1:
                # MOE update
                moe = torch.matmul(
                    experts,
                    # pyre-ignore[61]: `gating` may not be initialized here.
                    torch.nn.functional.softmax(gating, 1),
                )  # (B, N, 1)
                x_l = moe + x_l  # (B, N, 1)
            else:
                x_l = experts + x_l  # (B, N, 1)

        return torch.squeeze(x_l, dim=2)  # (B, N)