def main()

in private_prediction_experiment.py [0:0]


def main(args):
    """
    Runs private predictions experiment on dataset using input arguments `args`.
    """

    # set up visualizer:
    if args.visdom:
        visualizer = visdom.Visdom(args.visdom)
    if not args.visdom or not visualizer.check_connection():
        visualizer = None

    # load dataset:
    logging.info(f"Loading {args.dataset} dataset...")
    normalize = args.dataset.startswith("mnist")
    reshape = (args.model == "linear")
    num_classes = None if args.num_classes == -1 else args.num_classes
    data = {}
    for split in ["train", "test"]:
        data[split] = dataloading.load_dataset(
            name=args.dataset,
            split=split,
            normalize=normalize,
            reshape=reshape,
            num_classes=num_classes,
            root=args.data_folder,
        )

    # apply PCA if requested (on all data; non-transductive setting):
    if args.pca_dims != -1:
        assert reshape, "cannot use PCA with non-linear models"
        data["train"], mapping = dataloading.pca(data["train"], num_dims=args.pca_dims)
        data["test"], _ = dataloading.pca(data["test"], mapping=mapping)

    # subsample training data if requested:
    if args.num_samples != -1:
        data["train"] = dataloading.subsample(
            data["train"], num_samples=args.num_samples, random=False,
        )

    # copy data to GPU if requested (for linear models only):
    if args.device == "gpu" and args.model == "linear":
        assert torch.cuda.is_available(), "CUDA is not available on this machine."
        logging.info("Copying data to GPU...")
        for split in data.keys():
            for key, value in data[split].items():
                data[split][key] = value.cuda()

    # use cross-validation to tune hyperparameters:
    args = cross_validate(args, data, visualizer=visualizer)

    # repeat the same experiment multiple times:
    accuracies = {}
    for idx in range(args.num_repetitions):
        logging.info(f"Experiment {idx + 1} of {args.num_repetitions}...")
        private_prediction.compute_accuracy(
            args, data, accuracies=accuracies, visualizer=visualizer
        )

    # save results to file:
    if args.result_file is not None and args.result_file != "":
        logging.info(f"Writing results to file {args.result_file}...")
        with open(args.result_file, "wt") as json_file:
            json.dump(accuracies, json_file)