def finger_direction_loss_finger_pos_input_only()

in src/utils.py [0:0]


def finger_direction_loss_finger_pos_input_only(input, target, config, body, **kwargs):
    target_pos = target.reshape((input.shape[0], -1, 3))
    out_pos = input.reshape((input.shape[0], -1, 3))

    l_finger_idx, r_finger_idx = config.skeleton.Idx.get_finger_joints()
    # all_finger_idx = [*l_finger_idx, *r_finger_idx]

    Idx = config.skeleton.Idx
    l_finger_tri_idx = [
        [Idx.lindex1, Idx.lindex2, Idx.lindex3],
        [Idx.lring1, Idx.lring2, Idx.lring3],
        [Idx.lmiddle1, Idx.lmiddle2, Idx.lmiddle3],
        [Idx.lpinky1, Idx.lpinky2, Idx.lpinky3],
        [Idx.lthumb1, Idx.lthumb2, Idx.lthumb3],
    ]
    r_finger_tri_idx = [
        [Idx.rindex1, Idx.rindex2, Idx.rindex3],
        [Idx.rring1, Idx.rring2, Idx.rring3],
        [Idx.rmiddle1, Idx.rmiddle2, Idx.rmiddle3],
        [Idx.rpinky1, Idx.rpinky2, Idx.rpinky3],
        [Idx.rthumb1, Idx.rthumb2, Idx.rthumb3],
    ]

    first_finger_joint_idx = [Idx.rindex1, Idx.rring1, Idx.rmiddle1, Idx.rpinky1, Idx.rthumb0, Idx.lindex1, Idx.lring1, Idx.lmiddle1, Idx.lpinky1, Idx.lthumb0]

    l_finger_idx_off = min(l_finger_idx)
    r_finger_idx_off = min(r_finger_idx)
    for i in range(len(l_finger_idx)):
        r_finger_idx[i] -= r_finger_idx_off
        l_finger_idx[i] += len(l_finger_idx) - l_finger_idx_off

    for i in range(len(l_finger_tri_idx)):
        for j in range(len(l_finger_tri_idx[i])):
            r_finger_tri_idx[i][j] -= r_finger_idx_off
            l_finger_tri_idx[i][j] += len(l_finger_idx) - l_finger_idx_off

    for i in range(len(first_finger_joint_idx)):
        first_finger_joint_idx[i] -= (r_finger_idx_off if i < 5 else (l_finger_idx_off - len(l_finger_idx)))

    all_finger_idx = [*r_finger_idx, *l_finger_idx]
    # finger_tri_idx = [*l_finger_tri_idx, *r_finger_tri_idx]
    finger_tri_idx = [*r_finger_tri_idx, *l_finger_tri_idx]
    finger_tri_idx = [y for x in finger_tri_idx for y in x]
    target_finger_tris = target_pos[:, finger_tri_idx].reshape((len(target_pos), len(finger_tri_idx) // 3, 3, 3))
    actual_finger_pos = out_pos[:, all_finger_idx]

    # first_finger_joint_idx = [Idx.lindex1, Idx.lring1, Idx.lmiddle1, Idx.lpinky1, Idx.lthumb0, Idx.rindex1, Idx.rring1, Idx.rmiddle1, Idx.rpinky1, Idx.rthumb0]

    distances = torch.zeros((len(actual_finger_pos), len(all_finger_idx)), device=input.device)
    for i, idx in enumerate(all_finger_idx):
        # tri_idx = i // 4
        tri_idx = -1
        for first_finger_idx in first_finger_joint_idx:
            if idx >= first_finger_idx:
                tri_idx += 1
            else:
                break
        distance = distance_to_triangle(target_finger_tris[:, tri_idx], actual_finger_pos[:, i])
        distances[:, i] = distance

    return distances