in tensorflow_estimator/python/estimator/model_fn.py [0:0]
def _validate_eval_metric_ops(eval_metric_ops):
"""Validate eval_metric_ops for use in EstimatorSpec.
Args:
eval_metric_ops: Dict of metric results keyed by name.
The values of the dict can be one of the following: (1) instance of
`Metric` class. (2) Results of calling a metric_function, namely a
`(metric_tensor, update_op)` tuple. `metric_tensor` should be evaluated
without any impact on state (typically it is a pure computation based on
variables.). For example, it should not trigger the `update_op` or
require any input fetching.
Returns:
eval_metric_ops: Dict of metric results keyed by name.
Raises:
ValueError: If:
- one of the eval_metric_ops `Metric` objects has no updates.
- there is at least one `Metric` update or result, `Tensor`, or Op that is
not in the default graph.
TypeError: If:
- eval_metric_ops is not a dict or None.
- an element of eval_metric_ops is not a `Metric` or a 2-tuple.
- an element of eval_metric_ops has a sub-element that is not a `Tensor` or
an Op.
"""
if eval_metric_ops is None:
eval_metric_ops = {}
else:
if not isinstance(eval_metric_ops, dict):
raise TypeError(
'eval_metric_ops must be a dict, given: {}'.format(eval_metric_ops))
for key, value in six.iteritems(eval_metric_ops):
# TODO(psv): When we deprecate the old metrics, throw an error here if
# the value is not an instance of `Metric` class.
if isinstance(value, tf.keras.metrics.Metric):
if not value.updates: # Check if metric updates are available.
raise ValueError(
'Please call update_state(...) on the "{metric_name}" metric'
.format(metric_name=value.name))
else:
if not isinstance(value, tuple) or len(value) != 2:
raise TypeError(
'Values of eval_metric_ops must be (metric_value, update_op) '
'tuples, given: {} for key: {}'.format(value, key))
# Verify all tensors and ops are from default graph.
default_graph = tf.compat.v1.get_default_graph()
for key, value in list(six.iteritems(eval_metric_ops)):
if isinstance(value, tf.keras.metrics.Metric):
values_to_check = value.updates[:]
values_to_check.append(value.result())
else:
values_to_check = tf.nest.flatten(value)
for val in values_to_check:
if not (tf.executing_eagerly() or val.graph is default_graph):
raise ValueError(
_default_graph_error_message_template.format(
'eval_metric_ops', '{0}: {1}'.format(key, val.name)))
# Metric variables are by default not added to any collections. The variables
# are appended to the LOCAL_VARIABLES collection for initialization, and
# METRIC_VARIABLES for TFMA compatibility. Note that although collections are
# officially deprecated in TensorFlow 2, Estimators will continue using
# collections as long as it supports V1 graph mode.
vars_to_add = set()
for key, value in six.iteritems(eval_metric_ops):
if isinstance(value, tf.keras.metrics.Metric):
vars_to_add.update(value.variables)
# Convert Metric instances to (value_tensor, update_op) tuple.
eval_metric_ops[key] = (value.result(), value.updates[0])
_update_variable_collection(tf.compat.v1.GraphKeys.LOCAL_VARIABLES,
vars_to_add)
_update_variable_collection(tf.compat.v1.GraphKeys.METRIC_VARIABLES,
vars_to_add)
return eval_metric_ops