def compute()

in empose/eval/metrics.py [0:0]


    def compute(self, pose, shape, pose_hat, shape_hat=None, seq_lengths=None, pose_root=None, pose_root_hat=None,
                frame_mask=None):
        """
        Compute the metrics.
        :param pose: The ground-truth pose without the root as a tensor of shape (N, F, N_JOINTS*3)
        :param shape: The ground-truth shape as a tensor of shape (N, N_BETAS)
        :param pose_hat: The predicted pose.
        :param shape_hat: The predicted shape. If None the ground-truth shape is assumed.
        :param seq_lengths: An optional tensor of shape (N, ) indicating the true sequence length.
        :param pose_root: An optional tensor of shape (N, F, 3) indicating the ground-truth root pose.
        :param pose_root_hat: An optional tensor of shape (N, F, 3) indicating the estimated root pose.
        :param frame_mask: An optional boolean tensor of shape (N, F) or (N, F, M) indicating whether a frame
          should be considered in the evaluation or not. If the shape is (N, F, M) the last dimension will be
          reduced and the corresponding frame is not considered if any of the M dimensions is False.
        """
        n, f = pose.shape[0], pose.shape[1]

        if shape_hat is None:
            shape_hat = shape

        mask = self._get_mask(seq_lengths, n, f, frame_mask, pose.device)
        if mask.sum() == 0:
            return

        shape = self._pad_shapes(shape, f, mask)
        shape_hat = self._pad_shapes(shape_hat, f, mask)

        pose = self._masked_flatten(pose, mask)
        pose_hat = self._masked_flatten(pose_hat, mask)

        if pose_root is None:
            pose_root = torch.zeros([pose.shape[0], 3]).to(dtype=pose.dtype, device=pose.device)
            pose_root_hat = torch.zeros([pose.shape[0], 3]).to(dtype=pose.dtype, device=pose.device)
        else:
            pose_root = self._masked_flatten(pose_root, mask)
            pose_root_hat = self._masked_flatten(pose_root_hat, mask)

        # Get joint positions.
        _, kp3d = self.smpl_model.fk(pose, shape, poses_root=pose_root, window_size=1000)
        _, kp3d_hat = self.smpl_model.fk(pose_hat, shape_hat, poses_root=pose_root_hat, window_size=1000)

        # We're only interested in the body joints without hands.
        kp3d = kp3d[:, :C.N_JOINTS + 1]
        kp3d_hat = kp3d_hat[:, :C.N_JOINTS + 1]
        self._compute_eucl_dist(kp3d, kp3d_hat)
        self._compute_eucl_dist(kp3d, kp3d_hat, procrustes=True)

        if self.angle_glob:
            n = pose.shape[0]
            dummy_root = torch.zeros((n, 3)).to(dtype=pose.dtype, device=pose.device)
            pose_w_root = torch.cat([dummy_root, pose], dim=-1)
            pose_hat_w_root = torch.cat([dummy_root, pose_hat], dim=-1)

            pose_global = local_to_global(pose_w_root, C.SMPL_PARENTS)
            pose_hat_global = local_to_global(pose_hat_w_root, C.SMPL_PARENTS)

            self._compute_angular_dist(pose_global[:, 3:], pose_hat_global[:, 3:])
        else:
            self._compute_angular_dist(pose, pose_hat)