def fix_bone_lengths()

in src/utils.py [0:0]


def fix_bone_lengths(positions, mask, c, target_lengths, it=10, k=0.95):
    def get_mean_target_position(joint_idx):
        # calculates the target position of the joint using the direction to parent and children and the correct bone length
        # returns the mean
        # note -> because some bones have multiple children, the children also need to have a constant distance between each other
        pos_arr = []
        if parent_idx[joint_idx] >= 0:
            diff = target[:, joint_idx] / actual_lengths[:, joint_idx]
            p_parent = p[:, parent_idx[joint_idx]] if parent_idx[joint_idx] >= 0 else torch.zeros((len(p), 3), device=p.device)
            off = p[:, joint_idx] - p_parent
            pos_arr.append(p_parent + (off * torch_tile(diff[:, None], dim=1, n_tile=3)))

            # calculate the target distance between this bone and the siblings
            for child in child_idx[parent_idx[joint_idx]]:
                if child == joint_idx:
                    continue
                diff = target[:, child] / actual_lengths[:, child]
                p_child = p[:, child]
                off = p[:, joint_idx] - p_child
                pos_arr.append(p_child + (off * torch_tile(diff[:, None], dim=1, n_tile=3)))

        for child in child_idx[joint_idx]:
            diff = target[:, child] / actual_lengths[:, child]
            p_child = p[:, child]
            off = p[:, joint_idx] - p_child
            pos_arr.append(p_child + (off * torch_tile(diff[:, None], dim=1, n_tile=3)))

        pos_arr = torch.stack(pos_arr)
        return torch.mean(pos_arr, dim=0)

    p = torch.clone(positions).to(positions.device)
    new_p = torch.empty_like(p)
    parent_idx = c.skeleton.parent_idx_vector()
    child_idx = c.skeleton.child_idx_vector()
    target = torch_tile(torch.from_numpy(target_lengths[None, :]).to(p.device), dim=0, n_tile=len(p)).float().to(p.device)
    _k = k
    for i in range(it):
        actual_lengths = bone_lengths_batch(p, parent_idx).to(p.device)
        err = torch.abs(actual_lengths / target)
        err[torch.isnan(err)] = 0
        err[torch.isinf(err)] = 0
        print(torch.mean(err))
        for joint_idx in range(len(c.skeleton.Idx.all)):
            mean_p = get_mean_target_position(joint_idx)
            diff = mean_p - p[:, joint_idx]
            joint_mask = mask[:, joint_idx]
            new_p[joint_mask, joint_idx] = (p[:, joint_idx] + diff * _k)[joint_mask]
            new_p[~joint_mask, joint_idx] = p[:, joint_idx][~joint_mask]
        # _k *= k
        p = torch.clone(new_p)

    return p