in tensorflow_transform/tf_utils.py [0:0]
def reduce_batch_weighted_cooccurrences(
x_input: common_types.TensorType,
y_input: tf.Tensor,
weights_input: Optional[tf.Tensor] = None,
extend_with_sentinel_counts: bool = True,
filter_regex: Optional[str] = None) -> ReducedBatchWeightedCounts:
"""Performs batch-wise reduction to produce weighted co-occurrences.
Computes the weighted co-occurrence of each feature value in x, for each value
in the range [0, max(y)). If extend_with_sentinel_counts is true, the return
value will include an additional sentinel token (not in the true vocabulary)
that is used to accumulate the global distribution of y values.
Args:
x_input: Input `Tensor` or `CompositeTensor`.
y_input: Integer `Tensor` with which to compute the co-occurrence with
x_input.
weights_input: (Optional) Weights input `Tensor`.
extend_with_sentinel_counts: If True, the reduced batch will be extended
a sentinel value that accumlate the total distribution of y values. Should
be True except when called recursively with the sentinel value as input.
filter_regex: (Optional) Regex that matches tokens that have to be filtered
out. Can only be specified if `x_input` has string dtype.
Returns:
a namedtuple of...
unique_x_values: the unique values in x
summed_weights_per_x: sum of the weights for each unique value in x
summed_positive_per_x_and_y: If tensor y is provided, the sum of
positive weights for each unique y value, for each unique value in x.
If y tensor is not provided, value is None.
counts_per_x: if y is provided, counts of each of the unique values in x,
otherwise, None.
"""
tf.compat.v1.assert_type(y_input, tf.int64)
# TODO(b/134075780): Revisit expected weights shape when input is sparse.
if isinstance(x_input, tf.SparseTensor):
batch_indices = x_input.indices[:, 0]
# y and densified x should have the same batch dimension.
assert_eq = tf.compat.v1.assert_equal(
tf.shape(y_input)[0], tf.cast(x_input.dense_shape[0], tf.int32))
with tf.control_dependencies([assert_eq]):
y = tf.gather(y_input, batch_indices)
x = x_input.values
elif isinstance(x_input, tf.RaggedTensor):
# Each batch instance in x corresponds to a single value in y.
x_row_indices = _get_ragged_batch_value_rowids(x_input)
assert_compatible = tf.debugging.assert_greater_equal(
tf.shape(y_input, out_type=tf.int64)[0], x_input.bounding_shape(axis=0))
with tf.control_dependencies([assert_compatible]):
x = tf.ensure_shape(x_input.flat_values, [None])
y = tf.gather(y_input, x_row_indices)
else:
y = y_input
x = x_input
if weights_input is None:
weights = tf.ones_like(x, dtype=tf.float32)
else:
x, weights_input = assert_same_shape(x, weights_input)
weights = weights_input
y = _broadcast_to_x_shape(x, y)
x, y = assert_same_shape(x, y)
x = tf.reshape(x, [-1])
filter_fn = _make_regex_filter_fn(x, filter_regex)
x = filter_fn(x)
y = filter_fn(tf.reshape(y, [-1]))
weights = filter_fn(tf.reshape(weights, [-1]))
unique_x_values, unique_idx, unique_count = tf.unique_with_counts(
x, out_idx=tf.int64)
summed_weights_per_x = tf.math.unsorted_segment_sum(
weights, unique_idx, tf.size(input=unique_x_values))
# For each feature value in x, computed the weighted sum positive for each
# unique value in y.
max_y_value = tf.cast(tf.reduce_max(input_tensor=y_input), tf.int64)
max_x_idx = tf.cast(tf.size(unique_x_values), tf.int64)
dummy_index = (max_y_value + 1) * unique_idx + y
summed_positive_per_x_and_y = tf.cast(
tf.math.unsorted_segment_sum(weights, dummy_index,
max_x_idx * (max_y_value + 1)),
dtype=tf.float32)
summed_positive_per_x_and_y = tf.reshape(summed_positive_per_x_and_y,
[max_x_idx, max_y_value + 1])
reduced_batch = ReducedBatchWeightedCounts(
unique_x=unique_x_values,
summed_weights_per_x=summed_weights_per_x,
summed_positive_per_x_and_y=summed_positive_per_x_and_y,
counts_per_x=unique_count)
# Add a sentinel token tracking the full distribution of y values.
if extend_with_sentinel_counts:
reduced_batch = extend_reduced_batch_with_y_counts(reduced_batch, y_input,
weights_input)
return reduced_batch