in src/eval.py [0:0]
def print_joint_occ_acc(joint_occ, joint_acc_all, joint_acc_untracked, joint_acc_tracked):
body_occ = None
body_acc_all = None
body_acc_untracked = None
body_acc_tracked = None
for k in joint_occ.keys():
mean_occ = torch.mean(torch.cat(joint_occ[k], dim=0)).item()
rmse_acc_all = torch.sqrt(torch.mean(torch.cat(joint_acc_all[k], dim=0))).item()
rmse_acc_untracked = torch.sqrt(torch.mean(torch.cat(joint_acc_untracked[k], dim=0))).item()
rmse_acc_tracked = torch.sqrt(torch.mean(torch.cat(joint_acc_tracked[k], dim=0))).item()
mean_acc_all = torch.mean(torch.sqrt(torch.cat(joint_acc_all[k], dim=0))).item()
mean_acc_untracked = torch.mean(torch.sqrt(torch.cat(joint_acc_untracked[k], dim=0))).item()
mean_acc_tracked = torch.mean(torch.sqrt(torch.cat(joint_acc_tracked[k], dim=0))).item()
if not k.startswith("finger") and k != "thumb" and k != "wrist_local_finger":
if body_occ is None:
body_occ = torch.cat(joint_occ[k], dim=0)
body_acc_all = torch.cat(joint_acc_all[k], dim=0)
body_acc_untracked = torch.cat(joint_acc_untracked[k], dim=0)
body_acc_tracked = torch.cat(joint_acc_tracked[k], dim=0)
else:
body_occ = torch.cat([body_occ, *joint_occ[k]], dim=0)
body_acc_all = torch.cat([body_acc_all, *joint_acc_all[k]], dim=0)
body_acc_untracked = torch.cat([body_acc_untracked, *joint_acc_untracked[k]], dim=0)
body_acc_tracked = torch.cat([body_acc_tracked, *joint_acc_tracked[k]], dim=0)
print(f"{k:<20} occ: {mean_occ:.2f} "
f"acc_all: {rmse_acc_all:.3f} ({mean_acc_all:.3f}) "
f"acc_untracked: {rmse_acc_untracked:.3f} ({mean_acc_untracked:.3f}) "
f"acc_tracked: {rmse_acc_tracked:.3f} ({mean_acc_tracked:.3f})")
mean_body_occ = torch.mean(body_occ).item()
rmse_body_acc_all = torch.sqrt(torch.mean(body_acc_all)).item()
rmse_body_acc_untracked = torch.sqrt(torch.mean(body_acc_untracked)).item()
rmse_body_acc_tracked = torch.sqrt(torch.mean(body_acc_tracked)).item()
mean_body_acc_all = torch.mean(torch.sqrt(body_acc_all)).item()
mean_body_acc_untracked = torch.mean(torch.sqrt(body_acc_untracked)).item()
mean_body_acc_tracked = torch.mean(torch.sqrt(body_acc_tracked)).item()
print(f"{'body':<20} occ: {mean_body_occ:.2f} "
f"acc_all: {rmse_body_acc_all:.3f} ({mean_body_acc_all:.3f}) "
f"acc_untracked: {rmse_body_acc_untracked:.3f} ({mean_body_acc_untracked:.3f}) "
f"acc_tracked: {rmse_body_acc_tracked:.3f} ({mean_body_acc_tracked:.3f})")