in src/python/tensorflow_cloud/tuner/tuner.py [0:0]
def create_trial(self, tuner_id: Text) -> trial_module.Trial:
"""Create a new `Trial` to be run by the `Tuner`.
Args:
tuner_id: An ID that identifies the `Tuner` requesting a `Trial`.
`Tuners` that should run the same trial (for instance, when
running a multi-worker model) should have the same ID. If
multiple suggestTrialsRequests have the same tuner_id, the
service will return the identical suggested trial if the trial
is PENDING, and provide a new trial if the last suggested trial
was completed.
Returns:
A `Trial` object containing a set of hyperparameter values to run
in a `Tuner`.
Raises:
SuggestionInactiveError: Indicates that a suggestion was requested
from an inactive study.
"""
# List all trials from the same study and see if any
# trial.status=STOPPED or if number of trials >= max_limit.
trial_list = self.service.list_trials()
# Note that KerasTunerTrialStatus - 'STOPPED' is equivalent to
# VizierTrialState - 'STOPPING'.
stopping_trials = [t for t in trial_list if t["state"] == "STOPPING"]
if (self.max_trials and
len(trial_list) >= self.max_trials) or stopping_trials:
trial_id = "n"
hyperparameters = self.hyperparameters.copy()
hyperparameters.values = {}
# This will break the search loop later.
return trial_module.Trial(
hyperparameters=hyperparameters,
trial_id=trial_id,
status=trial_module.TrialStatus.STOPPED,
)
# Get suggestions
suggestions = self.service.get_suggestions(tuner_id)
if not suggestions:
return trial_module.Trial(
hyperparameters={}, status=trial_module.TrialStatus.STOPPED
)
# Fetches the suggested trial.
# Vizier Trial instance
vizier_trial = suggestions[0]
trial_id = utils.get_trial_id(vizier_trial)
# KerasTuner Trial instance
keras_tuner_trial = trial_module.Trial(
hyperparameters=utils.convert_vizier_trial_to_hps(
self.hyperparameters.copy(), vizier_trial
),
trial_id=trial_id,
status=trial_module.TrialStatus.RUNNING,
)
tf.get_logger().info(
"Hyperparameters requested by tuner ({}): {} ".format(
tuner_id, keras_tuner_trial.hyperparameters.values
)
)
self._start_time = time.time()
self.trials[trial_id] = keras_tuner_trial
self.ongoing_trials[tuner_id] = keras_tuner_trial
self._save_trial(keras_tuner_trial)
self.save()
return keras_tuner_trial