def _process_tf_metrics_specs()

in tensorflow_model_analysis/metrics/metric_specs.py [0:0]


def _process_tf_metrics_specs(
    tf_metrics_specs: List[config_pb2.MetricsSpec],
    per_tf_spec_metric_instances: List[List[_TFMetricOrLoss]],
    eval_config: config_pb2.EvalConfig) -> metric_types.MetricComputations:
  """Processes list of TF MetricsSpecs to create computations."""

  # Wrap args into structure that is hashable so we can track unique arg sets.
  class UniqueArgs(
      NamedTuple('UniqueArgs',
                 [('model_name', str),
                  ('sub_key', Optional[metric_types.SubKey]),
                  ('aggregation_type', Optional[metric_types.AggregationType]),
                  ('class_weights', Tuple[Tuple[int, float], ...])])):
    pass

  def _create_private_tf_metrics(
      metrics: List[_TFMetricOrLoss]) -> List[_TFMetricOrLoss]:
    """Creates private versions of TF metrics."""
    result = []
    for m in metrics:
      if isinstance(m, tf.keras.metrics.Metric):
        result.append(_private_tf_metric(m))
      else:
        result.append(_private_tf_loss(m))
    return result

  #
  # Group TF metrics by the subkeys, models and outputs. This is done in reverse
  # because model and subkey processing is done outside of TF and so each unique
  # sub key combination needs to be run through a separate model instance. Note
  # that output_names are handled by the tf_metric_computation since all the
  # outputs are batch calculated in a single model evaluation call.
  #

  # UniqueArgs -> output_name -> [_TFMetricOrLoss]
  metrics_by_unique_args = collections.defaultdict(dict)
  for i, spec in enumerate(tf_metrics_specs):
    metrics = per_tf_spec_metric_instances[i]
    sub_keys_by_aggregation_type = _create_sub_keys(spec)
    # Keep track of metrics that can be shared between macro averaging and
    # binarization. For example, if macro averaging is being performed over 10
    # classes and 5 of the classes are also being binarized, then those 5
    # classes can be re-used by the macro averaging calculation. The remaining
    # 5 classes need to be added as private metrics since those classes were
    # not requested but are still needed for the macro averaging calculation.
    if None in sub_keys_by_aggregation_type:
      shared_sub_keys = set(sub_keys_by_aggregation_type[None])
    else:
      shared_sub_keys = set()
    for aggregation_type, sub_keys in sub_keys_by_aggregation_type.items():
      if aggregation_type:
        class_weights = tuple(sorted((_class_weights(spec) or {}).items()))
      else:
        class_weights = ()
      is_macro = (
          aggregation_type and (aggregation_type.macro_average or
                                aggregation_type.weighted_macro_average))
      for parent_sub_key in sub_keys:
        if is_macro:
          child_sub_keys = _macro_average_sub_keys(parent_sub_key,
                                                   _class_weights(spec))
        else:
          child_sub_keys = [parent_sub_key]
        for output_name in spec.output_names or ['']:
          for sub_key in child_sub_keys:
            if is_macro and sub_key not in shared_sub_keys:
              # Create private metrics for all non-shared metrics.
              instances = _create_private_tf_metrics(metrics)
            else:
              instances = metrics
            for model_name in spec.model_names or ['']:
              unique_args = UniqueArgs(
                  model_name, sub_key,
                  aggregation_type if not is_macro else None,
                  class_weights if not is_macro else ())
              if unique_args not in metrics_by_unique_args:
                # Tuple of weighted and unweighted metrics by output
                metrics_by_unique_args[unique_args] = (
                    collections.defaultdict(list),
                    collections.defaultdict(list))
              for instance in instances:
                for example_weighted in _example_weight_options(
                    eval_config, spec):
                  if example_weighted:
                    metrics_by_unique_args[unique_args][0][output_name].append(
                        instance)
                  else:
                    metrics_by_unique_args[unique_args][1][output_name].append(
                        instance)

  # Convert Unique args and outputs to calls to compute TF metrics
  result = []
  for args, metrics_by_output in metrics_by_unique_args.items():
    class_weights = dict(args.class_weights) if args.class_weights else None
    weighted_metrics_by_output, unweighted_metrics_by_output = metrics_by_output
    if weighted_metrics_by_output:
      result.extend(
          tf_metric_wrapper.tf_metric_computations(
              weighted_metrics_by_output,
              eval_config=eval_config,
              model_name=args.model_name,
              sub_key=args.sub_key,
              aggregation_type=args.aggregation_type,
              class_weights=class_weights,
              example_weighted=True))
    if unweighted_metrics_by_output:
      result.extend(
          tf_metric_wrapper.tf_metric_computations(
              unweighted_metrics_by_output,
              eval_config=eval_config,
              model_name=args.model_name,
              sub_key=args.sub_key,
              aggregation_type=args.aggregation_type,
              class_weights=class_weights,
              example_weighted=False))
  return result