def finger_direction_loss()

in src/utils.py [0:0]


def finger_direction_loss(input, target, config, body, **kwargs):
    target_pos = config \
        .out_features.features_to_pose() \
        .solve_batch(target, ref_positions=None, mask=None, bone_lengths=kwargs["bone_lengths"], body=body) \
        .reshape((input.shape[0], -1, 3))

    out_pos = config \
        .out_features.features_to_pose() \
        .solve_batch(input, ref_positions=None, mask=None, bone_lengths=kwargs["bone_lengths"], body=body) \
        .reshape((input.shape[0], -1, 3))

    l_finger_idx, r_finger_idx = config.skeleton.Idx.get_finger_joints()
    # all_finger_idx = [*l_finger_idx, *r_finger_idx]
    all_finger_idx = [*r_finger_idx, *l_finger_idx]

    Idx = config.skeleton.Idx
    l_finger_tri_idx = [
        [Idx.lindex1, Idx.lindex2, Idx.lindex3],
        [Idx.lring1, Idx.lring2, Idx.lring3],
        [Idx.lmiddle1, Idx.lmiddle2, Idx.lmiddle3],
        [Idx.lpinky1, Idx.lpinky2, Idx.lpinky3],
        [Idx.lthumb1, Idx.lthumb2, Idx.lthumb3],
    ]
    r_finger_tri_idx = [
        [Idx.rindex1, Idx.rindex2, Idx.rindex3],
        [Idx.rring1, Idx.rring2, Idx.rring3],
        [Idx.rmiddle1, Idx.rmiddle2, Idx.rmiddle3],
        [Idx.rpinky1, Idx.rpinky2, Idx.rpinky3],
        [Idx.rthumb1, Idx.rthumb2, Idx.rthumb3],
    ]
    # finger_tri_idx = [*l_finger_tri_idx, *r_finger_tri_idx]
    finger_tri_idx = [*r_finger_tri_idx, *l_finger_tri_idx]
    finger_tri_idx = [y for x in finger_tri_idx for y in x]
    target_finger_tris = target_pos[:, finger_tri_idx].reshape((len(target_pos), len(finger_tri_idx) // 3, 3, 3))
    actual_finger_pos = out_pos[:, all_finger_idx]

    # first_finger_joint_idx = [Idx.lindex1, Idx.lring1, Idx.lmiddle1, Idx.lpinky1, Idx.lthumb0, Idx.rindex1, Idx.rring1, Idx.rmiddle1, Idx.rpinky1, Idx.rthumb0]
    first_finger_joint_idx = [Idx.rindex1, Idx.rring1, Idx.rmiddle1, Idx.rpinky1, Idx.rthumb0, Idx.lindex1, Idx.lring1, Idx.lmiddle1, Idx.lpinky1, Idx.lthumb0]

    distances = torch.zeros((len(actual_finger_pos), len(all_finger_idx)), device=input.device)
    for i, idx in enumerate(all_finger_idx):
        # tri_idx = i // 4
        tri_idx = -1
        for first_finger_idx in first_finger_joint_idx:
            if idx >= first_finger_idx:
                tri_idx += 1
            else:
                break
        distance = distance_to_triangle(target_finger_tris[:, tri_idx], actual_finger_pos[:, i])
        distances[:, i] = distance

    return distances