def _validate_eval_metric_ops()

in tensorflow_estimator/python/estimator/model_fn.py [0:0]


def _validate_eval_metric_ops(eval_metric_ops):
  """Validate eval_metric_ops for use in EstimatorSpec.

  Args:
    eval_metric_ops: Dict of metric results keyed by name.
      The values of the dict can be one of the following: (1) instance of
        `Metric` class. (2) Results of calling a metric_function, namely a
        `(metric_tensor, update_op)` tuple. `metric_tensor` should be evaluated
        without any impact on state (typically it is a pure computation based on
        variables.). For example, it should not trigger the `update_op` or
        require any input fetching.

  Returns:
    eval_metric_ops: Dict of metric results keyed by name.

  Raises:
    ValueError:  If:
     - one of the eval_metric_ops `Metric` objects has no updates.
     - there is at least one `Metric` update or result, `Tensor`, or Op that is
       not in the default graph.
    TypeError:   If:
     - eval_metric_ops is not a dict or None.
     - an element of eval_metric_ops is not a `Metric` or a 2-tuple.
     - an element of eval_metric_ops has a sub-element that is not a `Tensor` or
       an Op.
  """
  if eval_metric_ops is None:
    eval_metric_ops = {}
  else:
    if not isinstance(eval_metric_ops, dict):
      raise TypeError(
          'eval_metric_ops must be a dict, given: {}'.format(eval_metric_ops))
    for key, value in six.iteritems(eval_metric_ops):
      # TODO(psv): When we deprecate the old metrics, throw an error here if
      # the value is not an instance of `Metric` class.
      if isinstance(value, tf.keras.metrics.Metric):
        if not value.updates:  # Check if metric updates are available.
          raise ValueError(
              'Please call update_state(...) on the "{metric_name}" metric'
              .format(metric_name=value.name))
      else:
        if not isinstance(value, tuple) or len(value) != 2:
          raise TypeError(
              'Values of eval_metric_ops must be (metric_value, update_op) '
              'tuples, given: {} for key: {}'.format(value, key))
  # Verify all tensors and ops are from default graph.
  default_graph = tf.compat.v1.get_default_graph()
  for key, value in list(six.iteritems(eval_metric_ops)):
    if isinstance(value, tf.keras.metrics.Metric):
      values_to_check = value.updates[:]
      values_to_check.append(value.result())
    else:
      values_to_check = tf.nest.flatten(value)
    for val in values_to_check:
      if not (tf.executing_eagerly() or val.graph is default_graph):
        raise ValueError(
            _default_graph_error_message_template.format(
                'eval_metric_ops', '{0}: {1}'.format(key, val.name)))
  # Metric variables are by default not added to any collections. The variables
  # are appended to the LOCAL_VARIABLES collection for initialization, and
  # METRIC_VARIABLES for TFMA compatibility. Note that although collections are
  # officially deprecated in TensorFlow 2, Estimators will continue using
  # collections as long as it supports V1 graph mode.
  vars_to_add = set()
  for key, value in six.iteritems(eval_metric_ops):
    if isinstance(value, tf.keras.metrics.Metric):
      vars_to_add.update(value.variables)
      # Convert Metric instances to (value_tensor, update_op) tuple.
      eval_metric_ops[key] = (value.result(), value.updates[0])
  _update_variable_collection(tf.compat.v1.GraphKeys.LOCAL_VARIABLES,
                              vars_to_add)
  _update_variable_collection(tf.compat.v1.GraphKeys.METRIC_VARIABLES,
                              vars_to_add)

  return eval_metric_ops