def _get_best_ensemble_index()

in adanet/core/estimator.py [0:0]


  def _get_best_ensemble_index(self,
                               current_iteration,
                               input_hooks,
                               checkpoint_path=None):
    # type: (_Iteration, Sequence[tf_compat.SessionRunHook], Text) -> int
    """Returns the best candidate ensemble's index in this iteration.

    Evaluates the ensembles using an `Evaluator` when provided. Otherwise,
    it returns the index of the best candidate as defined by the `_Iteration`.

    Args:
      current_iteration: Current `_Iteration`.
      input_hooks: List of SessionRunHooks to be included when running.
      checkpoint_path: Checkpoint to use when determining the best index.

    Returns:
      Index of the best ensemble in the iteration's list of `_Candidates`.
    """
    # AdaNet Replay.
    if self._replay_config:
      best_index = self._replay_config.get_best_ensemble_index(
          current_iteration.number)
      if best_index is not None:
        return best_index

    # Skip the evaluation phase when there is only one candidate subnetwork.
    if len(current_iteration.candidates) == 1:
      logging.info("'%s' is the only ensemble",
                   current_iteration.candidates[0].ensemble_spec.name)
      return 0

    # The zero-th index candidate at iteration t>0 is always the
    # previous_ensemble.
    if current_iteration.number > 0 and self._force_grow and (len(
        current_iteration.candidates) == 2):
      logging.info("With `force_grow` enabled, '%s' is the only ensemble",
                   current_iteration.candidates[1].ensemble_spec.name)
      return 1

    logging.info("Starting ensemble evaluation for iteration %s",
                 current_iteration.number)
    for hook in input_hooks:
      hook.begin()
    with tf_compat.v1.Session(config=self.config.session_config) as sess:
      init = tf.group(
          tf_compat.v1.global_variables_initializer(),
          tf_compat.v1.local_variables_initializer(),
          tf_compat.v1.tables_initializer(),
          current_iteration.estimator_spec.scaffold.local_init_op if isinstance(
              current_iteration.estimator_spec,
              tf.estimator.EstimatorSpec) else tf.no_op())
      sess.run(init)

      if self._enable_v2_checkpoint:
        status = current_iteration.checkpoint.restore(checkpoint_path)
        status.expect_partial()  # Optional sanity checks.
        status.initialize_or_restore(sess)
      else:
        saver = tf_compat.v1.train.Saver(sharded=True)
        saver.restore(sess, checkpoint_path)

      coord = tf.train.Coordinator()
      for hook in input_hooks:
        hook.after_create_session(sess, coord)

      tf_compat.v1.train.start_queue_runners(sess=sess, coord=coord)
      ensemble_metrics = []
      for candidate in current_iteration.candidates:
        metrics = candidate.ensemble_spec.eval_metrics.eval_metrics_ops()
        metrics["adanet_loss"] = tf_compat.v1.metrics.mean(
            candidate.ensemble_spec.adanet_loss)
        ensemble_metrics.append(metrics)
      if self._evaluator:
        metric_name = self._evaluator.metric_name
        metrics = self._evaluator.evaluate(sess, ensemble_metrics)
        objective_fn = self._evaluator.objective_fn
      else:
        metric_name = "adanet_loss"
        metrics = sess.run(
            [c.adanet_loss for c in current_iteration.candidates])
        objective_fn = np.nanargmin

      values = []
      for i in range(len(current_iteration.candidates)):
        ensemble_name = current_iteration.candidates[i].ensemble_spec.name
        values.append("{}/{} = {:.6f}".format(metric_name, ensemble_name,
                                              metrics[i]))
      logging.info("Computed ensemble metrics: %s", ", ".join(values))
      if self._force_grow and current_iteration.number > 0:
        logging.info(
            "The `force_grow` override is enabled, so the "
            "the performance of the previous ensemble will be ignored.")
        # NOTE: The zero-th index candidate at iteration t>0 is always the
        # previous_ensemble.
        metrics = metrics[1:]
        index = objective_fn(metrics) + 1
      else:
        index = objective_fn(metrics)
    logging.info("Finished ensemble evaluation for iteration %s",
                 current_iteration.number)
    logging.info("'%s' at index %s is the best ensemble",
                 current_iteration.candidates[index].ensemble_spec.name, index)
    return index