in tensorflow_transform/tf_utils.py [0:0]
def _validate_and_get_dense_value_key_inputs(
x: common_types.TensorType,
key: common_types.TensorType) -> Tuple[tf.Tensor, tf.Tensor]:
"""Validate x and key and returns dense representations if feasible.
Check if sparse x and sparse key have identical indices, map key if dense.
Args:
x: A `Tensor` or `CompositeTensor`.
key: A `Tensor` or `CompositeTensor`. Must be `Tensor` if x is `Tensor`.
Returns:
The values of x and key if both are composite, the values of x and a mapped
key if only x is composite, or the original x and key if both are dense.
"""
if isinstance(x, tf.Tensor) and isinstance(key, tf.Tensor):
return x, key
elif isinstance(x, tf.Tensor):
raise ValueError('A dense key is required if x is dense')
elif isinstance(x, tf.SparseTensor) and isinstance(key, tf.SparseTensor):
assert_shape = tf.debugging.assert_equal(x.dense_shape, key.dense_shape)
assert_eq = tf.debugging.assert_equal(x.indices, key.indices)
with tf.control_dependencies([assert_eq, assert_shape]):
return tf.identity(x.values), tf.identity(key.values)
elif isinstance(x, tf.SparseTensor) and isinstance(key, tf.Tensor):
# In this case, the row of x corresponds to the key at that row.
x_row_indices = x.indices[:, 0]
assert_compatible = tf.debugging.assert_greater_equal(
tf.shape(key, out_type=tf.int64)[0], x.dense_shape[0])
with tf.control_dependencies([assert_compatible]):
return x.values, tf.gather(key, x_row_indices)
elif isinstance(x, tf.SparseTensor):
raise ValueError('A sparse or dense key is required if x is sparse')
elif isinstance(x, tf.RaggedTensor) and isinstance(key, tf.RaggedTensor):
x.shape.assert_is_compatible_with(key.shape)
assert_ops = [
tf.debugging.assert_equal(x_split, key_split) for x_split, key_split in
zip(x.nested_row_splits, key.nested_row_splits)
]
with tf.control_dependencies(assert_ops):
return (tf.ensure_shape(tf.identity(x.flat_values), [None]),
tf.ensure_shape(tf.identity(key.flat_values), [None]))
elif isinstance(x, tf.RaggedTensor) and isinstance(key, tf.Tensor):
# Each batch instance in x corresponds to a single element in key.
x_row_indices = _get_ragged_batch_value_rowids(x)
assert_compatible = tf.debugging.assert_greater_equal(
tf.shape(key, out_type=tf.int64)[0], x.bounding_shape(axis=0))
with tf.control_dependencies([assert_compatible]):
return (tf.ensure_shape(x.flat_values,
[None]), tf.gather(key, x_row_indices))
else:
raise ValueError('A ragged or dense key is required if x is ragged')