def cross_validate()

in private_prediction_experiment.py [0:0]


def cross_validate(args, data, visualizer=None):
    """
    Performs cross-validation over hyperparameters for which this was requested.
    """

    # check if there are any parameters to cross-validate over:
    accuracies = {}
    arguments = {key: getattr(args, key) for key in CROSS_VALIDATION.keys()}
    if any(value == -1 for value in arguments.values()):

        # create validation split from training data:
        valid_size = data["train"]["features"].size(0) // 10
        original_train, data["valid"] = {}, {}
        for key in data["train"].keys():
            original_train[key] = data["train"][key]
            data["valid"][key] = original_train[key].narrow(0, 0, valid_size)
            data["train"][key] = original_train[key].narrow(
                0, valid_size, original_train[key].size(0) - valid_size
            )
        # NOTE: This assumes data is already shuffled.
        # NOTE: This makes an additional data copy, which may be bad on GPUs.

        # get hyperparameter key and values:
        hyper_key = [key for key, val in arguments.items() if val == -1]
        assert len(hyper_key) == 1, \
            "can only cross-validate over single hyperparameter at the same time"
        hyper_key = hyper_key[0]
        hyper_values = CROSS_VALIDATION[hyper_key]

        # perform the actual cross-validation:
        num_repetitions, idx = max(1, args.num_repetitions // 10), 0
        for hyper_value in hyper_values:

            # make copy of arguments that we can alter:
            args_copy = copy.deepcopy(args)
            setattr(args_copy, hyper_key, hyper_value)
            if args_copy.inference_budget == -1:
                args_copy.inference_budget = 100
            accuracies[hyper_value] = {}

            # repeat experiment multiple times:
            for _ in range(num_repetitions):
                logging.info(f"Cross-validation experiment {idx + 1} of "
                             f"{len(hyper_values) * num_repetitions}...")
                private_prediction.compute_accuracy(
                    args_copy, data,
                    accuracies=accuracies[hyper_value],
                    visualizer=visualizer,
                )
                idx += 1

        # find best hyperparameter setting:
        for hyper_value in hyper_values:
            valid_accuracy = accuracies[hyper_value]["valid"]
            if isinstance(valid_accuracy, dict):  # inference budget in accuracies
                valid_accuracy = valid_accuracy[str(args_copy.inference_budget)]
            accuracies[hyper_value] = sum(valid_accuracy) / float(num_repetitions)
        optimal_value = max(accuracies, key=accuracies.get)
        logging.info(f"Selecting {hyper_key} value of {optimal_value}...")

        # clean up validation set:
        data["train"] = original_train
        del data["valid"]

        # update arguments object:
        setattr(args, hyper_key, optimal_value)

    # return arguments to use for main experiment:
    return args