in tensorflow_model_analysis/metrics/metric_specs.py [0:0]
def _process_tf_metrics_specs(
tf_metrics_specs: List[config_pb2.MetricsSpec],
per_tf_spec_metric_instances: List[List[_TFMetricOrLoss]],
eval_config: config_pb2.EvalConfig) -> metric_types.MetricComputations:
"""Processes list of TF MetricsSpecs to create computations."""
# Wrap args into structure that is hashable so we can track unique arg sets.
class UniqueArgs(
NamedTuple('UniqueArgs',
[('model_name', str),
('sub_key', Optional[metric_types.SubKey]),
('aggregation_type', Optional[metric_types.AggregationType]),
('class_weights', Tuple[Tuple[int, float], ...])])):
pass
def _create_private_tf_metrics(
metrics: List[_TFMetricOrLoss]) -> List[_TFMetricOrLoss]:
"""Creates private versions of TF metrics."""
result = []
for m in metrics:
if isinstance(m, tf.keras.metrics.Metric):
result.append(_private_tf_metric(m))
else:
result.append(_private_tf_loss(m))
return result
#
# Group TF metrics by the subkeys, models and outputs. This is done in reverse
# because model and subkey processing is done outside of TF and so each unique
# sub key combination needs to be run through a separate model instance. Note
# that output_names are handled by the tf_metric_computation since all the
# outputs are batch calculated in a single model evaluation call.
#
# UniqueArgs -> output_name -> [_TFMetricOrLoss]
metrics_by_unique_args = collections.defaultdict(dict)
for i, spec in enumerate(tf_metrics_specs):
metrics = per_tf_spec_metric_instances[i]
sub_keys_by_aggregation_type = _create_sub_keys(spec)
# Keep track of metrics that can be shared between macro averaging and
# binarization. For example, if macro averaging is being performed over 10
# classes and 5 of the classes are also being binarized, then those 5
# classes can be re-used by the macro averaging calculation. The remaining
# 5 classes need to be added as private metrics since those classes were
# not requested but are still needed for the macro averaging calculation.
if None in sub_keys_by_aggregation_type:
shared_sub_keys = set(sub_keys_by_aggregation_type[None])
else:
shared_sub_keys = set()
for aggregation_type, sub_keys in sub_keys_by_aggregation_type.items():
if aggregation_type:
class_weights = tuple(sorted((_class_weights(spec) or {}).items()))
else:
class_weights = ()
is_macro = (
aggregation_type and (aggregation_type.macro_average or
aggregation_type.weighted_macro_average))
for parent_sub_key in sub_keys:
if is_macro:
child_sub_keys = _macro_average_sub_keys(parent_sub_key,
_class_weights(spec))
else:
child_sub_keys = [parent_sub_key]
for output_name in spec.output_names or ['']:
for sub_key in child_sub_keys:
if is_macro and sub_key not in shared_sub_keys:
# Create private metrics for all non-shared metrics.
instances = _create_private_tf_metrics(metrics)
else:
instances = metrics
for model_name in spec.model_names or ['']:
unique_args = UniqueArgs(
model_name, sub_key,
aggregation_type if not is_macro else None,
class_weights if not is_macro else ())
if unique_args not in metrics_by_unique_args:
# Tuple of weighted and unweighted metrics by output
metrics_by_unique_args[unique_args] = (
collections.defaultdict(list),
collections.defaultdict(list))
for instance in instances:
for example_weighted in _example_weight_options(
eval_config, spec):
if example_weighted:
metrics_by_unique_args[unique_args][0][output_name].append(
instance)
else:
metrics_by_unique_args[unique_args][1][output_name].append(
instance)
# Convert Unique args and outputs to calls to compute TF metrics
result = []
for args, metrics_by_output in metrics_by_unique_args.items():
class_weights = dict(args.class_weights) if args.class_weights else None
weighted_metrics_by_output, unweighted_metrics_by_output = metrics_by_output
if weighted_metrics_by_output:
result.extend(
tf_metric_wrapper.tf_metric_computations(
weighted_metrics_by_output,
eval_config=eval_config,
model_name=args.model_name,
sub_key=args.sub_key,
aggregation_type=args.aggregation_type,
class_weights=class_weights,
example_weighted=True))
if unweighted_metrics_by_output:
result.extend(
tf_metric_wrapper.tf_metric_computations(
unweighted_metrics_by_output,
eval_config=eval_config,
model_name=args.model_name,
sub_key=args.sub_key,
aggregation_type=args.aggregation_type,
class_weights=class_weights,
example_weighted=False))
return result