def _wrap_confusion_matrix_metric()

in tensorflow_model_analysis/metrics/tf_metric_wrapper.py [0:0]


def _wrap_confusion_matrix_metric(
    metric: tf.keras.metrics.Metric, eval_config: config_pb2.EvalConfig,
    model_name: str, output_name: str, sub_key: Optional[metric_types.SubKey],
    aggregation_type: Optional[metric_types.AggregationType],
    class_weights: Optional[Dict[int, float]],
    example_weighted: bool) -> metric_types.MetricComputations:
  """Returns confusion matrix metric wrapped in a more efficient computation."""

  # Special handling for AUC metric which supports aggregation inherently via
  # multi_label flag.
  if (isinstance(metric, tf.keras.metrics.AUC) and
      hasattr(metric, 'label_weights')):
    if metric.label_weights:
      if class_weights:
        raise ValueError(
            'class weights are configured in two different places: (1) via the '
            'tf.keras.metrics.AUC class (using "label_weights") and (2) via '
            'the MetricsSpecs (using "aggregate.class_weights"). Either remove '
            'the label_weights settings in the AUC class or remove the '
            'class_weights from the AggregationOptions: metric={}, '
            'class_weights={}'.format(metric, class_weights))
      class_weights = {i: v for i, v in enumerate(metric.label_weights)}
    if metric.multi_label:
      raise NotImplementedError('AUC.multi_label=True is not implemented yet.')

  sub_key = _verify_and_update_sub_key(model_name, output_name, sub_key, metric)
  key = metric_types.MetricKey(
      name=metric.name,
      model_name=model_name,
      output_name=output_name,
      aggregation_type=aggregation_type,
      sub_key=sub_key,
      example_weighted=example_weighted)

  metric_config = tf.keras.metrics.serialize(metric)

  thresholds = None
  num_thresholds = None
  # The top_k metrics have special settings. If we are setting the top_k value
  # outside of keras (i.e. using BinarizeOptions), then we need to set the
  # special threshold ourselves otherwise the default threshold of 0.5 is used.
  if (sub_key and sub_key.top_k is not None and
      _get_config_value(_TOP_K_KEY, metric_config) is None and
      _get_config_value(_THRESHOLDS_KEY, metric_config) is None and
      _get_config_value(_NUM_THRESHOLDS_KEY, metric_config) is None):
    thresholds = [float('-inf')]
  elif hasattr(metric, _THRESHOLDS_KEY):
    thresholds = metric.thresholds
  # Only one of either thresholds or num_thresholds should be used. Keras AUC
  # allows both but thresholds has more precedence.
  if thresholds is None and hasattr(metric, _NUM_THRESHOLDS_KEY):
    num_thresholds = metric.num_thresholds

  # Make sure matrices are calculated.
  computations = binary_confusion_matrices.binary_confusion_matrices(
      num_thresholds=num_thresholds,
      thresholds=thresholds,
      eval_config=eval_config,
      model_name=model_name,
      output_name=output_name,
      sub_key=sub_key,
      aggregation_type=aggregation_type,
      class_weights=class_weights,
      example_weighted=example_weighted)
  matrices_key = computations[-1].keys[-1]

  def result(
      metrics: Dict[metric_types.MetricKey, Any]
  ) -> Dict[metric_types.MetricKey, Any]:
    """Returns result derived from binary confusion matrices."""
    matrices = metrics[matrices_key]

    metric = tf.keras.metrics.deserialize(metric_config)
    if (isinstance(metric, tf.keras.metrics.AUC) or
        isinstance(metric, tf.keras.metrics.SpecificityAtSensitivity) or
        isinstance(metric, tf.keras.metrics.SensitivityAtSpecificity)):
      metric.true_positives.assign(np.array(matrices.tp))
      metric.true_negatives.assign(np.array(matrices.tn))
      metric.false_positives.assign(np.array(matrices.fp))
      metric.false_negatives.assign(np.array(matrices.fn))
    elif isinstance(metric, tf.keras.metrics.Precision):
      metric.true_positives.assign(np.array(matrices.tp))
      metric.false_positives.assign(np.array(matrices.fp))
    elif isinstance(metric, tf.keras.metrics.Recall):
      metric.true_positives.assign(np.array(matrices.tp))
      metric.false_negatives.assign(np.array(matrices.fn))
    elif isinstance(metric, tf.keras.metrics.TruePositives):
      metric.accumulator.assign(np.array(matrices.tp))
    elif isinstance(metric, tf.keras.metrics.FalsePositives):
      metric.accumulator.assign(np.array(matrices.fp))
    elif isinstance(metric, tf.keras.metrics.TrueNegatives):
      metric.accumulator.assign(np.array(matrices.tn))
    elif isinstance(metric, tf.keras.metrics.FalseNegatives):
      metric.accumulator.assign(np.array(matrices.fn))
    return {key: metric.result().numpy()}

  derived_computation = metric_types.DerivedMetricComputation(
      keys=[key], result=result)
  computations.append(derived_computation)
  return computations