in src/eval.py [0:0]
def append_batch_occ_acc(joint_occ, joint_acc_all, joint_acc_untracked, joint_acc_tracked, mask, y_batch, out, c: Train_Config, bone_lengths, X_batch,
bone_offsets):
ref_pose = c.out_features.features_to_pose().solve_batch(y_batch, ref_positions=None, mask=None, bone_lengths=bone_lengths, body=X_batch,
bone_offsets=bone_offsets)
ref_pose = ref_pose.reshape([out.shape[0], -1, 3])
pred_pose = c.out_features.features_to_pose().solve_batch(out, ref_positions=None, mask=None, bone_lengths=bone_lengths, body=X_batch,
bone_offsets=bone_offsets)
pred_pose = pred_pose.reshape([out.shape[0], -1, 3])
Idx = get_reduced_joint_group_idx(c)
for k, v in Idx.items():
err = None
joint_mask = None
for group_idx, joint_idx in enumerate(v):
if k == "wrist_local_finger":
wrist_idx = Idx['wrist'][0] if group_idx < len(v) / 2 else Idx['wrist'][1]
ref_wrist = ref_pose[:, wrist_idx]
pred_wrist = pred_pose[:, wrist_idx]
ref_p = ref_pose[:, joint_idx] - ref_wrist
pred_p = pred_pose[:, joint_idx] - pred_wrist
_err = torch.norm(ref_p - pred_p, dim=1) ** 2
else:
_err = torch.norm(ref_pose[:, joint_idx] - pred_pose[:, joint_idx], dim=1) ** 2
if err is None:
err = _err
joint_mask = torch.clone(mask[:, joint_idx])
else:
err = torch.cat([err, _err], dim=0)
joint_mask = torch.cat([joint_mask, torch.clone(mask[:, joint_idx])], dim=0)
mask_float = torch.zeros(joint_mask.shape, dtype=torch.float32)
mask_float[joint_mask] += 1.0
occ = mask_float
err_all = err
if torch.mean(occ) > 0.000001:
err_untracked = err[joint_mask]
else:
err_untracked = err_all * 0.0
if torch.mean(occ) < 0.999999:
err_tracked = err[~joint_mask]
else:
err_tracked = err_all * 0.0
joint_occ[k].append(occ)
joint_acc_all[k].append(err_all)
joint_acc_tracked[k].append(err_tracked)
joint_acc_untracked[k].append(err_untracked)