in src/utils.py [0:0]
def fix_bone_lengths(positions, mask, c, target_lengths, it=10, k=0.95):
def get_mean_target_position(joint_idx):
# calculates the target position of the joint using the direction to parent and children and the correct bone length
# returns the mean
# note -> because some bones have multiple children, the children also need to have a constant distance between each other
pos_arr = []
if parent_idx[joint_idx] >= 0:
diff = target[:, joint_idx] / actual_lengths[:, joint_idx]
p_parent = p[:, parent_idx[joint_idx]] if parent_idx[joint_idx] >= 0 else torch.zeros((len(p), 3), device=p.device)
off = p[:, joint_idx] - p_parent
pos_arr.append(p_parent + (off * torch_tile(diff[:, None], dim=1, n_tile=3)))
# calculate the target distance between this bone and the siblings
for child in child_idx[parent_idx[joint_idx]]:
if child == joint_idx:
continue
diff = target[:, child] / actual_lengths[:, child]
p_child = p[:, child]
off = p[:, joint_idx] - p_child
pos_arr.append(p_child + (off * torch_tile(diff[:, None], dim=1, n_tile=3)))
for child in child_idx[joint_idx]:
diff = target[:, child] / actual_lengths[:, child]
p_child = p[:, child]
off = p[:, joint_idx] - p_child
pos_arr.append(p_child + (off * torch_tile(diff[:, None], dim=1, n_tile=3)))
pos_arr = torch.stack(pos_arr)
return torch.mean(pos_arr, dim=0)
p = torch.clone(positions).to(positions.device)
new_p = torch.empty_like(p)
parent_idx = c.skeleton.parent_idx_vector()
child_idx = c.skeleton.child_idx_vector()
target = torch_tile(torch.from_numpy(target_lengths[None, :]).to(p.device), dim=0, n_tile=len(p)).float().to(p.device)
_k = k
for i in range(it):
actual_lengths = bone_lengths_batch(p, parent_idx).to(p.device)
err = torch.abs(actual_lengths / target)
err[torch.isnan(err)] = 0
err[torch.isinf(err)] = 0
print(torch.mean(err))
for joint_idx in range(len(c.skeleton.Idx.all)):
mean_p = get_mean_target_position(joint_idx)
diff = mean_p - p[:, joint_idx]
joint_mask = mask[:, joint_idx]
new_p[joint_mask, joint_idx] = (p[:, joint_idx] + diff * _k)[joint_mask]
new_p[~joint_mask, joint_idx] = p[:, joint_idx][~joint_mask]
# _k *= k
p = torch.clone(new_p)
return p