def percentile()

in tensorflow_probability/python/stats/quantiles.py [0:0]


def percentile(x,
               q,
               axis=None,
               interpolation=None,
               keepdims=False,
               validate_args=False,
               preserve_gradients=True,
               name=None):
  """Compute the `q`-th percentile(s) of `x`.

  Given a vector `x`, the `q`-th percentile of `x` is the value `q / 100` of the
  way from the minimum to the maximum in a sorted copy of `x`.

  The values and distances of the two nearest neighbors as well as the
  `interpolation` parameter will determine the percentile if the normalized
  ranking does not match the location of `q` exactly.

  This function is the same as the median if `q = 50`, the same as the minimum
  if `q = 0` and the same as the maximum if `q = 100`.

  Multiple percentiles can be computed at once by using `1-D` vector `q`.
  Dimension zero of the returned `Tensor` will index the different percentiles.

  Compare to `numpy.percentile`.

  Args:
    x:  Numeric `N-D` `Tensor` with `N > 0`.  If `axis` is not `None`,
      `x` must have statically known number of dimensions.
    q:  Scalar or vector `Tensor` with values in `[0, 100]`. The percentile(s).
    axis:  Optional `0-D` or `1-D` integer `Tensor` with constant values. The
      axis that index independent samples over which to return the desired
      percentile.  If `None` (the default), treat every dimension as a sample
      dimension, returning a scalar.
    interpolation : {'nearest', 'linear', 'lower', 'higher', 'midpoint'}.
      Default value: 'nearest'.  This specifies the interpolation method to
      use when the desired quantile lies between two data points `i < j`:
        * linear: i + (j - i) * fraction, where fraction is the fractional part
          of the index surrounded by i and j.
        * lower: `i`.
        * higher: `j`.
        * nearest: `i` or `j`, whichever is nearest.
        * midpoint: (i + j) / 2.
      `linear` and `midpoint` interpolation do not work with integer dtypes.
    keepdims:  Python `bool`. If `True`, the last dimension is kept with size 1
      If `False`, the last dimension is removed from the output shape.
    validate_args:  Whether to add runtime checks of argument validity. If
      False, and arguments are incorrect, correct behavior is not guaranteed.
    preserve_gradients:  Python `bool`.  If `True`, ensure that gradient w.r.t
      the percentile `q` is preserved in the case of linear interpolation.
      If `False`, the gradient will be (incorrectly) zero when `q` corresponds
      to a point in `x`.
    name:  A Python string name to give this `Op`.  Default is 'percentile'

  Returns:
    A `(rank(q) + N - len(axis))` dimensional `Tensor` of same dtype as `x`, or,
      if `axis` is `None`, a `rank(q)` `Tensor`.  The first `rank(q)` dimensions
      index quantiles for different values of `q`.

  Raises:
    ValueError:  If argument 'interpolation' is not an allowed type.
    ValueError:  If interpolation type not compatible with `dtype`.

  #### Examples

  ```python
  # Get 30th percentile with default ('nearest') interpolation.
  x = [1., 2., 3., 4.]
  tfp.stats.percentile(x, q=30.)
  ==> 2.0

  # Get 30th percentile with 'linear' interpolation.
  x = [1., 2., 3., 4.]
  tfp.stats.percentile(x, q=30., interpolation='linear')
  ==> 1.9

  # Get 30th and 70th percentiles with 'lower' interpolation
  x = [1., 2., 3., 4.]
  tfp.stats.percentile(x, q=[30., 70.], interpolation='lower')
  ==> [1., 3.]

  # Get 100th percentile (maximum).  By default, this is computed over every dim
  x = [[1., 2.]
       [3., 4.]]
  tfp.stats.percentile(x, q=100.)
  ==> 4.

  # Treat the leading dim as indexing samples, and find the 100th quantile (max)
  # over all such samples.
  x = [[1., 2.]
       [3., 4.]]
  tfp.stats.percentile(x, q=100., axis=[0])
  ==> [3., 4.]
  ```

  """
  name = name or 'percentile'
  allowed_interpolations = {'linear', 'lower', 'higher', 'nearest', 'midpoint'}

  if interpolation is None:
    interpolation = 'nearest'
  else:
    if interpolation not in allowed_interpolations:
      raise ValueError(
          'Argument `interpolation` must be in {}. Found {}.'.format(
              allowed_interpolations, interpolation))

  with tf.name_scope(name):
    x = tf.convert_to_tensor(x, name='x')

    if (interpolation in {'linear', 'midpoint'} and
        dtype_util.is_integer(x.dtype)):
      raise TypeError('{} interpolation not allowed with dtype {}'.format(
          interpolation, x.dtype))

    # Double is needed here and below, else we get the wrong index if the array
    # is huge along axis.
    q = tf.cast(q, tf.float64)
    _get_static_ndims(q, expect_ndims_no_more_than=1)

    if validate_args:
      q = distribution_util.with_dependencies([
          assert_util.assert_rank_in(q, [0, 1]),
          assert_util.assert_greater_equal(q, tf.cast(0., tf.float64)),
          assert_util.assert_less_equal(q, tf.cast(100., tf.float64))
      ], q)

    # Move `axis` dims of `x` to the rightmost, call it `y`.
    if axis is None:
      y = tf.reshape(x, [-1])
    else:
      x_ndims = _get_static_ndims(
          x, expect_static=True, expect_ndims_at_least=1)
      axis = _make_static_axis_non_negative_list(axis, x_ndims)
      y = _move_dims_to_flat_end(x, axis, x_ndims, right_end=True)

    frac_at_q_or_below = q / 100.

    # Sort (in ascending order) everything which allows multiple calls to sort
    # only once (under the hood) and use CSE.
    sorted_y = tf.sort(y, axis=-1, direction='ASCENDING')

    d = ps.cast(ps.shape(y)[-1], tf.float64)

    def _get_indices(interp_type):
      """Get values of y at the indices implied by interp_type."""
      if interp_type == 'lower':
        indices = tf.math.floor((d - 1) * frac_at_q_or_below)
      elif interp_type == 'higher':
        indices = tf.math.ceil((d - 1) * frac_at_q_or_below)
      elif interp_type == 'nearest':
        indices = tf.round((d - 1) * frac_at_q_or_below)
      # d - 1 will be distinct from d in int32, but not necessarily double.
      # So clip to avoid out of bounds errors.
      return tf.clip_by_value(
          tf.cast(indices, tf.int32), 0,
          ps.shape(y)[-1] - 1)

    if interpolation in ['nearest', 'lower', 'higher']:
      gathered_y = tf.gather(sorted_y, _get_indices(interpolation), axis=-1)
    elif interpolation == 'midpoint':
      gathered_y = 0.5 * (
          tf.gather(sorted_y, _get_indices('lower'), axis=-1) +
          tf.gather(sorted_y, _get_indices('higher'), axis=-1))
    elif interpolation == 'linear':
      # Copy-paste of docstring on interpolation:
      # linear: i + (j - i) * fraction, where fraction is the fractional part
      # of the index surrounded by i and j.
      larger_y_idx = _get_indices('higher')
      exact_idx = (d - 1) * frac_at_q_or_below
      if preserve_gradients:
        # If q corresponds to a point in x, we will initially have
        # larger_y_idx == smaller_y_idx.
        # This results in the gradient w.r.t. fraction being zero (recall `q`
        # enters only through `fraction`...and see that things cancel).
        # The fix is to ensure that smaller_y_idx and larger_y_idx are always
        # separated by exactly 1.
        smaller_y_idx = tf.maximum(larger_y_idx - 1, 0)
        larger_y_idx = tf.minimum(smaller_y_idx + 1, tf.shape(y)[-1] - 1)
        fraction = tf.cast(larger_y_idx, tf.float64) - exact_idx
      else:
        smaller_y_idx = _get_indices('lower')
        fraction = tf.math.ceil((d - 1) * frac_at_q_or_below) - exact_idx

      fraction = tf.cast(fraction, y.dtype)
      gathered_y = (
          tf.gather(sorted_y, larger_y_idx, axis=-1) * (1 - fraction) +
          tf.gather(sorted_y, smaller_y_idx, axis=-1) * fraction)

    # Propagate NaNs
    if x.dtype in (tf.bfloat16, tf.float16, tf.float32, tf.float64):
      # Apparently tf.is_nan doesn't like other dtypes
      nan_batch_members = tf.reduce_any(tf.math.is_nan(x), axis=axis)
      right_rank_matched_shape = ps.pad(
          ps.shape(nan_batch_members),
          paddings=[[0, ps.rank(q)]],
          constant_values=1)
      nan_batch_members = tf.reshape(
          nan_batch_members, shape=right_rank_matched_shape)
      nan = np.array(np.nan, dtype_util.as_numpy_dtype(gathered_y.dtype))
      gathered_y = tf.where(nan_batch_members, nan, gathered_y)

    # Expand dimensions if requested
    if keepdims:
      if axis is None:
        ones_vec = tf.ones(
            shape=[_get_best_effort_ndims(x) + _get_best_effort_ndims(q)],
            dtype=tf.int32)
        gathered_y *= tf.ones(ones_vec, dtype=x.dtype)
      else:
        gathered_y = _insert_back_keepdims(gathered_y, axis)

    # If q is a scalar, then result has the right shape.
    # If q is a vector, then result has trailing dim of shape q.shape, which
    # needs to be rotated to dim 0.
    return distribution_util.rotate_transpose(gathered_y, ps.rank(q))