in private_prediction_experiment.py [0:0]
def main(args):
"""
Runs private predictions experiment on dataset using input arguments `args`.
"""
# set up visualizer:
if args.visdom:
visualizer = visdom.Visdom(args.visdom)
if not args.visdom or not visualizer.check_connection():
visualizer = None
# load dataset:
logging.info(f"Loading {args.dataset} dataset...")
normalize = args.dataset.startswith("mnist")
reshape = (args.model == "linear")
num_classes = None if args.num_classes == -1 else args.num_classes
data = {}
for split in ["train", "test"]:
data[split] = dataloading.load_dataset(
name=args.dataset,
split=split,
normalize=normalize,
reshape=reshape,
num_classes=num_classes,
root=args.data_folder,
)
# apply PCA if requested (on all data; non-transductive setting):
if args.pca_dims != -1:
assert reshape, "cannot use PCA with non-linear models"
data["train"], mapping = dataloading.pca(data["train"], num_dims=args.pca_dims)
data["test"], _ = dataloading.pca(data["test"], mapping=mapping)
# subsample training data if requested:
if args.num_samples != -1:
data["train"] = dataloading.subsample(
data["train"], num_samples=args.num_samples, random=False,
)
# copy data to GPU if requested (for linear models only):
if args.device == "gpu" and args.model == "linear":
assert torch.cuda.is_available(), "CUDA is not available on this machine."
logging.info("Copying data to GPU...")
for split in data.keys():
for key, value in data[split].items():
data[split][key] = value.cuda()
# use cross-validation to tune hyperparameters:
args = cross_validate(args, data, visualizer=visualizer)
# repeat the same experiment multiple times:
accuracies = {}
for idx in range(args.num_repetitions):
logging.info(f"Experiment {idx + 1} of {args.num_repetitions}...")
private_prediction.compute_accuracy(
args, data, accuracies=accuracies, visualizer=visualizer
)
# save results to file:
if args.result_file is not None and args.result_file != "":
logging.info(f"Writing results to file {args.result_file}...")
with open(args.result_file, "wt") as json_file:
json.dump(accuracies, json_file)