def _ApplyThresholdsAndTopK()

in tensorflow_transform/beam/analyzer_impls.py [0:0]


def _ApplyThresholdsAndTopK(  # pylint: disable=invalid-name
    counts,
    frequency_threshold,
    top_k,
    input_dtype,
    info_threshold=float('-inf'),
    key_fn=None):
  """Applies `frequency_threshold` and `top_k` to (count, value) pairs."""
  # TODO(b/117796748): Filter frequency per-key when key feature input enabled.
  # Filter is cheaper than TopK computation and the two commute, so filter
  # first.
  if frequency_threshold > 0 or info_threshold > float('-inf'):

    def filter_by_thresholds(values):
      """Returns True if values are greater than specified thresholds."""
      values, _ = values
      # The values can be a single number (the frequency) or a tuple of the
      # informativeness and the frequency.
      if isinstance(values, tuple):
        informativeness, freq = values
      else:
        informativeness = float('inf')
        freq = values
      if freq >= frequency_threshold and informativeness >= info_threshold:
        return True
      return False

    counts |= ('FilterByThresholds(%s)' % frequency_threshold >>
               beam.Filter(filter_by_thresholds))
  # If a tuple of multiple metrics, flatten to only the first. This is needed
  # for the case the accumulator has tracked informativeness and frequency.
  def flatten_to_single_metric(values):
    value, term = values
    value = value[0] if isinstance(value, tuple) else value
    return value, term

  counts |= 'FlattenToSingleMetric' >> beam.Map(flatten_to_single_metric)

  if input_dtype != tf.string.name:
    counts |= 'EncodeNumericalTerms' >> beam.MapTuple(
        lambda k, v: (k, tf.compat.as_bytes(tf.compat.as_str_any(v))))

  if top_k is not None:
    # TODO(katsiapis): Perhaps enhance Beam's Top to accept an N that can
    # signify "unlimited" and then we can simplify a lot of our code (though
    # that might come at a performance penalty).
    if key_fn:
      def map_key_to_count_and_term(kv, key_fn):
        """Parses key from term with `key_fn` and maps it to count and term."""
        count, term = kv
        # TODO(b/184196242): Ideally we wouldn't be producing numpy.float64
        # counts in the first place, as opposed to casting to float here. See
        # also b/79751861.
        count = float(count) if isinstance(count, np.float64) else count
        key = key_fn(term)
        return key, (count, term)

      counts = (
          counts
          | 'MapKeyToCountAndTerm' >> beam.Map(
              lambda x: map_key_to_count_and_term(x, key_fn))
          | 'CoverageTop(%s)' % top_k >> beam.combiners.Top.LargestPerKey(top_k)
          | 'FlattenCoverageTerms' >> beam.FlatMap(lambda kv: kv[1]))
    else:
      # LINT.IfChange(top_k_impl)
      # Stages that follow this block rely on the sorted order of `Top.Of`'s
      # output and fusion with the `FlattenList`. If changing this part of
      # implementation, either make sure that these hold true or adjust the
      # appropriate arg of `VocabularyOrderAndWrite` node.
      counts = (
          counts
          | 'Top(%s)' % top_k >> beam.combiners.Top.Of(top_k)
          | 'MaybeAddDummy' >> beam.Map(
              maybe_add_empty_vocabulary_dummy, dtype=input_dtype)
          | 'FlattenList' >> beam.FlatMap(lambda lst: lst))
      # LINT.ThenChange(../analyzers.py:input_is_sorted)

  return counts