in private_prediction.py [0:0]
def compute_accuracy(args, data, accuracies=None, visualizer=None):
"""
Runs a single experiment using the settings in `args` on the specified
`data`. Accuracies resulting from the experiment are stored in `accuracies`.
If a visdom `visualizer` is specified, the function plots learning curves.
"""
# check inputs:
if accuracies is None:
accuracies = {}
else:
assert isinstance(accuracies, dict), "accuracies must be dict"
# run the specified private prediction method:
title = "Learning curve"
method_name = f"{args.method}_method"
if method_name not in globals():
raise ValueError(f"Unknown private prediction method: {args.method}")
method_func = globals()[method_name]
predictions = method_func(data, args, visualizer=visualizer, title=title)
# compute accuracy on all splits:
for split, preds in predictions.items():
# get targets for this split:
targets = data[split]["targets"]
# prediction accuracy independent of inference budget:
if torch.is_tensor(preds):
# make sure predictions and targets live on the same device:
if preds.device != targets.device:
preds = preds.to(device=targets.device)
# compute accuracy:
if split not in accuracies:
accuracies[split] = []
accuracy = float(preds.eq(targets).sum()) / targets.size(0)
logging.info(f" => {split} accuracy: {accuracy}")
accuracies[split].append(accuracy)
# prediction accuracy depends on inference budget:
elif isinstance(preds, dict):
if split not in accuracies:
accuracies[split] = {}
for budget, budget_preds in preds.items():
# make sure predictions and targets live on the same device:
if budget_preds.device != targets.device:
budget_preds = budget_preds.to(device=targets.device)
# compute accuracy:
budget = str(budget)
if budget not in accuracies[split]:
accuracies[split][budget] = []
accuracy = float(budget_preds.eq(targets).sum()) / targets.size(0)
logging.info(f" => {split} accuracy at {budget} budget: {accuracy}")
accuracies[split][budget].append(accuracy)
# this should never happen:
else:
raise ValueError("Unknown format of preds variable.")
# return:
return accuracies