def compute_accuracy()

in private_prediction.py [0:0]


def compute_accuracy(args, data, accuracies=None, visualizer=None):
    """
    Runs a single experiment using the settings in `args` on the specified
    `data`. Accuracies resulting from the experiment are stored in `accuracies`.

    If a visdom `visualizer` is specified, the function plots learning curves.
    """

    # check inputs:
    if accuracies is None:
        accuracies = {}
    else:
        assert isinstance(accuracies, dict), "accuracies must be dict"

    # run the specified private prediction method:
    title = "Learning curve"
    method_name = f"{args.method}_method"
    if method_name not in globals():
        raise ValueError(f"Unknown private prediction method: {args.method}")
    method_func = globals()[method_name]
    predictions = method_func(data, args, visualizer=visualizer, title=title)

    # compute accuracy on all splits:
    for split, preds in predictions.items():

        # get targets for this split:
        targets = data[split]["targets"]

        # prediction accuracy independent of inference budget:
        if torch.is_tensor(preds):

            # make sure predictions and targets live on the same device:
            if preds.device != targets.device:
                preds = preds.to(device=targets.device)

            # compute accuracy:
            if split not in accuracies:
                accuracies[split] = []
            accuracy = float(preds.eq(targets).sum()) / targets.size(0)
            logging.info(f" => {split} accuracy: {accuracy}")
            accuracies[split].append(accuracy)

        # prediction accuracy depends on inference budget:
        elif isinstance(preds, dict):
            if split not in accuracies:
                accuracies[split] = {}
            for budget, budget_preds in preds.items():

                # make sure predictions and targets live on the same device:
                if budget_preds.device != targets.device:
                    budget_preds = budget_preds.to(device=targets.device)

                # compute accuracy:
                budget = str(budget)
                if budget not in accuracies[split]:
                    accuracies[split][budget] = []
                accuracy = float(budget_preds.eq(targets).sum()) / targets.size(0)
                logging.info(f" => {split} accuracy at {budget} budget: {accuracy}")
                accuracies[split][budget].append(accuracy)

        # this should never happen:
        else:
            raise ValueError("Unknown format of preds variable.")

    # return:
    return accuracies