in adanet/core/eval_metrics.py [0:0]
def best_eval_metrics_tuple(self, best_candidate_index, mode):
"""Returns (metric_fn, tensors) which computes the best ensemble's metrics.
Specifically, when metric_fn(tensors) is called, it separates the metric ops
by metric name. All candidates are not required to have the same metrics.
When they all share a given metric, an additional metric is added which
represents that of the best candidate.
Args:
best_candidate_index: `Tensor` index of the best candidate in the list.
mode: Defines whether this is training, evaluation or inference. Eval
metrics are only defined during evaluation. See `ModeKeys`.
Returns:
Dict of metric results keyed by name. The values of the dict are the
results of calling a metric function.
"""
if mode != tf.estimator.ModeKeys.EVAL:
return None
candidate_args = self._candidates_eval_metrics_store.flatten_args()
subnetwork_args = self._subnetworks_eval_metrics_store.flatten_args()
args = candidate_args + subnetwork_args
args.append(tf.reshape(best_candidate_index, [1]))
def _replay_eval_metrics(best_candidate_idx, eval_metric_ops):
"""Saves replay indices as eval metrics."""
# _replay_indices_for_all is a dict: {candidate: [list of replay_indices]}
# We are finding the max length replay list.
pad_value = max([len(v) for _, v in self._replay_indices_for_all.items()])
# Creating a matrix of (#candidate) times (max length replay indices).
# Entry i,j is the jth replay index of the ith candidate (ensemble).
replay_indices_as_tensor = tf.constant([
value + [-1] * (pad_value - len(value))
for _, value in self._replay_indices_for_all.items()
])
# Passing the right entries (entries of the best candidate). Note: we use
# TensorShape.as_list here so the code works on both TF 1.0 and 2.0.
for iteration in range(replay_indices_as_tensor.get_shape().as_list()[1]):
index_t = replay_indices_as_tensor[best_candidate_idx, iteration]
eval_metric_ops["best_ensemble_index_{}".format(iteration)] = (index_t,
index_t)
def _best_eval_metrics_fn(*args):
"""Returns the best eval metrics."""
with tf_compat.v1.variable_scope("best_eval_metrics"):
args = list(args)
idx, idx_update_op = tf_compat.v1.metrics.mean(args.pop())
idx = tf.cast(idx, tf.int32)
metric_fns = self._candidates_eval_metrics_store.metric_fns
metric_fn_args = self._candidates_eval_metrics_store.pack_args(
args[:len(candidate_args)])
candidate_grouped_metrics = self._group_metric_ops(
metric_fns, metric_fn_args)
metric_fns = self._subnetworks_eval_metrics_store.metric_fns
metric_fn_args = self._subnetworks_eval_metrics_store.pack_args(
args[(len(args) - len(subnetwork_args)):])
subnetwork_grouped_metrics = self._group_metric_ops(
metric_fns, metric_fn_args)
eval_metric_ops = {}
for metric_name in sorted(candidate_grouped_metrics):
metric_ops = candidate_grouped_metrics[metric_name]
if len(metric_ops) != len(self._candidates):
continue
if metric_name == "loss":
continue
values, ops = list(six.moves.zip(*metric_ops))
best_value = tf.stack(values)[idx]
# All tensors in this function have been outfed from the TPU, so we
# must update them manually, otherwise the TPU will hang indefinitely
# for the value of idx to update.
ops = list(ops)
ops.append(idx_update_op)
# Bundle subnetwork eval metric ops and ensemble "loss"" ops (which
# is a restricted Estimator keyword) into other metric ops so that
# they are computed.
ensemble_loss_ops = candidate_grouped_metrics.get("loss", tf.no_op())
all_ops = tf.group(ops, ensemble_loss_ops, subnetwork_grouped_metrics)
eval_metric_ops[metric_name] = (best_value, all_ops)
iteration_number = tf.constant(self._iteration_number)
eval_metric_ops["iteration"] = (iteration_number, iteration_number)
if self._replay_indices_for_all:
_replay_eval_metrics(idx, eval_metric_ops)
# tf.estimator.Estimator does not allow a "loss" key to be present in
# its eval_metrics.
assert "loss" not in eval_metric_ops
return eval_metric_ops
if not self._use_tpu:
if not self._best_eval_metrics_tuple:
best_ops = _call_eval_metrics((_best_eval_metrics_fn, args))
self._best_eval_metrics_tuple = lambda: best_ops, []
return self._best_eval_metrics_tuple
return _best_eval_metrics_fn, args