in src/python/tensorflow_cloud/tuner/tuner.py [0:0]
def run_trial(self, trial, *fit_args, **fit_kwargs):
"""Evaluates a set of hyperparameter values.
This method is called during `search` to evaluate a set of
hyperparameters using AI Platform training.
Args:
trial: A `Trial` instance that contains the information
needed to run this trial. `Hyperparameters` can be accessed
via `trial.hyperparameters`.
*fit_args: Positional arguments passed by `search`.
**fit_kwargs: Keyword arguments passed by `search`.
Raises:
RuntimeError: If AI Platform training job fails.
"""
# Running the training remotely.
copied_fit_kwargs = copy.copy(fit_kwargs)
# Handle any callbacks passed to `fit`.
callbacks = fit_kwargs.pop("callbacks", [])
callbacks = self._deepcopy_callbacks(callbacks)
# Note: run_trial does not use `TunerCallback` calls, since
# training is performed on AI Platform training remotely.
# Handle TensorBoard/hyperparameter logging here. The TensorBoard
# logs are used for passing metrics back from remote execution.
self._add_logging(callbacks, trial)
# Creating a save_model checkpoint callback with a saved model file path
# specific to this trial. This is to prevent different trials from
# overwriting each other.
self._add_model_checkpoint_callback(
callbacks, trial.trial_id)
copied_fit_kwargs["callbacks"] = callbacks
model = self.hypermodel.build(trial.hyperparameters)
remote_dir = os.path.join(self.directory, str(trial.trial_id))
job_id = f"{self._study_id}_{trial.trial_id}"
# Create job spec from worker count and config
job_spec = self._get_job_spec_from_config(job_id)
tf.get_logger().info("Calling cloud_fit with %s", {
"model": model,
"remote_dir": remote_dir,
"region": self._region,
"project_id": self._project_id,
"image_uri": self._container_uri,
"job_id": job_id,
"*fit_args": fit_args,
"job_spec": job_spec,
"**copied_fit_kwargs": copied_fit_kwargs})
cloud_fit_client.cloud_fit(
model=model,
remote_dir=remote_dir,
region=self._region,
project_id=self._project_id,
image_uri=self._container_uri,
job_id=job_id,
job_spec=job_spec,
*fit_args,
**copied_fit_kwargs)
# Create an instance of tensorboard DirectoryWatcher to retrieve the
# logs for this trial run
train_log_path = os.path.join(
self._get_tensorboard_log_dir(trial.trial_id), "train")
# Tensorboard log watcher expects the path to exist
tf.io.gfile.makedirs(train_log_path)
tf.get_logger().info(
f"Retrieving training logs for trial {trial.trial_id} from"
f" {train_log_path}")
train_log_reader = tf_utils.get_tensorboard_log_watcher_from_path(
train_log_path)
training_metrics = _TrainingMetrics([], {})
epoch = 0
while google_api_client.is_aip_training_job_running(
job_id, self._project_id):
time.sleep(_POLLING_INTERVAL_IN_SECONDS)
# Retrieve available metrics if any
training_metrics = self._get_remote_training_metrics(
train_log_reader, training_metrics.partial_epoch_metrics)
for epoch_metrics in training_metrics.completed_epoch_metrics:
# TODO(b/169197272) Validate metrics contain oracle objective
if epoch_metrics:
trial.status = self.oracle.update_trial(
trial_id=trial.trial_id,
metrics=epoch_metrics,
step=epoch)
epoch += 1
if trial.status == "STOPPED":
google_api_client.stop_aip_training_job(
job_id, self._project_id)
break
# Ensure the training job has completed successfully.
if not google_api_client.wait_for_aip_training_job_completion(
job_id, self._project_id):
raise RuntimeError(
"AI Platform Training job failed, see logs for details at "
"https://console.cloud.google.com/ai-platform/jobs/"
"{}/charts/cpu?project={}"
.format(job_id, self._project_id))
# Retrieve and report any remaining metrics
training_metrics = self._get_remote_training_metrics(
log_reader=train_log_reader,
partial_epoch_metrics=training_metrics.partial_epoch_metrics)
for epoch_metrics in training_metrics.completed_epoch_metrics:
# TODO(b/169197272) Validate metrics contain oracle objective
# TODO(b/170907612) Support submit partial results to Oracle
if epoch_metrics:
self.oracle.update_trial(
trial_id=trial.trial_id,
metrics=epoch_metrics,
step=epoch)
epoch += 1
# submit final epoch metrics
if training_metrics.partial_epoch_metrics:
self.oracle.update_trial(
trial_id=trial.trial_id,
metrics=training_metrics.partial_epoch_metrics,
step=epoch)
# Submit validation metrics if eval_files is provided at the end of
# the trial.
if copied_fit_kwargs.get("eval_files"):
# Create an instance of tensorboard DirectoryWatcher to retrieve the
# logs for validation run.
val_log_path = os.path.join(
self._get_tensorboard_log_dir(trial.trial_id), "validation")
# Tensorboard log watcher expects the path to exist
tf.io.gfile.makedirs(val_log_path)
tf.get_logger().info(
f"Retrieving validation logs for trial {trial.trial_id} from"
f" {val_log_path}")
val_log_reader = tf_utils.get_tensorboard_log_watcher_from_path(
val_log_path)
validation_metrics = _TrainingMetrics([], {})
validation_metrics = self._get_remote_training_metrics(
log_reader=val_log_reader,
partial_epoch_metrics=validation_metrics.partial_epoch_metrics,
is_validation=True)
for metric in validation_metrics.completed_epoch_metrics:
if metric:
self.oracle.update_trial(
trial_id=trial.trial_id,
metrics=metric)