def best_eval_metrics_tuple()

in adanet/core/eval_metrics.py [0:0]


  def best_eval_metrics_tuple(self, best_candidate_index, mode):
    """Returns (metric_fn, tensors) which computes the best ensemble's metrics.

    Specifically, when metric_fn(tensors) is called, it separates the metric ops
    by metric name. All candidates are not required to have the same metrics.
    When they all share a given metric, an additional metric is added which
    represents that of the best candidate.

    Args:
      best_candidate_index: `Tensor` index of the best candidate in the list.
      mode: Defines whether this is training, evaluation or inference. Eval
        metrics are only defined during evaluation. See `ModeKeys`.

    Returns:
      Dict of metric results keyed by name. The values of the dict are the
      results of calling a metric function.
    """

    if mode != tf.estimator.ModeKeys.EVAL:
      return None

    candidate_args = self._candidates_eval_metrics_store.flatten_args()
    subnetwork_args = self._subnetworks_eval_metrics_store.flatten_args()
    args = candidate_args + subnetwork_args
    args.append(tf.reshape(best_candidate_index, [1]))

    def _replay_eval_metrics(best_candidate_idx, eval_metric_ops):
      """Saves replay indices as eval metrics."""
      # _replay_indices_for_all is a dict: {candidate: [list of replay_indices]}
      # We are finding the max length replay list.
      pad_value = max([len(v) for _, v in self._replay_indices_for_all.items()])

      # Creating a matrix of (#candidate) times (max length replay indices).
      # Entry i,j is the jth replay index of the ith candidate (ensemble).
      replay_indices_as_tensor = tf.constant([
          value + [-1] * (pad_value - len(value))
          for _, value in self._replay_indices_for_all.items()
      ])

      # Passing the right entries (entries of the best candidate). Note: we use
      # TensorShape.as_list here so the code works on both TF 1.0 and 2.0.
      for iteration in range(replay_indices_as_tensor.get_shape().as_list()[1]):
        index_t = replay_indices_as_tensor[best_candidate_idx, iteration]
        eval_metric_ops["best_ensemble_index_{}".format(iteration)] = (index_t,
                                                                       index_t)

    def _best_eval_metrics_fn(*args):
      """Returns the best eval metrics."""

      with tf_compat.v1.variable_scope("best_eval_metrics"):
        args = list(args)
        idx, idx_update_op = tf_compat.v1.metrics.mean(args.pop())
        idx = tf.cast(idx, tf.int32)
        metric_fns = self._candidates_eval_metrics_store.metric_fns
        metric_fn_args = self._candidates_eval_metrics_store.pack_args(
            args[:len(candidate_args)])
        candidate_grouped_metrics = self._group_metric_ops(
            metric_fns, metric_fn_args)

        metric_fns = self._subnetworks_eval_metrics_store.metric_fns
        metric_fn_args = self._subnetworks_eval_metrics_store.pack_args(
            args[(len(args) - len(subnetwork_args)):])
        subnetwork_grouped_metrics = self._group_metric_ops(
            metric_fns, metric_fn_args)

        eval_metric_ops = {}
        for metric_name in sorted(candidate_grouped_metrics):
          metric_ops = candidate_grouped_metrics[metric_name]
          if len(metric_ops) != len(self._candidates):
            continue
          if metric_name == "loss":
            continue
          values, ops = list(six.moves.zip(*metric_ops))
          best_value = tf.stack(values)[idx]
          # All tensors in this function have been outfed from the TPU, so we
          # must update them manually, otherwise the TPU will hang indefinitely
          # for the value of idx to update.
          ops = list(ops)
          ops.append(idx_update_op)
          # Bundle subnetwork eval metric ops and ensemble "loss"" ops (which
          # is a restricted Estimator keyword) into other metric ops so that
          # they are computed.
          ensemble_loss_ops = candidate_grouped_metrics.get("loss", tf.no_op())
          all_ops = tf.group(ops, ensemble_loss_ops, subnetwork_grouped_metrics)
          eval_metric_ops[metric_name] = (best_value, all_ops)
        iteration_number = tf.constant(self._iteration_number)
        eval_metric_ops["iteration"] = (iteration_number, iteration_number)

        if self._replay_indices_for_all:
          _replay_eval_metrics(idx, eval_metric_ops)

        # tf.estimator.Estimator does not allow a "loss" key to be present in
        # its eval_metrics.
        assert "loss" not in eval_metric_ops
        return eval_metric_ops

    if not self._use_tpu:
      if not self._best_eval_metrics_tuple:
        best_ops = _call_eval_metrics((_best_eval_metrics_fn, args))
        self._best_eval_metrics_tuple = lambda: best_ops, []
      return self._best_eval_metrics_tuple

    return _best_eval_metrics_fn, args