def binary_confusion_matrices()

in tensorflow_model_analysis/metrics/binary_confusion_matrices.py [0:0]


def binary_confusion_matrices(
    num_thresholds: Optional[int] = None,
    thresholds: Optional[List[float]] = None,
    name: Optional[str] = None,
    eval_config: Optional[config_pb2.EvalConfig] = None,
    model_name: str = '',
    output_name: str = '',
    sub_key: Optional[metric_types.SubKey] = None,
    aggregation_type: Optional[metric_types.AggregationType] = None,
    class_weights: Optional[Dict[int, float]] = None,
    example_weighted: bool = False,
    use_histogram: Optional[bool] = None,
    extract_label_prediction_and_weight: Optional[Callable[
        ..., Any]] = metric_util.to_label_prediction_example_weight,
    preprocessor: Optional[Callable[..., Any]] = None,
    examples_name: Optional[str] = None,
    example_id_key: Optional[str] = None,
    example_ids_count: Optional[int] = None,
    fractional_labels: float = True) -> metric_types.MetricComputations:
  """Returns metric computations for computing binary confusion matrices.

  Args:
    num_thresholds: Number of thresholds to use. Thresholds will be calculated
      using linear interpolation between 0.0 and 1.0 with equidistant values and
      bondardaries at -epsilon and 1.0+epsilon. Values must be > 0. Only one of
      num_thresholds or thresholds should be used. If used, num_thresholds must
      be > 1.
    thresholds: A specific set of thresholds to use. The caller is responsible
      for marking the boundaries with +/-epsilon if desired. Only one of
      num_thresholds or thresholds should be used. For metrics computed at top k
      this may be a single negative threshold value (i.e. -inf).
    name: Metric name containing binary_confusion_matrices.Matrices.
    eval_config: Eval config.
    model_name: Optional model name (if multi-model evaluation).
    output_name: Optional output name (if multi-output model type).
    sub_key: Optional sub key.
    aggregation_type: Optional aggregation type.
    class_weights: Optional class weights to apply to multi-class / multi-label
      labels and predictions prior to flattening (when micro averaging is used).
    example_weighted: True if example weights should be applied.
    use_histogram: If true, matrices will be derived from calibration
      histograms.
    extract_label_prediction_and_weight: User-provided function argument that
      yields label, prediction, and example weights for use in calculations
      (relevant only when use_histogram flag is not true).
    preprocessor: User-provided preprocessor for including additional extracts
      in StandardMetricInputs (relevant only when use_histogram flag is not
      true).
    examples_name: Metric name containing binary_confusion_matrices.Examples.
      (relevant only when use_histogram flag is not true and example_id_key is
      set).
    example_id_key: Feature key containing example id (relevant only when
      use_histogram flag is not true).
    example_ids_count: Max number of example ids to be extracted for false
      positives and false negatives (relevant only when use_histogram flag is
      not true).
    fractional_labels: If true, each incoming tuple of (label, prediction, and
      example weight) will be split into two tuples as follows (where l, p, w
      represent the resulting label, prediction, and example weight values): (1)
        l = 0.0, p = prediction, and w = example_weight * (1.0 - label) (2) l =
        1.0, p = prediction, and w = example_weight * label If enabled, an
        exception will be raised if labels are not within [0, 1]. The
        implementation is such that tuples associated with a weight of zero are
        not yielded. This means it is safe to enable fractional_labels even when
        the labels only take on the values of 0.0 or 1.0.

  Raises:
    ValueError: If both num_thresholds and thresholds are set at the same time.
  """
  # TF v1 Keras AUC turns num_thresholds parameters into thresholds which
  # circumvents sharing of settings. If the thresholds match the interpolated
  # version of the thresholds then reset back to num_thresholds.
  if thresholds:
    if (not num_thresholds and
        thresholds == _interpolated_thresholds(len(thresholds))):
      num_thresholds = len(thresholds)
      thresholds = None
    elif (num_thresholds
          in (DEFAULT_NUM_THRESHOLDS, _KERAS_DEFAULT_NUM_THRESHOLDS) and
          len(thresholds) == num_thresholds - 2):
      thresholds = None
  if num_thresholds is not None and thresholds is not None:
    raise ValueError(
        'only one of thresholds or num_thresholds can be set at a time: '
        f'num_thesholds={num_thresholds}, thresholds={thresholds}, '
        f'len(thresholds)={len(thresholds)})')
  if num_thresholds is None and thresholds is None:
    num_thresholds = DEFAULT_NUM_THRESHOLDS
  if num_thresholds is not None:
    if num_thresholds <= 1:
      raise ValueError('num_thresholds must be > 1')
    # The interpolation strategy used here matches that used by keras for AUC.
    thresholds = _interpolated_thresholds(num_thresholds)
    thresholds_name_part = str(num_thresholds)
  else:
    thresholds_name_part = str(list(thresholds))

  if use_histogram is None:
    use_histogram = (
        num_thresholds is not None or
        (len(thresholds) == 1 and thresholds[0] < 0))

  if use_histogram and (examples_name or example_id_key or example_ids_count):
    raise ValueError('Example sampling is only performed when not using the '
                     'histogram computation. However, use_histogram is true '
                     f'and one of examples_name ("{examples_name}"), '
                     f'examples_id_key ("{example_id_key}"), '
                     f'or example_ids_count ({example_ids_count}) was '
                     'provided, which will have no effect.')

  if examples_name and not (example_id_key and example_ids_count):
    raise ValueError('examples_name provided but either example_id_key or '
                     'example_ids_count was not. Examples will only be '
                     'returned when both example_id_key and '
                     'example_ids_count are provided, and when the '
                     'non-histogram computation is used. '
                     f'example_id_key: "{example_id_key}" '
                     f'example_ids_count: {example_ids_count}')

  if name is None:
    name = f'{BINARY_CONFUSION_MATRICES_NAME}_{thresholds_name_part}'
  if examples_name is None:
    examples_name = f'{BINARY_CONFUSION_EXAMPLES_NAME}_{thresholds_name_part}'
  matrices_key = metric_types.MetricKey(
      name=name,
      model_name=model_name,
      output_name=output_name,
      sub_key=sub_key,
      example_weighted=example_weighted)
  examples_key = metric_types.MetricKey(
      name=examples_name,
      model_name=model_name,
      output_name=output_name,
      sub_key=sub_key,
      example_weighted=example_weighted)

  computations = []
  if use_histogram:
    # Use calibration histogram to calculate matrices. For efficiency (unless
    # all predictions are matched - i.e. thresholds <= 0) we will assume that
    # other metrics will make use of the calibration histogram and re-use the
    # default histogram for the given model_name/output_name/sub_key. This is
    # also required to get accurate counts at the threshold boundaries. If this
    # becomes an issue, then calibration histogram can be updated to support
    # non-linear boundaries.
    computations = calibration_histogram.calibration_histogram(
        eval_config=eval_config,
        num_buckets=(
            # For precision/recall_at_k were a single large negative threshold
            # is used, we only need one bucket. Note that the histogram will
            # actually have 2 buckets: one that we set (which handles
            # predictions > -1.0) and a default catch-all bucket (i.e. bucket 0)
            # that the histogram creates for large negative predictions (i.e.
            # predictions <= -1.0).
            1 if len(thresholds) == 1 and thresholds[0] <= 0 else None),
        model_name=model_name,
        output_name=output_name,
        sub_key=sub_key,
        aggregation_type=aggregation_type,
        class_weights=class_weights,
        example_weighted=example_weighted)
    input_metric_key = computations[-1].keys[-1]
    output_metric_keys = [matrices_key]
  else:
    if bool(example_ids_count) != bool(example_id_key):
      raise ValueError('Both of example_ids_count and example_id_key must be '
                       f'set, but got example_id_key: "{example_id_key}" and '
                       f'example_ids_count: {example_ids_count}.')
    computations = _binary_confusion_matrix_computation(
        eval_config=eval_config,
        thresholds=thresholds,
        model_name=model_name,
        output_name=output_name,
        sub_key=sub_key,
        extract_label_prediction_and_weight=extract_label_prediction_and_weight,
        preprocessor=preprocessor,
        example_id_key=example_id_key,
        example_ids_count=example_ids_count,
        aggregation_type=aggregation_type,
        class_weights=class_weights,
        example_weighted=example_weighted,
        fractional_labels=fractional_labels)
    input_metric_key = computations[-1].keys[-1]
    # matrices_key is last for backwards compatibility with code that:
    #   1) used this computation as an input for a derived computation
    #   2) only accessed the matrix counts
    #   3) used computations[-1].keys[-1] to access the input key
    output_metric_keys = [examples_key, matrices_key]

  def result(
      metrics: Dict[metric_types.MetricKey, Any]
  ) -> Dict[metric_types.MetricKey, Union[Matrices, Examples]]:
    """Returns binary confusion matrices."""
    matrices = None
    if use_histogram:
      if len(thresholds) == 1 and thresholds[0] < 0:
        # This case is used when all positive prediction values are relevant
        # matches (e.g. when calculating top_k for precision/recall where the
        # non-top_k values are expected to have been set to float('-inf')).
        histogram = metrics[input_metric_key]
      else:
        # Calibration histogram uses intervals of the form [start, end) where
        # the prediction >= start. The confusion matrices want intervals of the
        # form (start, end] where the prediction > start. Add a small epsilon so
        # that >= checks don't match. This correction shouldn't be needed in
        # practice but allows for correctness in small tests.
        rebin_thresholds = [t + _EPSILON if t != 0 else t for t in thresholds]
        if thresholds[0] >= 0:
          # Add -epsilon bucket to account for differences in histogram vs
          # confusion matrix intervals mentioned above. If the epsilon bucket is
          # missing the false negatives and false positives will be 0 for the
          # first threshold.
          rebin_thresholds = [-_EPSILON] + rebin_thresholds
        if thresholds[-1] < 1.0:
          # If the last threshold < 1.0, then add a fence post at 1.0 + epsilon
          # othewise true negatives and true positives will be overcounted.
          rebin_thresholds = rebin_thresholds + [1.0 + _EPSILON]
        histogram = calibration_histogram.rebin(rebin_thresholds,
                                                metrics[input_metric_key])
      matrices = _histogram_to_binary_confusion_matrices(thresholds, histogram)
      return {matrices_key: matrices}
    else:
      matrices, examples = _accumulator_to_matrices_and_examples(
          thresholds, metrics[input_metric_key])
      return {matrices_key: matrices, examples_key: examples}

  derived_computation = metric_types.DerivedMetricComputation(
      keys=output_metric_keys, result=result)
  computations.append(derived_computation)
  return computations