in shap_e/models/nn/ops.py [0:0]
def forward(self, x, params=None):
params = self.update(params)
batch_size, *shape, d_in = x.shape
x = x.view(batch_size, -1, d_in)
if params.weight.ndim == 2:
h = torch.einsum("bni,oi->bno", x, params.weight)
elif params.weight.ndim == 3:
h = torch.einsum("bni,boi->bno", x, params.weight)
if params.bias is not None:
h = self._bcast(torch.add, h, params.bias)
if params.scale is not None:
h = self._bcast(torch.mul, h, params.scale)
if params.shift is not None:
h = self._bcast(torch.add, h, params.shift)
h = h.view(batch_size, *shape, -1)
return h