def _confusion_matrix_at_thresholds()

in easy_rec/python/core/easyrec_metrics/distribute_metrics_impl_pai.py [0:0]


def _confusion_matrix_at_thresholds(labels,
                                    predictions,
                                    thresholds,
                                    weights=None,
                                    includes=None):
  """Computes true_positives, false_negatives, true_negatives, false_positives.

  This function creates up to four local variables, `true_positives`,
  `true_negatives`, `false_positives` and `false_negatives`.
  `true_positive[i]` is defined as the total weight of values in `predictions`
  above `thresholds[i]` whose corresponding entry in `labels` is `True`.
  `false_negatives[i]` is defined as the total weight of values in `predictions`
  at most `thresholds[i]` whose corresponding entry in `labels` is `True`.
  `true_negatives[i]` is defined as the total weight of values in `predictions`
  at most `thresholds[i]` whose corresponding entry in `labels` is `False`.
  `false_positives[i]` is defined as the total weight of values in `predictions`
  above `thresholds[i]` whose corresponding entry in `labels` is `False`.

  For estimation of these metrics over a stream of data, for each metric the
  function respectively creates an `update_op` operation that updates the
  variable and returns its value.

  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.

  Args:
    labels: A `Tensor` whose shape matches `predictions`. Will be cast to
      `bool`.
    predictions: A floating point `Tensor` of arbitrary shape and whose values
      are in the range `[0, 1]`.
    thresholds: A python list or tuple of float thresholds in `[0, 1]`.
    weights: Optional `Tensor` whose rank is either 0, or the same rank as
      `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
      be either `1`, or the same as the corresponding `labels` dimension).
    includes: Tuple of keys to return, from 'tp', 'fn', 'tn', fp'. If `None`,
        default to all four.

  Returns:
    values: Dict of variables of shape `[len(thresholds)]`. Keys are from
        `includes`.
    update_ops: Dict of operations that increments the `values`. Keys are from
        `includes`.

  Raises:
    ValueError: If `predictions` and `labels` have mismatched shapes, or if
      `weights` is not `None` and its shape doesn't match `predictions`, or if
      `includes` contains invalid keys.
  """
  all_includes = ('tp', 'fn', 'tn', 'fp')
  if includes is None:
    includes = all_includes
  else:
    for include in includes:
      if include not in all_includes:
        raise ValueError('Invalid key: %s.' % include)

  with ops.control_dependencies([
      check_ops.assert_greater_equal(
          predictions,
          math_ops.cast(0.0, dtype=predictions.dtype),
          message='predictions must be in [0, 1]'),
      check_ops.assert_less_equal(
          predictions,
          math_ops.cast(1.0, dtype=predictions.dtype),
          message='predictions must be in [0, 1]')
  ]):
    predictions, labels, weights = _remove_squeezable_dimensions(
        predictions=math_ops.to_float(predictions),
        labels=math_ops.cast(labels, dtype=dtypes.bool),
        weights=weights)

  num_thresholds = len(thresholds)

  # Reshape predictions and labels.
  predictions_2d = array_ops.reshape(predictions, [-1, 1])
  labels_2d = array_ops.reshape(
      math_ops.cast(labels, dtype=dtypes.bool), [1, -1])

  # Use static shape if known.
  num_predictions = predictions_2d.get_shape().as_list()[0]

  # Otherwise use dynamic shape.
  if num_predictions is None:
    num_predictions = array_ops.shape(predictions_2d)[0]
  thresh_tiled = array_ops.tile(
      array_ops.expand_dims(array_ops.constant(thresholds), [1]),
      array_ops.stack([1, num_predictions]))

  # Tile the predictions after thresholding them across different thresholds.
  pred_is_pos = math_ops.greater(
      array_ops.tile(array_ops.transpose(predictions_2d), [num_thresholds, 1]),
      thresh_tiled)
  if ('fn' in includes) or ('tn' in includes):
    pred_is_neg = math_ops.logical_not(pred_is_pos)

  # Tile labels by number of thresholds
  label_is_pos = array_ops.tile(labels_2d, [num_thresholds, 1])
  if ('fp' in includes) or ('tn' in includes):
    label_is_neg = math_ops.logical_not(label_is_pos)

  if weights is not None:
    weights = weights_broadcast_ops.broadcast_weights(
        math_ops.to_float(weights), predictions)
    weights_tiled = array_ops.tile(
        array_ops.reshape(weights, [1, -1]), [num_thresholds, 1])
    thresh_tiled.get_shape().assert_is_compatible_with(
        weights_tiled.get_shape())
  else:
    weights_tiled = None

  values = {}
  update_ops = {}

  if 'tp' in includes:
    true_p = metric_variable([num_thresholds],
                             dtypes.float32,
                             name='true_positives')
    is_true_positive = math_ops.cast(
        math_ops.logical_and(label_is_pos, pred_is_pos), dtypes.float32)
    if weights_tiled is not None:
      is_true_positive *= weights_tiled
    update_ops['tp'] = state_ops.assign_add(
        true_p, math_ops.reduce_sum(is_true_positive, 1), use_locking=True)
    values['tp'] = true_p

  if 'fn' in includes:
    false_n = metric_variable([num_thresholds],
                              dtypes.float32,
                              name='false_negatives')
    is_false_negative = math_ops.cast(
        math_ops.logical_and(label_is_pos, pred_is_neg), dtypes.float32)
    if weights_tiled is not None:
      is_false_negative *= weights_tiled
    update_ops['fn'] = state_ops.assign_add(
        false_n, math_ops.reduce_sum(is_false_negative, 1), use_locking=True)
    values['fn'] = false_n

  if 'tn' in includes:
    true_n = metric_variable([num_thresholds],
                             dtypes.float32,
                             name='true_negatives')
    is_true_negative = math_ops.cast(
        math_ops.logical_and(label_is_neg, pred_is_neg), dtypes.float32)
    if weights_tiled is not None:
      is_true_negative *= weights_tiled
    update_ops['tn'] = state_ops.assign_add(
        true_n, math_ops.reduce_sum(is_true_negative, 1), use_locking=True)
    values['tn'] = true_n

  if 'fp' in includes:
    false_p = metric_variable([num_thresholds],
                              dtypes.float32,
                              name='false_positives')
    is_false_positive = math_ops.cast(
        math_ops.logical_and(label_is_neg, pred_is_pos), dtypes.float32)
    if weights_tiled is not None:
      is_false_positive *= weights_tiled
    update_ops['fp'] = state_ops.assign_add(
        false_p, math_ops.reduce_sum(is_false_positive, 1), use_locking=True)
    values['fp'] = false_p

  return values, update_ops