def compute_random_rotations()

in phosa/utils/geometry.py [0:0]


def compute_random_rotations(B=10, upright=False):
    """
    Randomly samples rotation matrices.

    Args:
        B (int): Batch size.
        upright (bool): If True, samples rotations that are mostly upright. Otherwise,
            samples uniformly from rotation space.

    Returns:
        rotation_matrices (B x 3 x 3).
    """
    if upright:
        a1 = torch.FloatTensor(B, 1).uniform_(0, 2 * math.pi)
        a2 = torch.FloatTensor(B, 1).uniform_(-math.pi / 6, math.pi / 6)
        a3 = torch.FloatTensor(B, 1).uniform_(-math.pi / 12, math.pi / 12)

        angles = torch.cat((a1, a2, a3), 1).cuda()
        rotation_matrices = euler_angles_to_matrix(angles, "YXZ")
    else:
        # Reference: J Avro. "Fast Random Rotation Matrices." (1992)
        x1, x2, x3 = torch.split(torch.rand(3 * B).cuda(), B)
        tau = 2 * math.pi
        R = torch.stack(
            (  # B x 3 x 3
                torch.stack(
                    (torch.cos(tau * x1), torch.sin(tau * x1), torch.zeros_like(x1)), 1
                ),
                torch.stack(
                    (-torch.sin(tau * x1), torch.cos(tau * x1), torch.zeros_like(x1)), 1
                ),
                torch.stack(
                    (torch.zeros_like(x1), torch.zeros_like(x1), torch.ones_like(x1)), 1
                ),
            ),
            1,
        )
        v = torch.stack(
            (  # B x 3
                torch.cos(tau * x2) * torch.sqrt(x3),
                torch.sin(tau * x2) * torch.sqrt(x3),
                torch.sqrt(1 - x3),
            ),
            1,
        )
        identity = torch.eye(3).repeat(B, 1, 1).cuda()
        H = identity - 2 * v.unsqueeze(2) * v.unsqueeze(1)
        rotation_matrices = -torch.matmul(H, R)
    return rotation_matrices