def default_compute_accuracies()

in privacy_lint/attacks/gap.py [0:0]


def default_compute_accuracies(model: nn.Module, dataloader: DataLoader):
    """
    Computes 0-1 accuracy of the model for each sample in the dataloader.
    """

    accuracies = []
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    model.to(device)

    for inp, target in tqdm(dataloader):
        inp = inp.to(device)
        target = target.to(device)
        outputs = model(inp)
        accuracies += (outputs.argmax(dim=1) == target).tolist()

    return torch.Tensor(accuracies)