in src/quaternion.py [0:0]
def rotation_matrix_to_quaternion(rotation_matrix, eps=1e-6):
"""Convert 3x4 rotation matrix to 4d quaternion vector
This algorithm is based on algorithm described in
https://github.com/KieranWynn/pyquaternion/blob/master/pyquaternion/quaternion.py#L201
Args:
rotation_matrix (Tensor): the rotation matrix to convert.
Return:
Tensor: the rotation in quaternion
Shape:
- Input: :math:`(N, 3, 4)`
- Output: :math:`(N, 4)`
Example:
input = torch.rand(4, 3, 4) # Nx3x4
output = tgm.rotation_matrix_to_quaternion(input) # Nx4
"""
if not torch.is_tensor(rotation_matrix):
raise TypeError("Input type is not a torch.Tensor. Got {}".format(
type(rotation_matrix)))
if len(rotation_matrix.shape) > 3:
raise ValueError(
"Input size must be a three dimensional tensor. Got {}".format(
rotation_matrix.shape))
if not rotation_matrix.shape[-2:] == (3, 3):
raise ValueError(
"Input size must be a N x 3 x 3 tensor. Got {}".format(
rotation_matrix.shape))
rmat_t = torch.transpose(rotation_matrix, 1, 2)
mask_d2 = rmat_t[:, 2, 2] < eps
mask_d0_d1 = rmat_t[:, 0, 0] > rmat_t[:, 1, 1]
mask_d0_nd1 = rmat_t[:, 0, 0] < -rmat_t[:, 1, 1]
t0 = 1 + rmat_t[:, 0, 0] - rmat_t[:, 1, 1] - rmat_t[:, 2, 2]
q0 = torch.stack([rmat_t[:, 1, 2] - rmat_t[:, 2, 1],
t0, rmat_t[:, 0, 1] + rmat_t[:, 1, 0],
rmat_t[:, 2, 0] + rmat_t[:, 0, 2]], -1)
t0_rep = t0.repeat(4, 1).t()
t1 = 1 - rmat_t[:, 0, 0] + rmat_t[:, 1, 1] - rmat_t[:, 2, 2]
q1 = torch.stack([rmat_t[:, 2, 0] - rmat_t[:, 0, 2],
rmat_t[:, 0, 1] + rmat_t[:, 1, 0],
t1, rmat_t[:, 1, 2] + rmat_t[:, 2, 1]], -1)
t1_rep = t1.repeat(4, 1).t()
t2 = 1 - rmat_t[:, 0, 0] - rmat_t[:, 1, 1] + rmat_t[:, 2, 2]
q2 = torch.stack([rmat_t[:, 0, 1] - rmat_t[:, 1, 0],
rmat_t[:, 2, 0] + rmat_t[:, 0, 2],
rmat_t[:, 1, 2] + rmat_t[:, 2, 1], t2], -1)
t2_rep = t2.repeat(4, 1).t()
t3 = 1 + rmat_t[:, 0, 0] + rmat_t[:, 1, 1] + rmat_t[:, 2, 2]
q3 = torch.stack([t3, rmat_t[:, 1, 2] - rmat_t[:, 2, 1],
rmat_t[:, 2, 0] - rmat_t[:, 0, 2],
rmat_t[:, 0, 1] - rmat_t[:, 1, 0]], -1)
t3_rep = t3.repeat(4, 1).t()
mask_c0 = mask_d2 * mask_d0_d1
mask_c1 = mask_d2 * (~mask_d0_d1)
mask_c2 = (~mask_d2) * mask_d0_nd1
mask_c3 = (~mask_d2) * (~mask_d0_nd1)
mask_c0 = mask_c0.view(-1, 1).type_as(q0)
mask_c1 = mask_c1.view(-1, 1).type_as(q1)
mask_c2 = mask_c2.view(-1, 1).type_as(q2)
mask_c3 = mask_c3.view(-1, 1).type_as(q3)
q = q0 * mask_c0 + q1 * mask_c1 + q2 * mask_c2 + q3 * mask_c3
q /= torch.sqrt(t0_rep * mask_c0 + t1_rep * mask_c1 + # noqa
t2_rep * mask_c2 + t3_rep * mask_c3) # noqa
q *= 0.5
return q