def _validate_and_get_dense_value_key_inputs()

in tensorflow_transform/tf_utils.py [0:0]


def _validate_and_get_dense_value_key_inputs(
    x: common_types.TensorType,
    key: common_types.TensorType) -> Tuple[tf.Tensor, tf.Tensor]:
  """Validate x and key and returns dense representations if feasible.

  Check if sparse x and sparse key have identical indices, map key if dense.

  Args:
    x: A `Tensor` or `CompositeTensor`.
    key: A `Tensor` or `CompositeTensor`. Must be `Tensor` if x is `Tensor`.

  Returns:
    The values of x and key if both are composite, the values of x and a mapped
    key if only x is composite, or the original x and key if both are dense.
  """

  if isinstance(x, tf.Tensor) and isinstance(key, tf.Tensor):
    return x, key
  elif isinstance(x, tf.Tensor):
    raise ValueError('A dense key is required if x is dense')

  elif isinstance(x, tf.SparseTensor) and isinstance(key, tf.SparseTensor):
    assert_shape = tf.debugging.assert_equal(x.dense_shape, key.dense_shape)
    assert_eq = tf.debugging.assert_equal(x.indices, key.indices)
    with tf.control_dependencies([assert_eq, assert_shape]):
      return tf.identity(x.values), tf.identity(key.values)
  elif isinstance(x, tf.SparseTensor) and isinstance(key, tf.Tensor):
    # In this case, the row of x corresponds to the key at that row.
    x_row_indices = x.indices[:, 0]
    assert_compatible = tf.debugging.assert_greater_equal(
        tf.shape(key, out_type=tf.int64)[0], x.dense_shape[0])
    with tf.control_dependencies([assert_compatible]):
      return x.values, tf.gather(key, x_row_indices)
  elif isinstance(x, tf.SparseTensor):
    raise ValueError('A sparse or dense key is required if x is sparse')

  elif isinstance(x, tf.RaggedTensor) and isinstance(key, tf.RaggedTensor):
    x.shape.assert_is_compatible_with(key.shape)
    assert_ops = [
        tf.debugging.assert_equal(x_split, key_split) for x_split, key_split in
        zip(x.nested_row_splits, key.nested_row_splits)
    ]
    with tf.control_dependencies(assert_ops):
      return (tf.ensure_shape(tf.identity(x.flat_values), [None]),
              tf.ensure_shape(tf.identity(key.flat_values), [None]))
  elif isinstance(x, tf.RaggedTensor) and isinstance(key, tf.Tensor):
    # Each batch instance in x corresponds to a single element in key.
    x_row_indices = _get_ragged_batch_value_rowids(x)
    assert_compatible = tf.debugging.assert_greater_equal(
        tf.shape(key, out_type=tf.int64)[0], x.bounding_shape(axis=0))
    with tf.control_dependencies([assert_compatible]):
      return (tf.ensure_shape(x.flat_values,
                              [None]), tf.gather(key, x_row_indices))
  else:
    raise ValueError('A ragged or dense key is required if x is ragged')