def run_trial()

in src/python/tensorflow_cloud/tuner/tuner.py [0:0]


    def run_trial(self, trial, *fit_args, **fit_kwargs):
        """Evaluates a set of hyperparameter values.

        This method is called during `search` to evaluate a set of
        hyperparameters using AI Platform training.
        Args:
            trial: A `Trial` instance that contains the information
              needed to run this trial. `Hyperparameters` can be accessed
              via `trial.hyperparameters`.
            *fit_args: Positional arguments passed by `search`.
            **fit_kwargs: Keyword arguments passed by `search`.
        Raises:
            RuntimeError: If AI Platform training job fails.
        """

        # Running the training remotely.
        copied_fit_kwargs = copy.copy(fit_kwargs)

        # Handle any callbacks passed to `fit`.
        callbacks = fit_kwargs.pop("callbacks", [])
        callbacks = self._deepcopy_callbacks(callbacks)

        # Note: run_trial does not use `TunerCallback` calls, since
        # training is performed on AI Platform training remotely.

        # Handle TensorBoard/hyperparameter logging here. The TensorBoard
        # logs are used for passing metrics back from remote execution.
        self._add_logging(callbacks, trial)

        # Creating a save_model checkpoint callback with a saved model file path
        # specific to this trial. This is to prevent different trials from
        # overwriting each other.
        self._add_model_checkpoint_callback(
            callbacks, trial.trial_id)

        copied_fit_kwargs["callbacks"] = callbacks
        model = self.hypermodel.build(trial.hyperparameters)

        remote_dir = os.path.join(self.directory, str(trial.trial_id))

        job_id = f"{self._study_id}_{trial.trial_id}"

        # Create job spec from worker count and config
        job_spec = self._get_job_spec_from_config(job_id)

        tf.get_logger().info("Calling cloud_fit with %s", {
            "model": model,
            "remote_dir": remote_dir,
            "region": self._region,
            "project_id": self._project_id,
            "image_uri": self._container_uri,
            "job_id": job_id,
            "*fit_args": fit_args,
            "job_spec": job_spec,
            "**copied_fit_kwargs": copied_fit_kwargs})

        cloud_fit_client.cloud_fit(
            model=model,
            remote_dir=remote_dir,
            region=self._region,
            project_id=self._project_id,
            image_uri=self._container_uri,
            job_id=job_id,
            job_spec=job_spec,
            *fit_args,
            **copied_fit_kwargs)

        # Create an instance of tensorboard DirectoryWatcher to retrieve the
        # logs for this trial run
        train_log_path = os.path.join(
            self._get_tensorboard_log_dir(trial.trial_id), "train")

        # Tensorboard log watcher expects the path to exist
        tf.io.gfile.makedirs(train_log_path)

        tf.get_logger().info(
            f"Retrieving training logs for trial {trial.trial_id} from"
            f" {train_log_path}")
        train_log_reader = tf_utils.get_tensorboard_log_watcher_from_path(
            train_log_path)

        training_metrics = _TrainingMetrics([], {})
        epoch = 0

        while google_api_client.is_aip_training_job_running(
            job_id, self._project_id):

            time.sleep(_POLLING_INTERVAL_IN_SECONDS)

            # Retrieve available metrics if any
            training_metrics = self._get_remote_training_metrics(
                train_log_reader, training_metrics.partial_epoch_metrics)

            for epoch_metrics in training_metrics.completed_epoch_metrics:
                # TODO(b/169197272) Validate metrics contain oracle objective
                if epoch_metrics:
                    trial.status = self.oracle.update_trial(
                        trial_id=trial.trial_id,
                        metrics=epoch_metrics,
                        step=epoch)
                    epoch += 1

            if trial.status == "STOPPED":
                google_api_client.stop_aip_training_job(
                    job_id, self._project_id)
                break

        # Ensure the training job has completed successfully.
        if not google_api_client.wait_for_aip_training_job_completion(
            job_id, self._project_id):
            raise RuntimeError(
                "AI Platform Training job failed, see logs for details at "
                "https://console.cloud.google.com/ai-platform/jobs/"
                "{}/charts/cpu?project={}"
                .format(job_id, self._project_id))

        # Retrieve and report any remaining metrics
        training_metrics = self._get_remote_training_metrics(
            log_reader=train_log_reader,
            partial_epoch_metrics=training_metrics.partial_epoch_metrics)

        for epoch_metrics in training_metrics.completed_epoch_metrics:
            # TODO(b/169197272) Validate metrics contain oracle objective
            # TODO(b/170907612) Support submit partial results to Oracle
            if epoch_metrics:
                self.oracle.update_trial(
                    trial_id=trial.trial_id,
                    metrics=epoch_metrics,
                    step=epoch)
                epoch += 1

        # submit final epoch metrics
        if training_metrics.partial_epoch_metrics:
            self.oracle.update_trial(
                trial_id=trial.trial_id,
                metrics=training_metrics.partial_epoch_metrics,
                step=epoch)

        # Submit validation metrics if eval_files is provided at the end of
        # the trial.
        if copied_fit_kwargs.get("eval_files"):
            # Create an instance of tensorboard DirectoryWatcher to retrieve the
            # logs for validation run.
            val_log_path = os.path.join(
                self._get_tensorboard_log_dir(trial.trial_id), "validation")
            # Tensorboard log watcher expects the path to exist
            tf.io.gfile.makedirs(val_log_path)
            tf.get_logger().info(
                f"Retrieving validation logs for trial {trial.trial_id} from"
                f" {val_log_path}")
            val_log_reader = tf_utils.get_tensorboard_log_watcher_from_path(
                val_log_path)
            validation_metrics = _TrainingMetrics([], {})
            validation_metrics = self._get_remote_training_metrics(
                log_reader=val_log_reader,
                partial_epoch_metrics=validation_metrics.partial_epoch_metrics,
                is_validation=True)
            for metric in validation_metrics.completed_epoch_metrics:
                if metric:
                    self.oracle.update_trial(
                        trial_id=trial.trial_id,
                        metrics=metric)