def _create_iteration()

in adanet/core/estimator.py [0:0]


  def _create_iteration(self,
                        features,
                        labels,
                        mode,
                        config,
                        is_growing_phase,
                        checkpoint_path,
                        hooks,
                        best_ensemble_index_override=None):
    """Constructs the TF ops and variables for the current iteration.

    Args:
      features: Dictionary of `Tensor` objects keyed by feature name.
      labels: Labels `Tensor` or a dictionary of string label name to `Tensor`
        (for multi-head). Can be `None`.
      mode: Defines whether this is training, evaluation or prediction. See
        `ModeKeys`.
      config: The current `tf.estimator.RunConfig`.
      is_growing_phase: Whether we are in the AdaNet graph growing phase.
      checkpoint_path: Path of the checkpoint to use. When `None`, this method
        uses the latest checkpoint instead.
      hooks: A list of `tf.estimator.SessionRunHooks`.
      best_ensemble_index_override: Integer index to identify the latest
        iteration's best ensemble candidate instead of computing the best
        ensemble index dynamically conditional on the ensemble AdaNet losses.

    Returns:
      A two-tuple of the current `_Iteration`, and list of variables from
        the previous iteration for restoring during the graph growing phase.
    """

    # Use the evaluation checkpoint path to get both the iteration number and
    # variable values to avoid any race conditions between the first and second
    # checkpoint reads.
    iteration_number = self._checkpoint_iteration_number(checkpoint_path)

    if mode == tf.estimator.ModeKeys.EVAL and checkpoint_path is None:
      # This should only happen during some tests, so we log instead of
      # asserting here.
      logging.warning("There are no checkpoints available during evaluation. "
                      "Variables will be initialized to their defaults.")

    if is_growing_phase:
      assert mode == tf.estimator.ModeKeys.TRAIN
      assert config.is_chief
      iteration_number += 1

    # Only record summaries when training.
    skip_summaries = (mode != tf.estimator.ModeKeys.TRAIN or is_growing_phase)
    base_global_step = 0
    with tf_compat.v1.variable_scope("adanet"):
      previous_iteration = None
      previous_ensemble_spec = None
      previous_ensemble = None
      previous_ensemble_summary = None
      previous_ensemble_subnetwork_builders = None
      architecture = None
      for i in range(iteration_number):
        architecture_filename = self._architecture_filename(i)
        if not tf.io.gfile.exists(architecture_filename):
          continue
        architecture = self._read_architecture(architecture_filename)
        logging.info(
            "Importing architecture from %s: [%s].", architecture_filename,
            ", ".join(
                sorted([
                    "'{}:{}'".format(t, n)
                    for t, n in architecture.subnetworks_grouped_by_iteration
                ])))
        base_global_step = architecture.global_step
        previous_iteration = self._architecture_ensemble_spec(
            architecture, i, features, mode, labels, previous_ensemble_spec,
            config, previous_iteration, hooks)
        previous_ensemble_spec = previous_iteration.candidates[-1].ensemble_spec
        previous_ensemble = previous_ensemble_spec.ensemble
        previous_ensemble_summary = self._summary_maker(
            namespace="ensemble",
            scope=previous_ensemble_spec.name,
            skip_summary=skip_summaries)
        previous_ensemble_subnetwork_builders = (
            previous_ensemble_spec.subnetwork_builders)
      previous_iteration_vars = None
      if is_growing_phase:
        # Keep track of the previous iteration variables so we can restore them
        # from the previous checkpoint after growing the graph. After this line,
        # any variables created will not have a matching one in the checkpoint
        # until it gets overwritten.
        # Note: It's not possible to just create a tf.train.Saver here since
        # this code is also run on TPU, which does not support creating Savers
        # inside model_fn.
        previous_iteration_vars = (
            tf_compat.v1.get_collection(tf_compat.v1.GraphKeys.GLOBAL_VARIABLES)
            + tf_compat.v1.get_collection(
                tf_compat.v1.GraphKeys.SAVEABLE_OBJECTS))
      previous_ensemble_reports, all_reports = [], []
      if self._report_materializer:
        previous_ensemble_reports, all_reports = (
            self._collate_subnetwork_reports(iteration_number))

      subnetwork_builders = self._call_generate_candidates(
          previous_ensemble=previous_ensemble,
          iteration_number=iteration_number,
          previous_ensemble_reports=previous_ensemble_reports,
          all_reports=all_reports,
          config=config)
      ensemble_candidates = []
      for ensemble_strategy in self._ensemble_strategies:
        ensemble_candidates += ensemble_strategy.generate_ensemble_candidates(
            subnetwork_builders, previous_ensemble_subnetwork_builders)
      current_iteration = self._iteration_builder.build_iteration(
          base_global_step=base_global_step,
          iteration_number=iteration_number,
          ensemble_candidates=ensemble_candidates,
          subnetwork_builders=subnetwork_builders,
          features=features,
          labels=labels,
          mode=mode,
          config=config,
          previous_ensemble_summary=previous_ensemble_summary,
          best_ensemble_index_override=best_ensemble_index_override,
          previous_iteration=previous_iteration)
    return current_iteration, previous_iteration_vars