def reduce_batch_weighted_cooccurrences()

in tensorflow_transform/tf_utils.py [0:0]


def reduce_batch_weighted_cooccurrences(
    x_input: common_types.TensorType,
    y_input: tf.Tensor,
    weights_input: Optional[tf.Tensor] = None,
    extend_with_sentinel_counts: bool = True,
    filter_regex: Optional[str] = None) -> ReducedBatchWeightedCounts:
  """Performs batch-wise reduction to produce weighted co-occurrences.

  Computes the weighted co-occurrence of each feature value in x, for each value
  in the range [0, max(y)). If extend_with_sentinel_counts is true, the return
  value will include an additional sentinel token (not in the true vocabulary)
  that is used to accumulate the global distribution of y values.

  Args:
    x_input: Input `Tensor` or `CompositeTensor`.
    y_input: Integer `Tensor` with which to compute the co-occurrence with
      x_input.
    weights_input: (Optional) Weights input `Tensor`.
    extend_with_sentinel_counts: If True, the reduced batch will be extended
      a sentinel value that accumlate the total distribution of y values. Should
      be True except when called recursively with the sentinel value as input.
    filter_regex: (Optional) Regex that matches tokens that have to be filtered
      out. Can only be specified if `x_input` has string dtype.

  Returns:
    a namedtuple of...
    unique_x_values: the unique values in x
    summed_weights_per_x: sum of the weights for each unique value in x
    summed_positive_per_x_and_y: If tensor y is provided, the sum of
      positive weights for each unique y value, for each unique value in x.
      If y tensor is not provided, value is None.
    counts_per_x: if y is provided, counts of each of the unique values in x,
      otherwise, None.
  """
  tf.compat.v1.assert_type(y_input, tf.int64)
  # TODO(b/134075780): Revisit expected weights shape when input is sparse.
  if isinstance(x_input, tf.SparseTensor):
    batch_indices = x_input.indices[:, 0]
    # y and densified x should have the same batch dimension.
    assert_eq = tf.compat.v1.assert_equal(
        tf.shape(y_input)[0], tf.cast(x_input.dense_shape[0], tf.int32))
    with tf.control_dependencies([assert_eq]):
      y = tf.gather(y_input, batch_indices)
    x = x_input.values
  elif isinstance(x_input, tf.RaggedTensor):
    # Each batch instance in x corresponds to a single value in y.
    x_row_indices = _get_ragged_batch_value_rowids(x_input)
    assert_compatible = tf.debugging.assert_greater_equal(
        tf.shape(y_input, out_type=tf.int64)[0], x_input.bounding_shape(axis=0))
    with tf.control_dependencies([assert_compatible]):
      x = tf.ensure_shape(x_input.flat_values, [None])
      y = tf.gather(y_input, x_row_indices)
  else:
    y = y_input
    x = x_input
  if weights_input is None:
    weights = tf.ones_like(x, dtype=tf.float32)
  else:
    x, weights_input = assert_same_shape(x, weights_input)
    weights = weights_input
  y = _broadcast_to_x_shape(x, y)
  x, y = assert_same_shape(x, y)
  x = tf.reshape(x, [-1])
  filter_fn = _make_regex_filter_fn(x, filter_regex)
  x = filter_fn(x)
  y = filter_fn(tf.reshape(y, [-1]))
  weights = filter_fn(tf.reshape(weights, [-1]))

  unique_x_values, unique_idx, unique_count = tf.unique_with_counts(
      x, out_idx=tf.int64)

  summed_weights_per_x = tf.math.unsorted_segment_sum(
      weights, unique_idx, tf.size(input=unique_x_values))
  # For each feature value in x, computed the weighted sum positive for each
  # unique value in y.

  max_y_value = tf.cast(tf.reduce_max(input_tensor=y_input), tf.int64)
  max_x_idx = tf.cast(tf.size(unique_x_values), tf.int64)
  dummy_index = (max_y_value + 1) * unique_idx + y
  summed_positive_per_x_and_y = tf.cast(
      tf.math.unsorted_segment_sum(weights, dummy_index,
                                   max_x_idx * (max_y_value + 1)),
      dtype=tf.float32)
  summed_positive_per_x_and_y = tf.reshape(summed_positive_per_x_and_y,
                                           [max_x_idx, max_y_value + 1])

  reduced_batch = ReducedBatchWeightedCounts(
      unique_x=unique_x_values,
      summed_weights_per_x=summed_weights_per_x,
      summed_positive_per_x_and_y=summed_positive_per_x_and_y,
      counts_per_x=unique_count)
  # Add a sentinel token tracking the full distribution of y values.
  if extend_with_sentinel_counts:
    reduced_batch = extend_reduced_batch_with_y_counts(reduced_batch, y_input,
                                                       weights_input)
  return reduced_batch