in phosa/utils/geometry.py [0:0]
def compute_random_rotations(B=10, upright=False):
"""
Randomly samples rotation matrices.
Args:
B (int): Batch size.
upright (bool): If True, samples rotations that are mostly upright. Otherwise,
samples uniformly from rotation space.
Returns:
rotation_matrices (B x 3 x 3).
"""
if upright:
a1 = torch.FloatTensor(B, 1).uniform_(0, 2 * math.pi)
a2 = torch.FloatTensor(B, 1).uniform_(-math.pi / 6, math.pi / 6)
a3 = torch.FloatTensor(B, 1).uniform_(-math.pi / 12, math.pi / 12)
angles = torch.cat((a1, a2, a3), 1).cuda()
rotation_matrices = euler_angles_to_matrix(angles, "YXZ")
else:
# Reference: J Avro. "Fast Random Rotation Matrices." (1992)
x1, x2, x3 = torch.split(torch.rand(3 * B).cuda(), B)
tau = 2 * math.pi
R = torch.stack(
( # B x 3 x 3
torch.stack(
(torch.cos(tau * x1), torch.sin(tau * x1), torch.zeros_like(x1)), 1
),
torch.stack(
(-torch.sin(tau * x1), torch.cos(tau * x1), torch.zeros_like(x1)), 1
),
torch.stack(
(torch.zeros_like(x1), torch.zeros_like(x1), torch.ones_like(x1)), 1
),
),
1,
)
v = torch.stack(
( # B x 3
torch.cos(tau * x2) * torch.sqrt(x3),
torch.sin(tau * x2) * torch.sqrt(x3),
torch.sqrt(1 - x3),
),
1,
)
identity = torch.eye(3).repeat(B, 1, 1).cuda()
H = identity - 2 * v.unsqueeze(2) * v.unsqueeze(1)
rotation_matrices = -torch.matmul(H, R)
return rotation_matrices