in privacy_lint/attack_results.py [0:0]
def get_max_accuracy_threshold(self) -> Tuple[float, float]:
"""
Computes the score threshold that allows for maximum accuracy of the attack.
All samples below this threshold will be classified as train and all samples
above as test.
"""
labels_ordered, scores_ordered = self._get_scores_and_labels_ordered()
cum_train_from_left = torch.cumsum(labels_ordered == 1, 0)
cum_heldout_from_right = torch.cumsum(labels_ordered.flip(0) == 0, 0).flip(0)
pad = torch.zeros(1, device=cum_train_from_left.device)
cum_train_from_left = torch.cat((pad, cum_train_from_left[:-1]))
n = labels_ordered.shape[0]
accuracies = (cum_train_from_left + cum_heldout_from_right) / n
max_accuracy_threshold = scores_ordered[accuracies.argmax()].item()
max_accuracy = accuracies.max().item()
return max_accuracy_threshold, max_accuracy