def random_ragged_tensor()

in tensorflow_gnn/graph/graph_tensor_random.py [0:0]


def random_ragged_tensor(
    shape: List[Union[int, None, tf.Tensor]],
    dtype: tf.dtypes.DType,
    sample_values: Optional[List[Any]] = None,
    row_lengths_range: Tuple[int, int] = (2, 8),
    row_splits_dtype: tf.dtypes.DType = tf.int32,
    validate: bool = True) -> gc.Field:
  """Generate a ragged tensor with random values.

  Note: This is running in pure Python, not building a TensorFlow graph.

  Args:
    shape: The desired shape of the tensor, as a list of integers, None (for
      ragged dimensions) or a dynamically computed size (as a tensor of rank 0).
      Do not provide a tf.TensorShape here as it cannot hold tensors.
    dtype: Data type for the values.
    sample_values: List of example values to sample from. If not specified, some
      simple defaults are produced. The data type must match that used to
      initalize `dtype`.
    row_lengths_range: Minimum and maximum values for each row lengths in a
      ragged range.
    row_splits_dtype: Data type for row splits.
    validate: If true, then use assertions to check that the arguments form a
      valid RaggedTensor. Note: these assertions incur a runtime cost, since
      they must be checked for each tensor value.

  Returns:
    An instance of either tf.Tensor of tf.RaggedTensor.
  """

  # Allocate partitions for each , generating random row lengths where ragged.
  # This also computes the total number of values to be inserted in the final
  # tensor (`size`).
  nested_row_lengths = []
  if not isinstance(shape, list):
    raise ValueError(f"Requested shape must be a list of integers, None or "
                     f"rank-0 tensors: {shape}")
  if shape:
    size = tf.constant(1, tf.int32)
    for dim in shape:
      if isinstance(dim, (int, tf.Tensor)):
        if isinstance(dim, tf.Tensor):
          _assert_rank0_int(dim, "Shape dimension")
          dim = tf.cast(dim, tf.int32)
        nested_row_lengths.append(dim)
        size *= dim
      else:
        assert dim is None, f"Invalid type for dimension in shape: {dim}"
        low, high = row_lengths_range
        row_lengths = tf.random.uniform([size], low, high, row_splits_dtype)
        nested_row_lengths.append(row_lengths)
        size = tf.cast(tf.math.reduce_sum(row_lengths), tf.int32)
  else:
    size = []  # Scalar.

  # Allocate the total amount of flat values.
  if sample_values:
    # TODO(blais): Assert consistency of types from `sample_values` and `dtype`.
    indices = tf.random.uniform([size], 0, len(sample_values), tf.int32)
    sample_values_tensor = tf.convert_to_tensor(sample_values, dtype=dtype)
    flat_values = tf.gather(sample_values_tensor, indices)
  else:
    flat_values = typed_random_values(size, dtype)

  # Now, build up the ragged tensor inside out.
  #
  # NOTE(blais,edloper): The future of RaggedTensor will bring support for a
  # RowPartition representation that will make the following dispatching code
  # unnecessary.
  tensor = flat_values
  for row_lengths in reversed(nested_row_lengths[1:]):
    if isinstance(row_lengths, int):
      if isinstance(tensor, tf.RaggedTensor):
        tensor = tf.RaggedTensor.from_uniform_row_length(tensor, row_lengths,
                                                         validate=validate)
      else:
        old_shape = list(tensor.shape)
        tensor = tf.reshape(tensor, [-1, row_lengths] + old_shape[1:])
    else:
      tensor = tf.RaggedTensor.from_row_lengths(tensor, row_lengths,
                                                validate=validate)
  return tensor