def get_max_accuracy_threshold()

in privacy_lint/attack_results.py [0:0]


    def get_max_accuracy_threshold(self) -> Tuple[float, float]:
        """
        Computes the score threshold that allows for maximum accuracy of the attack.
        All samples below this threshold will be classified as train and all samples
        above as test.
        """

        labels_ordered, scores_ordered = self._get_scores_and_labels_ordered()

        cum_train_from_left = torch.cumsum(labels_ordered == 1, 0)
        cum_heldout_from_right = torch.cumsum(labels_ordered.flip(0) == 0, 0).flip(0)

        pad = torch.zeros(1, device=cum_train_from_left.device)
        cum_train_from_left = torch.cat((pad, cum_train_from_left[:-1]))

        n = labels_ordered.shape[0]
        accuracies = (cum_train_from_left + cum_heldout_from_right) / n

        max_accuracy_threshold = scores_ordered[accuracies.argmax()].item()
        max_accuracy = accuracies.max().item()

        return max_accuracy_threshold, max_accuracy