in src/utils.py [0:0]
def finger_direction_loss_finger_pos_input_only(input, target, config, body, **kwargs):
target_pos = target.reshape((input.shape[0], -1, 3))
out_pos = input.reshape((input.shape[0], -1, 3))
l_finger_idx, r_finger_idx = config.skeleton.Idx.get_finger_joints()
# all_finger_idx = [*l_finger_idx, *r_finger_idx]
Idx = config.skeleton.Idx
l_finger_tri_idx = [
[Idx.lindex1, Idx.lindex2, Idx.lindex3],
[Idx.lring1, Idx.lring2, Idx.lring3],
[Idx.lmiddle1, Idx.lmiddle2, Idx.lmiddle3],
[Idx.lpinky1, Idx.lpinky2, Idx.lpinky3],
[Idx.lthumb1, Idx.lthumb2, Idx.lthumb3],
]
r_finger_tri_idx = [
[Idx.rindex1, Idx.rindex2, Idx.rindex3],
[Idx.rring1, Idx.rring2, Idx.rring3],
[Idx.rmiddle1, Idx.rmiddle2, Idx.rmiddle3],
[Idx.rpinky1, Idx.rpinky2, Idx.rpinky3],
[Idx.rthumb1, Idx.rthumb2, Idx.rthumb3],
]
first_finger_joint_idx = [Idx.rindex1, Idx.rring1, Idx.rmiddle1, Idx.rpinky1, Idx.rthumb0, Idx.lindex1, Idx.lring1, Idx.lmiddle1, Idx.lpinky1, Idx.lthumb0]
l_finger_idx_off = min(l_finger_idx)
r_finger_idx_off = min(r_finger_idx)
for i in range(len(l_finger_idx)):
r_finger_idx[i] -= r_finger_idx_off
l_finger_idx[i] += len(l_finger_idx) - l_finger_idx_off
for i in range(len(l_finger_tri_idx)):
for j in range(len(l_finger_tri_idx[i])):
r_finger_tri_idx[i][j] -= r_finger_idx_off
l_finger_tri_idx[i][j] += len(l_finger_idx) - l_finger_idx_off
for i in range(len(first_finger_joint_idx)):
first_finger_joint_idx[i] -= (r_finger_idx_off if i < 5 else (l_finger_idx_off - len(l_finger_idx)))
all_finger_idx = [*r_finger_idx, *l_finger_idx]
# finger_tri_idx = [*l_finger_tri_idx, *r_finger_tri_idx]
finger_tri_idx = [*r_finger_tri_idx, *l_finger_tri_idx]
finger_tri_idx = [y for x in finger_tri_idx for y in x]
target_finger_tris = target_pos[:, finger_tri_idx].reshape((len(target_pos), len(finger_tri_idx) // 3, 3, 3))
actual_finger_pos = out_pos[:, all_finger_idx]
# first_finger_joint_idx = [Idx.lindex1, Idx.lring1, Idx.lmiddle1, Idx.lpinky1, Idx.lthumb0, Idx.rindex1, Idx.rring1, Idx.rmiddle1, Idx.rpinky1, Idx.rthumb0]
distances = torch.zeros((len(actual_finger_pos), len(all_finger_idx)), device=input.device)
for i, idx in enumerate(all_finger_idx):
# tri_idx = i // 4
tri_idx = -1
for first_finger_idx in first_finger_joint_idx:
if idx >= first_finger_idx:
tri_idx += 1
else:
break
distance = distance_to_triangle(target_finger_tris[:, tri_idx], actual_finger_pos[:, i])
distances[:, i] = distance
return distances