in models/utils.py [0:0]
def forward(self, rvec):
theta = torch.sqrt(1e-5 + torch.sum(rvec ** 2, dim=1))
rvec = rvec / theta[:, None]
return torch.stack((
1. - 2. * rvec[:, 1] ** 2 - 2. * rvec[:, 2] ** 2,
2. * (rvec[:, 0] * rvec[:, 1] - rvec[:, 2] * rvec[:, 3]),
2. * (rvec[:, 0] * rvec[:, 2] + rvec[:, 1] * rvec[:, 3]),
2. * (rvec[:, 0] * rvec[:, 1] + rvec[:, 2] * rvec[:, 3]),
1. - 2. * rvec[:, 0] ** 2 - 2. * rvec[:, 2] ** 2,
2. * (rvec[:, 1] * rvec[:, 2] - rvec[:, 0] * rvec[:, 3]),
2. * (rvec[:, 0] * rvec[:, 2] - rvec[:, 1] * rvec[:, 3]),
2. * (rvec[:, 0] * rvec[:, 3] + rvec[:, 1] * rvec[:, 2]),
1. - 2. * rvec[:, 0] ** 2 - 2. * rvec[:, 1] ** 2
), dim=1).view(-1, 3, 3)