in private_prediction_experiment.py [0:0]
def cross_validate(args, data, visualizer=None):
"""
Performs cross-validation over hyperparameters for which this was requested.
"""
# check if there are any parameters to cross-validate over:
accuracies = {}
arguments = {key: getattr(args, key) for key in CROSS_VALIDATION.keys()}
if any(value == -1 for value in arguments.values()):
# create validation split from training data:
valid_size = data["train"]["features"].size(0) // 10
original_train, data["valid"] = {}, {}
for key in data["train"].keys():
original_train[key] = data["train"][key]
data["valid"][key] = original_train[key].narrow(0, 0, valid_size)
data["train"][key] = original_train[key].narrow(
0, valid_size, original_train[key].size(0) - valid_size
)
# NOTE: This assumes data is already shuffled.
# NOTE: This makes an additional data copy, which may be bad on GPUs.
# get hyperparameter key and values:
hyper_key = [key for key, val in arguments.items() if val == -1]
assert len(hyper_key) == 1, \
"can only cross-validate over single hyperparameter at the same time"
hyper_key = hyper_key[0]
hyper_values = CROSS_VALIDATION[hyper_key]
# perform the actual cross-validation:
num_repetitions, idx = max(1, args.num_repetitions // 10), 0
for hyper_value in hyper_values:
# make copy of arguments that we can alter:
args_copy = copy.deepcopy(args)
setattr(args_copy, hyper_key, hyper_value)
if args_copy.inference_budget == -1:
args_copy.inference_budget = 100
accuracies[hyper_value] = {}
# repeat experiment multiple times:
for _ in range(num_repetitions):
logging.info(f"Cross-validation experiment {idx + 1} of "
f"{len(hyper_values) * num_repetitions}...")
private_prediction.compute_accuracy(
args_copy, data,
accuracies=accuracies[hyper_value],
visualizer=visualizer,
)
idx += 1
# find best hyperparameter setting:
for hyper_value in hyper_values:
valid_accuracy = accuracies[hyper_value]["valid"]
if isinstance(valid_accuracy, dict): # inference budget in accuracies
valid_accuracy = valid_accuracy[str(args_copy.inference_budget)]
accuracies[hyper_value] = sum(valid_accuracy) / float(num_repetitions)
optimal_value = max(accuracies, key=accuracies.get)
logging.info(f"Selecting {hyper_key} value of {optimal_value}...")
# clean up validation set:
data["train"] = original_train
del data["valid"]
# update arguments object:
setattr(args, hyper_key, optimal_value)
# return arguments to use for main experiment:
return args