def local_to_global()

in empose/helpers/utils.py [0:0]


def local_to_global(poses, parents, output_format='aa', input_format='aa'):
    """
    Convert relative joint angles to global ones by unrolling the kinematic chain.
    :param poses: A tensor of shape (N, N_JOINTS*3) defining the relative poses in angle-axis format.
    :param parents: A list of parents for each joint j, i.e. parent[j] is the parent of joint j.
    :param output_format: 'aa' or 'rotmat'.
    :param input_format: 'aa' or 'rotmat'
    :return: The global joint angles as a tensor of shape (N, N_JOINTS*DOF).
    """
    assert output_format in ['aa', 'rotmat']
    assert input_format in ['aa', 'rotmat']
    dof = 3 if input_format == 'aa' else 9
    n_joints = poses.shape[-1] // dof
    if input_format == 'aa':
        local_oris = aa2rot(poses.reshape((-1, 3)))
    else:
        local_oris = poses
    local_oris = local_oris.reshape((-1, n_joints, 3, 3))
    global_oris = torch.zeros_like(local_oris)

    for j in range(n_joints):
        if parents[j] < 0:
            # root rotation
            global_oris[..., j, :, :] = local_oris[..., j, :, :]
        else:
            parent_rot = global_oris[..., parents[j], :, :]
            local_rot = local_oris[..., j, :, :]
            global_oris[..., j, :, :] = torch.matmul(parent_rot, local_rot)

    if output_format == 'aa':
        global_oris = rot2aa(global_oris.reshape((-1, 3, 3)))
        res = global_oris.reshape((-1, n_joints * 3))
    else:
        res = global_oris.reshape((-1, n_joints * 3 * 3))
    return res