in ax/service/scheduler.py [0:0]
def _get_next_trials(self, num_trials: int = 1) -> List[BaseTrial]:
"""Produce up to `num_trials` new generator runs from the underlying
generation strategy and create new trials with them. Logs errors
encountered during generation.
NOTE: Fewer than `num_trials` trials may be produced if generation
strategy runs into its parallelism limit or needs more data to proceed.
Returns:
List of trials, empty if generation is not possible.
"""
pending = get_pending_observation_features_based_on_trial_status(
experiment=self.experiment
)
try:
generator_runs = self._gen_new_trials_from_generation_strategy(
num_trials=num_trials, pending=pending
)
except OptimizationComplete as err:
completion_str = f"Optimization complete: {err}"
self.logger.info(completion_str)
self.markdown_messages["Optimization complete"] = completion_str
self._optimization_complete = True
return []
except DataRequiredError as err:
# TODO[T62606107]: consider adding a `more_data_required` property to
# check to generation strategy to avoid running into this exception.
if self._log_next_no_trials_reason:
self.logger.info(
"Generated all trials that can be generated currently. "
"Model requires more data to generate more trials."
)
self.logger.debug(f"Message from generation strategy: {err}")
return []
except MaxParallelismReachedException as err:
# TODO[T62606107]: consider adding a `step_max_parallelism_reached`
# check to generation strategy to avoid running into this exception.
if self._log_next_no_trials_reason:
self.logger.info(
"Generated all trials that can be generated currently. "
"Max parallelism currently reached."
)
self.logger.debug(f"Message from generation strategy: {err}")
return []
if self.options.trial_type is Trial and len(generator_runs[0].arms) > 1:
raise SchedulerInternalError(
"Generation strategy produced multiple arms when only one was expected."
)
return [
self.experiment.new_batch_trial(
generator_run=generator_run,
ttl_seconds=self.options.ttl_seconds_for_trials,
)
if self.options.trial_type is BatchTrial
else self.experiment.new_trial(
generator_run=generator_run,
ttl_seconds=self.options.ttl_seconds_for_trials,
)
for generator_run in generator_runs
]