def _process_tfma_metrics_specs()

in tensorflow_model_analysis/metrics/metric_specs.py [0:0]


def _process_tfma_metrics_specs(
    tfma_metrics_specs: List[config_pb2.MetricsSpec],
    per_tfma_spec_metric_instances: List[List[metric_types.Metric]],
    eval_config: config_pb2.EvalConfig,
    schema: Optional[schema_pb2.Schema]) -> metric_types.MetricComputations:
  """Processes list of TFMA MetricsSpecs to create computations."""

  #
  # Computations are per metric, so separate by metrics and the specs associated
  # with them.
  #

  # Dict[bytes,List[config_pb2.MetricSpec]] (hash(MetricConfig)->[MetricSpec])
  tfma_specs_by_metric_config = {}
  # Dict[bytes,metric_types.Metric] (hash(MetricConfig)->Metric)
  hashed_metrics = {}
  hashed_configs = {}
  for i, spec in enumerate(tfma_metrics_specs):
    for metric_config, metric in zip(spec.metrics,
                                     per_tfma_spec_metric_instances[i]):
      # Note that hashing by SerializeToString() is only safe if used within the
      # same process.
      config_hash = metric_config.SerializeToString()
      if config_hash not in tfma_specs_by_metric_config:
        hashed_metrics[config_hash] = metric
        hashed_configs[config_hash] = metric_config
        tfma_specs_by_metric_config[config_hash] = []
      tfma_specs_by_metric_config[config_hash].append(spec)

  #
  # Create computations for each metric.
  #

  result = []
  for config_hash, specs in tfma_specs_by_metric_config.items():
    metric = hashed_metrics[config_hash]
    metric_config = hashed_configs[config_hash]
    for spec in specs:
      sub_keys_by_aggregation_type = _create_sub_keys(spec)
      # Keep track of sub-keys that can be shared between macro averaging and
      # binarization. For example, if macro averaging is being performed over
      # 10 classes and 5 of the classes are also being binarized, then those 5
      # classes can be re-used by the macro averaging calculation. The
      # remaining 5 classes need to be added as private metrics since those
      # classes were not requested but are still needed for the macro
      # averaging calculation.
      if None in sub_keys_by_aggregation_type:
        shared_sub_keys = set(sub_keys_by_aggregation_type[None])
      else:
        shared_sub_keys = set()
      for aggregation_type, sub_keys in sub_keys_by_aggregation_type.items():
        class_weights = _class_weights(spec) if aggregation_type else None
        is_macro = (
            aggregation_type and (aggregation_type.macro_average or
                                  aggregation_type.weighted_macro_average))
        if is_macro:
          updated_sub_keys = []
          for sub_key in sub_keys:
            for key in _macro_average_sub_keys(sub_key, class_weights):
              if key not in shared_sub_keys:
                updated_sub_keys.append(key)
          if not updated_sub_keys:
            continue
          aggregation_type = aggregation_type if not is_macro else None
          class_weights = None
          sub_keys = updated_sub_keys
          instance = _private_tfma_metric(metric)
        else:
          instance = metric
        for example_weighted in _example_weight_options(eval_config, spec):
          result.extend(
              instance.computations(
                  eval_config=eval_config,
                  schema=schema,
                  model_names=list(spec.model_names) or [''],
                  output_names=list(spec.output_names) or [''],
                  sub_keys=sub_keys,
                  aggregation_type=aggregation_type,
                  class_weights=class_weights if class_weights else None,
                  example_weighted=example_weighted,
                  query_key=spec.query_key))
  return result