def fill()

in tensorflow_gnn/graph/tensor_utils.py [0:0]


def fill(spec: ValueSpec, nrows: tf.Tensor, value: tf.Tensor) -> Value:
  """Creates tensor filled with a scalar `value` according to the constraints.

  This function returns a Tensor or RaggedTensor compatible with `spec`.
  Its outermost dimension is `nrows`. Its further dimensions must be dense
  dimensions of a size defined in `spec`, or ragged dimensions for which
  the value contains 0 items. The elements of the tensor (if any) are set to
  `value`.

  Args:
    spec: type spec the result should be compatible with.
    nrows: number of rows in the result tensor. For a dense tensor, this is the
      outermost dimension size. For a ragged tensor, this is the number of rows
      in the outermost split (`tf.RaggedTensor.nrows`).
    value: scalar value to use for filling.

  Returns:
    Tensor filled with `value` that is compatible with `spec` and has `nrows`
    number of rows.
  """
  value = tf.convert_to_tensor(value, dtype=spec.dtype)
  if value.shape.rank != 0:
    raise ValueError('The `value` must be scalar tensor,'
                     f' got rank={value.shape.rank}')

  nrows = tf.convert_to_tensor(nrows)
  if nrows.shape.rank != 0:
    raise ValueError('The `nrows` must be scalar tensor,'
                     f' got rank={nrows.shape.rank}')

  if isinstance(spec, tf.TensorSpec) or spec.ragged_rank == 0:
    inner_dims = spec.shape[1:]
    outer_dim = spec.shape[0]
    if outer_dim is not None and outer_dim != nrows:
      raise ValueError(f'The leading dimension in `spec` is {outer_dim} and'
                       f' it is not compatible with nrows={nrows}.')
    if not inner_dims.is_fully_defined():
      raise ValueError('All except the leading shape dimensions in `spec`'
                       ' must be fully defined,'
                       f' got shape={spec.shape}')
    result_dims = [nrows, *inner_dims.as_list()]
    result = tf.fill(result_dims, value)
    assert result.shape[1:].as_list() == inner_dims.as_list()

  elif isinstance(spec, tf.RaggedTensorSpec):

    # By convension: scalar entries represent uniform row length, vector entries
    # represent ragged row lenghts.
    row_partitions = []
    # The `cum_dim` tracks the minimum positive number of entities that could be
    # partitioned by the continuous sequence of higher-up uniform dimensions.
    cum_dim = nrows
    for dim in spec.shape[1:(spec.ragged_rank + 1)]:
      if dim is None:
        # Ragged dimension: add row lengths ([0, 0.., 0]) for empty values that
        # are compatible with outer dimensions.
        row_partitions.append(
            tf.fill([cum_dim], tf.constant(0, dtype=spec.row_splits_dtype)))
        cum_dim = 0
      else:
        row_partitions.append(tf.constant(dim, dtype=spec.row_splits_dtype))
        cum_dim = cum_dim * dim

    assert spec.shape[spec.ragged_rank] is None, spec
    features_shape = spec.shape[(spec.ragged_rank + 1):]
    flat_values_shape = tf.TensorShape([0]).concatenate(features_shape)
    flat_values = tf.fill(flat_values_shape, value)
    result = flat_values
    for row_partition in reversed(row_partitions):
      if row_partition.shape.rank == 0:
        result = tf.RaggedTensor.from_uniform_row_length(result, row_partition)
      else:
        assert row_partition.shape.rank == 1, row_partition.rank
        result = tf.RaggedTensor.from_row_lengths(result, row_partition)
  else:
    raise ValueError(f'Unsupported type spec {type(spec).__name__}')

  assert spec.is_compatible_with(
      result), f'{spec}, {tf.type_spec_from_value(result)}'
  return result