def append_batch_occ_acc()

in src/eval.py [0:0]


def append_batch_occ_acc(joint_occ, joint_acc_all, joint_acc_untracked, joint_acc_tracked, mask, y_batch, out, c: Train_Config, bone_lengths, X_batch,
                         bone_offsets):
    ref_pose = c.out_features.features_to_pose().solve_batch(y_batch, ref_positions=None, mask=None, bone_lengths=bone_lengths, body=X_batch,
                                                             bone_offsets=bone_offsets)
    ref_pose = ref_pose.reshape([out.shape[0], -1, 3])
    pred_pose = c.out_features.features_to_pose().solve_batch(out, ref_positions=None, mask=None, bone_lengths=bone_lengths, body=X_batch,
                                                              bone_offsets=bone_offsets)
    pred_pose = pred_pose.reshape([out.shape[0], -1, 3])

    Idx = get_reduced_joint_group_idx(c)

    for k, v in Idx.items():
        err = None
        joint_mask = None
        for group_idx, joint_idx in enumerate(v):
            if k == "wrist_local_finger":
                wrist_idx = Idx['wrist'][0] if group_idx < len(v) / 2 else Idx['wrist'][1]
                ref_wrist = ref_pose[:, wrist_idx]
                pred_wrist = pred_pose[:, wrist_idx]
                ref_p = ref_pose[:, joint_idx] - ref_wrist
                pred_p = pred_pose[:, joint_idx] - pred_wrist
                _err = torch.norm(ref_p - pred_p, dim=1) ** 2
            else:
                _err = torch.norm(ref_pose[:, joint_idx] - pred_pose[:, joint_idx], dim=1) ** 2
            if err is None:
                err = _err
                joint_mask = torch.clone(mask[:, joint_idx])
            else:
                err = torch.cat([err, _err], dim=0)
                joint_mask = torch.cat([joint_mask, torch.clone(mask[:, joint_idx])], dim=0)

        mask_float = torch.zeros(joint_mask.shape, dtype=torch.float32)
        mask_float[joint_mask] += 1.0
        occ = mask_float
        err_all = err
        if torch.mean(occ) > 0.000001:
            err_untracked = err[joint_mask]
        else:
            err_untracked = err_all * 0.0
        if torch.mean(occ) < 0.999999:
            err_tracked = err[~joint_mask]
        else:
            err_tracked = err_all * 0.0
        joint_occ[k].append(occ)
        joint_acc_all[k].append(err_all)
        joint_acc_tracked[k].append(err_tracked)
        joint_acc_untracked[k].append(err_untracked)