def compute_keypoints()

in tensorflow_lattice/python/premade_lib.py [0:0]


def compute_keypoints(values,
                      num_keypoints,
                      keypoints='quantiles',
                      clip_min=None,
                      clip_max=None,
                      default_value=None,
                      weights=None,
                      weight_reduction='mean',
                      feature_name=''):
  """Calculates keypoints for the given set of values.

  Args:
    values: Values to use for quantile calculation.
    num_keypoints: Number of keypoints to compute.
    keypoints: String `'quantiles'` or `'uniform'`.
    clip_min: Input values are lower clipped by this value.
    clip_max: Input values are upper clipped by this value.
    default_value: If provided, occurances will be removed from values.
    weights: Weights to be used for quantile calculation.
    weight_reduction: Reduction applied to weights for repeated values. Must be
      either 'mean' or 'sum'.
    feature_name: Name to use for error logs.

  Returns:
    A list of keypoints of `num_keypoints` length.
  """
  # Remove default values before calculating stats.
  non_default_idx = values != default_value
  values = values[non_default_idx]
  if weights is not None:
    weights = weights[non_default_idx]

  # Clip min and max if requested. Note that we add clip bounds to the values
  # so that the first and last keypoints are set to those values.
  if clip_min is not None:
    values = np.maximum(values, clip_min)
    values = np.append(values, clip_min)
    if weights is not None:
      weights = np.append(weights, 0)
  if clip_max is not None:
    values = np.minimum(values, clip_max)
    values = np.append(values, clip_max)
    if weights is not None:
      weights = np.append(weights, 0)

  # We do not allow nans in the data, even as default_value.
  if np.isnan(values).any():
    raise ValueError(
        'NaN values were observed for numeric feature `{}`. '
        'Consider replacing the values in transform or input_fn.'.format(
            feature_name))

  # Remove duplicates and sort value before calculating stats.
  # This is emperically useful as we use of keypoints more efficiently.
  if weights is None:
    sorted_values = np.unique(values)
  else:
    # First sort the values and reorder weights.
    idx = np.argsort(values)
    values = values[idx]
    weights = weights[idx]

    # Set the weight of each unique element to be the sum or average of the
    # weights of repeated instances. Using 'mean' reduction results in parity
    # between unweighted calculation and having equal weights for all values.
    sorted_values, idx, counts = np.unique(
        values, return_index=True, return_counts=True)
    weights = np.add.reduceat(weights, idx)
    if weight_reduction == 'mean':
      weights = weights / counts
    elif weight_reduction != 'sum':
      raise ValueError('Invalid weight reduction: {}'.format(weight_reduction))

  if keypoints == 'quantiles':
    if sorted_values.size < num_keypoints:
      logging.info(
          'Not enough unique values observed for feature `%s` to '
          'construct %d keypoints for pwl calibration. Using %d unique '
          'values as keypoints.', feature_name, num_keypoints,
          sorted_values.size)
      return sorted_values.astype(float)

    quantiles = np.linspace(0., 1., num_keypoints)
    if weights is not None:
      return _weighted_quantile(
          sorted_values=sorted_values, quantiles=quantiles,
          weights=weights).astype(float)
    else:
      return np.quantile(
          sorted_values, quantiles, interpolation='nearest').astype(float)

  elif keypoints == 'uniform':
    return np.linspace(sorted_values[0], sorted_values[-1], num_keypoints)
  else:
    raise ValueError('Invalid keypoint generation mode: {}'.format(keypoints))