def train()

in adanet/core/estimator.py [0:0]


  def train(self,
            input_fn,
            hooks=None,
            steps=None,
            max_steps=None,
            saving_listeners=None):
    # pyformat: disable
    """Trains a model given training data :code:`input_fn`.

    NOTE: If a given input_fn raises an :code:`OutOfRangeError`, then *all* of
    training will exit. The best practice is to make the training dataset repeat
    forever, in order to perform model search for more than one iteration.

    Args:
      input_fn: A function that provides input data for training as minibatches.
        See [Premade Estimators](
        https://tensorflow.org/guide/premade_estimators#create_input_functions)
        for more information. The function should construct and return one of
        the following:
          * A :code:`tf.data.Dataset` object: Outputs of `Dataset` object must
            be a tuple `(features, labels)` with same constraints as below.
          * A tuple `(features, labels)`: Where `features` is a
            :code:`tf.Tensor` or a dictionary of string feature name to
            `Tensor` and `labels` is a :code:`Tensor` or a dictionary of string
            label name to `Tensor`. Both `features` and `labels` are consumed by
            `model_fn`. They should satisfy the expectation of `model_fn` from
            inputs.
      hooks: List of :code:`tf.train.SessionRunHook` subclass instances. Used
        for callbacks inside the training loop.
      steps: Number of steps for which to train the model. If :code:`None`,
        train forever or train until `input_fn` generates the
        :code:`tf.errors.OutOfRange` error or :code:`StopIteration` exception.
        `steps` works incrementally. If you call two times `train(steps=10)`
        then training occurs in total 20 steps. If :code:`OutOfRange` or
        :code:`StopIteration` occurs in the middle, training stops before 20
        steps. If you don't want to have incremental behavior please set
        `max_steps` instead. If set, `max_steps` must be :code:`None`.
      max_steps: Number of total steps for which to train model. If
        :code:`None`, train forever or train until `input_fn` generates the
        :code:`tf.errors.OutOfRange` error or :code:`StopIteration` exception.
        If set, `steps` must be `None`. If :code:`OutOfRange` or
        :code:`StopIteration` occurs in the middle, training stops before
        `max_steps` steps. Two calls to `train(steps=100)` means 200 training
        iterations. On the other hand, two calls to `train(max_steps=100)`
        means that the second call will not do any iteration since first call
        did all 100 steps.
      saving_listeners: list of :code:`CheckpointSaverListener` objects. Used
        for callbacks that run immediately before or after checkpoint savings.

    Returns:
      `self`, for chaining.

    Raises:
      ValueError: If both `steps` and `max_steps` are not `None`.
      ValueError: If either `steps` or `max_steps <= 0`.
    """
    # pyformat: enable

    if (steps is not None) and (max_steps is not None):
      raise ValueError("Can not provide both steps and max_steps.")
    if steps is not None and steps <= 0:
      raise ValueError("Must specify steps > 0, given: {}".format(steps))

    latest_checkpoint = tf.train.latest_checkpoint(self.model_dir)
    latest_global_steps = self._checkpoint_global_step(latest_checkpoint)
    if steps is not None:
      max_steps = latest_global_steps + steps

    # Each iteration of this AdaNet loop represents an `_Iteration`. The
    # current iteration number is stored as a variable in the checkpoint so
    # that training can be stopped and started at anytime.
    with monkey_patch_default_variable_placement_strategy():
      while True:
        latest_checkpoint = tf.train.latest_checkpoint(self.model_dir)
        latest_global_steps = self._checkpoint_global_step(latest_checkpoint)
        current_iteration = self._checkpoint_iteration_number(latest_checkpoint)
        logging.info("Beginning training AdaNet iteration %s",
                     current_iteration)
        self._iteration_ended = False

        # Delegate training to a temporary estimator instead of super to make
        # passing arguments more functional (via params).
        temp_estimator = self._create_temp_estimator(
            config=self.config,
            is_inside_training_loop=True,
            checkpoint_path=latest_checkpoint,
            hooks=hooks)
        result = temp_estimator.train(
            input_fn=input_fn,
            hooks=hooks,
            max_steps=max_steps,
            saving_listeners=saving_listeners)
        # In TensorFlow v2.0.0.rc1 and below, saving listeners are attached to
        # the first CheckpointSaverHook each time train is called. Instead, we
        # pass in the saving_listeners in the first AdaNet iteration only.
        if not tf_compat.version_greater_or_equal("2.0.0.rc1"):
          saving_listeners = None
        logging.info("Finished training Adanet iteration %s", current_iteration)

        # If training ended because the maximum number of training steps
        # occurred, exit training.
        latest_checkpoint = tf.train.latest_checkpoint(self.model_dir)
        global_steps = self._checkpoint_global_step(latest_checkpoint)
        if max_steps is not None and global_steps >= max_steps:
          logging.info("Training ended after %s global steps", global_steps)
          return result

        # If training ended for any reason other than the iteration ending,
        # exit training.
        if not self._iteration_ended:
          logging.info("Training stop requested")
          return result

        max_iterations = self._max_iterations
        if max_iterations and current_iteration + 1 >= max_iterations:
          logging.info(
              "Training ended after exceeding maximum AdaNet iterations")
          if steps is not None and global_steps - latest_global_steps < steps:
            logging.warning(
                "Both `max_iterations` and `steps` were specified, but "
                "`max_iterations` takes precedence over `steps`")
          return result

        logging.info("Beginning bookkeeping phase for iteration %s",
                     current_iteration)

        # The chief prepares the next AdaNet iteration, and increments the
        # iteration number by 1.
        if self.config.is_chief:
          with self._force_replication_strategy():
            self._execute_bookkeeping_phase(
                input_fn,
                current_iteration,
                train_hooks=hooks or [],
                checkpoint_path=latest_checkpoint)

        # This inner loop serves mainly for synchronizing the workers with the
        # chief during distributed training. Workers that finish training early
        # wait for the chief to prepare the next iteration and increment the
        # iteration number. Workers that are slow to finish training quickly
        # move onto the next iteration. And workers that go offline and return
        # online after training ended terminate gracefully.
        wait_for_chief = not self.config.is_chief
        timer = _CountDownTimer(self._worker_wait_timeout_secs)
        while wait_for_chief:
          # Fetch the latest checkpoint.
          latest_checkpoint = tf.train.latest_checkpoint(self.model_dir)

          # If the chief hits max_steps, it will stop training itself and not
          # increment the iteration number, so this is how the worker knows to
          # exit if it wakes up and the chief is gone.
          # TODO: Support steps parameter.
          if self._checkpoint_global_step(latest_checkpoint) >= max_steps:
            return result

          # In distributed training, a worker may end training before the chief
          # overwrites the checkpoint with the incremented iteration number. If
          # that is the case, it should wait for the chief to do so. Otherwise
          # the worker will get stuck waiting for its weights to be initialized.
          next_iteration = self._checkpoint_iteration_number(latest_checkpoint)
          if next_iteration > current_iteration:
            break
          logging.info("Iteration number in latest checkpoint: %d",
                       next_iteration)

          # Check timeout when waiting for potentially downed chief.
          if timer.secs_remaining() == 0:
            logging.error(
                "Chief job did not prepare iteration %d after %s secs. It "
                "may have been preempted, been turned down, or crashed. This "
                "worker is now exiting training.", current_iteration + 1,
                self._worker_wait_timeout_secs)
            return result
          logging.info("Waiting for chief to prepare iteration %d",
                       current_iteration + 1)
          time.sleep(self._worker_wait_secs)

        # Stagger starting workers to prevent training instability.
        # Mimics behavior of tf.estimator.train_and_evaluate.
        if not self.config.is_chief and self.config.task_type == "worker":
          task_id = self.config.task_id or 0
          # Stagger each worker up to 60 secs.
          delay_secs = min(self._max_worker_delay_secs,
                           (task_id + 1.) * self._delay_secs_per_worker)
          if delay_secs > 0.:
            logging.info("Waiting %d secs before continuing training.",
                         delay_secs)
            time.sleep(delay_secs)

        logging.info("Finished bookkeeping phase for iteration %s",
                     current_iteration)