def _precision_recall_at_k()

in tensorflow_model_analysis/post_export_metrics/metrics.py [0:0]


def _precision_recall_at_k(classes: types.TensorType,
                           scores: types.TensorType,
                           labels: types.TensorType,
                           cutoffs: List[int],
                           weights: Optional[types.TensorType] = None,
                           precision: Optional[bool] = True,
                           recall: Optional[bool] = True
                          ) -> Tuple[types.TensorType, types.TensorType]:
  # pyformat: disable
  """Precision and recall at `k`.

  Args:
    classes: Tensor containing class names. Should be a BATCH_SIZE x NUM_CLASSES
      Tensor.
    scores: Tensor containing the associated scores. Should be a
      BATCH_SIZE x NUM_CLASSES Tensor.
    labels: Tensor containing the true labels. Should be a rank-2 Tensor where
      the first dimension is BATCH_SIZE. The second dimension can be anything.
    cutoffs: List containing the values for the `k` at which to compute the
      precision and recall for. Use a value of `k` = 0 to indicate that all
      predictions should be considered.
    weights: Optional weights for each of the examples. If None,
      each of the predictions/labels will be assumed to have a weight of 1.0.
      If present, should be a BATCH_SIZE Tensor.
    precision: True to compute precision.
    recall: True to compute recall.

  The value_op will return a matrix with len(cutoffs) rows and 3 columns:
  [ cutoff 0, precision at cutoff 0, recall at cutoff 0 ]
  [ cutoff 1, precision at cutoff 1, recall at cutoff 1 ]
  [     :                :                  :           ]
  [ cutoff n, precision at cutoff n, recall at cutoff n ]

  If only one of precision or recall is True then the value_op will return only
  2 columns (cutoff and ether precision or recall depending on which is True).

  Returns:
    (value_op, update_op) for the precision/recall at K metric.
  """
  # pyformat: enable
  if not precision and not recall:
    raise ValueError('one of either precision or recall must be set')

  num_cutoffs = len(cutoffs)

  if precision and recall:
    scope = 'precision_recall_at_k'
  elif precision:
    scope = 'precision_at_k'
  else:
    scope = 'recall_at_k'

  with tf.compat.v1.variable_scope(scope, [classes, scores, labels]):

    # Predicted positive.
    predicted_positives = tf.compat.v1.Variable(
        initial_value=[0.0] * num_cutoffs,
        dtype=tf.float64,
        trainable=False,
        collections=[
            tf.compat.v1.GraphKeys.METRIC_VARIABLES,
            tf.compat.v1.GraphKeys.LOCAL_VARIABLES
        ],
        validate_shape=True,
        name='predicted_positives')

    # Predicted positive, label positive.
    true_positives = tf.compat.v1.Variable(
        initial_value=[0.0] * num_cutoffs,
        dtype=tf.float64,
        trainable=False,
        collections=[
            tf.compat.v1.GraphKeys.METRIC_VARIABLES,
            tf.compat.v1.GraphKeys.LOCAL_VARIABLES
        ],
        validate_shape=True,
        name='true_positives')

    # Label positive.
    actual_positives = tf.compat.v1.Variable(
        initial_value=0.0,
        dtype=tf.float64,
        trainable=False,
        collections=[
            tf.compat.v1.GraphKeys.METRIC_VARIABLES,
            tf.compat.v1.GraphKeys.LOCAL_VARIABLES
        ],
        validate_shape=True,
        name='actual_positives')

    if weights is not None:
      weights_f64 = tf.cast(weights, tf.float64)
    else:
      weights_f64 = tf.ones(tf.shape(input=labels)[0], tf.float64)

  def compute_batch_stats(classes: np.ndarray, scores: np.ndarray,
                          labels: np.ndarray,
                          weights: np.ndarray) -> np.ndarray:
    """Compute precision/recall intermediate stats for a batch.

    Args:
      classes: Tensor containing class names. Should be a BATCH_SIZE x
        NUM_CLASSES Tensor.
      scores: Tensor containing the associated scores. Should be a BATCH_SIZE x
        NUM_CLASSES Tensor.
      labels: Tensor containing the true labels. Should be a rank-2 Tensor where
        the first dimension is BATCH_SIZE. The second dimension can be anything.
      weights: Weights for the associated exmaples. Should be a BATCH_SIZE
        Tesnor.

    Returns:
      True positives, predicted positives, actual positives computed for the
      batch of examples.

    Raises:
      ValueError: classes and scores have different shapes; or labels has
       a different batch size from classes and scores
    """

    if classes.shape != scores.shape:
      raise ValueError('classes and scores should have same shape, but got '
                       '%s and %s' % (classes.shape, scores.shape))

    batch_size = classes.shape[0]
    num_classes = classes.shape[1]
    if labels.shape[0] != batch_size:
      raise ValueError('labels should have the same batch size of %d, but got '
                       '%d instead' % (batch_size, labels.shape[0]))

    # Sort classes, by row, by their associated scores, in descending order of
    # score.
    sorted_classes = np.flip(
        classes[np.arange(batch_size)[:, None],
                np.argsort(scores)], axis=1)

    true_positives = np.zeros(num_cutoffs, dtype=np.float64)
    predicted_positives = np.zeros(num_cutoffs, dtype=np.float64)
    actual_positives = 0.0

    for predicted_row, label_row, weight in zip(sorted_classes, labels,
                                                weights):

      label_set = set(label_row)
      label_set.discard(b'')  # Remove filler elements.

      for i, cutoff in enumerate(cutoffs):
        cutoff_to_use = cutoff if cutoff > 0 else num_classes
        cut_predicted_row = predicted_row[:cutoff_to_use]
        true_pos = set(cut_predicted_row) & label_set
        true_positives[i] += len(true_pos) * weight
        predicted_positives[i] += len(cut_predicted_row) * weight

      actual_positives += len(label_set) * weight

    return true_positives, predicted_positives, actual_positives  # pytype: disable=bad-return-type

  # Value op returns
  # [ K | precision at K | recall at K ]
  # PyType doesn't like TF operator overloads: b/92797687
  # pytype: disable=unsupported-operands
  precision_op = true_positives / predicted_positives
  recall_op = true_positives / actual_positives
  # pytype: enable=unsupported-operands
  if precision and recall:
    value_op = tf.transpose(
        a=tf.stack([cutoffs, precision_op, recall_op], axis=0))
  elif precision:
    value_op = tf.transpose(a=tf.stack([cutoffs, precision_op], axis=0))
  else:
    value_op = tf.transpose(a=tf.stack([cutoffs, recall_op], axis=0))

  true_positives_update, predicted_positives_update, actual_positives_update = (
      tf.compat.v1.py_func(compute_batch_stats,
                           [classes, scores, labels, weights_f64],
                           [tf.float64, tf.float64, tf.float64]))

  update_op = tf.group(
      tf.compat.v1.assign_add(true_positives, true_positives_update),
      tf.compat.v1.assign_add(predicted_positives, predicted_positives_update),
      tf.compat.v1.assign_add(actual_positives, actual_positives_update))

  return value_op, update_op