def forward()

in shap_e/models/nn/ops.py [0:0]


    def forward(self, x, params=None):
        params = self.update(params)

        batch_size, *shape, d_in = x.shape
        x = x.view(batch_size, -1, d_in)

        if params.weight.ndim == 2:
            h = torch.einsum("bni,oi->bno", x, params.weight)
        elif params.weight.ndim == 3:
            h = torch.einsum("bni,boi->bno", x, params.weight)

        if params.bias is not None:
            h = self._bcast(torch.add, h, params.bias)

        if params.scale is not None:
            h = self._bcast(torch.mul, h, params.scale)

        if params.shift is not None:
            h = self._bcast(torch.add, h, params.shift)

        h = h.view(batch_size, *shape, -1)
        return h