in plugins/hydra_ax_sweeper/hydra_plugins/hydra_ax_sweeper/_core.py [0:0]
def sweep(self, arguments: List[str]) -> None:
self.job_idx = 0
ax_client = self.setup_ax_client(arguments)
num_trials_left = self.max_trials
max_parallelism = ax_client.get_max_parallelism()
current_parallelism_index = 0
# Index to track the parallelism value we are using right now.
is_search_space_exhausted = False
# Ax throws an exception if the search space is exhausted. We catch
# the exception and set the flag to True
best_parameters = {}
while num_trials_left > 0 and not is_search_space_exhausted:
current_parallelism = max_parallelism[current_parallelism_index]
num_trials, max_parallelism_setting = current_parallelism
num_trials_so_far = 0
while (
num_trials > num_trials_so_far or num_trials == -1
) and num_trials_left > 0:
trial_batch = get_one_batch_of_trials(
ax_client=ax_client,
parallelism=current_parallelism,
num_trials_so_far=num_trials_so_far,
num_max_trials_to_do=num_trials_left,
)
list_of_trials_to_launch = trial_batch.list_of_trials[:num_trials_left]
is_search_space_exhausted = trial_batch.is_search_space_exhausted
log.info(
"AxSweeper is launching {} jobs".format(
len(list_of_trials_to_launch)
)
)
self.sweep_over_batches(
ax_client=ax_client, list_of_trials=list_of_trials_to_launch
)
num_trials_so_far += len(list_of_trials_to_launch)
num_trials_left -= len(list_of_trials_to_launch)
best_parameters, predictions = ax_client.get_best_parameters()
metric = predictions[0][ax_client.objective_name]
if self.early_stopper.should_stop(metric, best_parameters):
num_trials_left = -1
break
if is_search_space_exhausted:
log.info("Ax has exhausted the search space")
break
current_parallelism_index += 1
results_to_serialize = {"optimizer": "ax", "ax": best_parameters}
OmegaConf.save(
OmegaConf.create(results_to_serialize),
f"{self.sweep_dir}/optimization_results.yaml",
)
log.info("Best parameters: " + str(best_parameters))