def _train_with_estimator_spec()

in tensorflow_estimator/python/estimator/estimator.py [0:0]


  def _train_with_estimator_spec(self, estimator_spec, worker_hooks, hooks,
                                 global_step_tensor, saving_listeners):
    """Train a model with the given Estimator Spec."""
    if (self._warm_start_settings and
        not tf.train.latest_checkpoint(self._model_dir)):
      tf.compat.v1.logging.info('Warm-starting with WarmStartSettings: %s' %
                                (self._warm_start_settings,))
      tf.compat.v1.train.warm_start(*self._warm_start_settings)
    # Check if the user created a loss summary, and add one if they didn't.
    # We assume here that the summary is called 'loss'. If it is not, we will
    # make another one with the name 'loss' to ensure it shows up in the right
    # graph in TensorBoard.
    if not any([
        x.op.name == 'loss' for x in ops.get_collection(ops.GraphKeys.SUMMARIES)
    ]):
      summary.scalar('loss', estimator_spec.loss)
    ops.add_to_collection(ops.GraphKeys.LOSSES, estimator_spec.loss)
    worker_hooks.extend(hooks)
    worker_hooks.append(tf.compat.v1.train.NanTensorHook(estimator_spec.loss))
    if self._config.log_step_count_steps is not None:
      worker_hooks.append(
          tf.compat.v1.train.LoggingTensorHook(
              {
                  'loss': estimator_spec.loss,
                  'step': global_step_tensor
              },
              every_n_iter=self._config.log_step_count_steps))
    worker_hooks.extend(estimator_spec.training_hooks)

    if not (estimator_spec.scaffold.saver or
            tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.SAVERS)):
      tf.compat.v1.add_to_collection(
          tf.compat.v1.GraphKeys.SAVERS,
          tf.compat.v1.train.Saver(
              sharded=True,
              max_to_keep=self._config.keep_checkpoint_max,
              keep_checkpoint_every_n_hours=(
                  self._config.keep_checkpoint_every_n_hours),
              defer_build=True,
              save_relative_paths=True))

    if (self._config.cluster_spec and type(
        self._train_distribution).__name__ in ('CollectiveAllReduceStrategy',
                                               'CollectiveAllReduceStrategyV1',
                                               'MultiWorkerMirroredStrategy')):
      return self._train_with_estimator_spec_distributed(
          estimator_spec, worker_hooks, saving_listeners)

    chief_hooks = []
    all_hooks = worker_hooks + list(estimator_spec.training_chief_hooks)
    saver_hooks = [
        h for h in all_hooks
        if isinstance(h, tf.compat.v1.train.CheckpointSaverHook)
    ]
    if (self._config.save_checkpoints_secs or
        self._config.save_checkpoints_steps):
      if not saver_hooks:
        chief_hooks = [
            tf.compat.v1.train.CheckpointSaverHook(
                self._model_dir,
                save_secs=self._config.save_checkpoints_secs,
                save_steps=self._config.save_checkpoints_steps,
                scaffold=estimator_spec.scaffold,
                save_graph_def=self._config.checkpoint_save_graph_def)
        ]
        saver_hooks = [chief_hooks[0]]
    if saving_listeners:
      if not saver_hooks:
        raise ValueError(
            'There should be a CheckpointSaverHook to use saving_listeners. '
            'Please set one of the RunConfig.save_checkpoints_steps or '
            'RunConfig.save_checkpoints_secs.')
      else:
        # It is expected to have one CheckpointSaverHook. If multiple, we pick
        # up the first one to add listener.
        for listener in saving_listeners:
          # pylint: disable=protected-access
          if listener not in saver_hooks[0]._listeners:
            saver_hooks[0]._listeners.append(listener)
          # pylint: disable=protected-access

    # Add summary hooks to worker 0 if we are running with a master, to ensure
    # that summaries are written at correct intervals even with long-running
    # evaluations.
    save_summary_steps = self._config.save_summary_steps
    log_step_count_steps = self._config.log_step_count_steps

    # Check existence of appropriate cluster spec fields, as well as master and
    # worker nodes. As master also performs evaluation, summary writing must
    # occur on a different node. The presence of a worker is also checked to
    # prevent reassigning hooks for single-replica jobs with just a master node.
    if (self._config.cluster_spec and self._config.cluster_spec.jobs and
        (run_config.TaskType.WORKER in self._config.cluster_spec.jobs) and
        (run_config.TaskType.MASTER in self._config.cluster_spec.jobs)):
      # Update config values to prevent the default hooks from being created on
      # the master or other workers.
      save_summary_steps = 0
      log_step_count_steps = None

      if (self._config.task_type == run_config.TaskType.WORKER and
          self._config.task_id == 0):
        if (self._config.save_summary_steps and
            self._config.save_summary_steps > 0):
          worker_hooks.append(
              tf.compat.v1.train.SummarySaverHook(
                  save_steps=self._config.save_summary_steps,
                  output_dir=self._config.model_dir,
                  scaffold=estimator_spec.scaffold))

        if (self._config.log_step_count_steps and
            self._config.log_step_count_steps > 0):
          worker_hooks.append(
              tf.compat.v1.train.StepCounterHook(
                  every_n_steps=self._config.log_step_count_steps,
                  output_dir=self._config.model_dir))

    with training.MonitoredTrainingSession(
        master=self._config.master,
        is_chief=self._config.is_chief,
        checkpoint_dir=self._model_dir,
        scaffold=estimator_spec.scaffold,
        hooks=worker_hooks,
        chief_only_hooks=(tuple(chief_hooks) +
                          tuple(estimator_spec.training_chief_hooks)),
        save_checkpoint_secs=0,  # Saving is handled by a hook.
        save_summaries_steps=save_summary_steps,
        config=self._session_config,
        max_wait_secs=self._config.session_creation_timeout_secs,
        log_step_count_steps=log_step_count_steps,
        save_graph_def=self._config.checkpoint_save_graph_def) as mon_sess:
      loss = None
      current_step = 0
      while not mon_sess.should_stop():
        current_step += 1
        # just as keras(https://github.com/tensorflow/tensorflow/blob/v2.4.1/tensorflow/python/keras/engine/training.py#L1093),
        # trace should be enabled for every step
        with trace.Trace('train', step_num=current_step, _r=1):
          _, loss = mon_sess.run([estimator_spec.train_op, estimator_spec.loss])
      if current_step == 0:
        tf.compat.v1.logging.warn('Training with estimator made no steps. '
                                  'Perhaps input is empty or misspecified.')
    return loss