in tensorflow_gnn/graph/tensor_utils.py [0:0]
def fill(spec: ValueSpec, nrows: tf.Tensor, value: tf.Tensor) -> Value:
"""Creates tensor filled with a scalar `value` according to the constraints.
This function returns a Tensor or RaggedTensor compatible with `spec`.
Its outermost dimension is `nrows`. Its further dimensions must be dense
dimensions of a size defined in `spec`, or ragged dimensions for which
the value contains 0 items. The elements of the tensor (if any) are set to
`value`.
Args:
spec: type spec the result should be compatible with.
nrows: number of rows in the result tensor. For a dense tensor, this is the
outermost dimension size. For a ragged tensor, this is the number of rows
in the outermost split (`tf.RaggedTensor.nrows`).
value: scalar value to use for filling.
Returns:
Tensor filled with `value` that is compatible with `spec` and has `nrows`
number of rows.
"""
value = tf.convert_to_tensor(value, dtype=spec.dtype)
if value.shape.rank != 0:
raise ValueError('The `value` must be scalar tensor,'
f' got rank={value.shape.rank}')
nrows = tf.convert_to_tensor(nrows)
if nrows.shape.rank != 0:
raise ValueError('The `nrows` must be scalar tensor,'
f' got rank={nrows.shape.rank}')
if isinstance(spec, tf.TensorSpec) or spec.ragged_rank == 0:
inner_dims = spec.shape[1:]
outer_dim = spec.shape[0]
if outer_dim is not None and outer_dim != nrows:
raise ValueError(f'The leading dimension in `spec` is {outer_dim} and'
f' it is not compatible with nrows={nrows}.')
if not inner_dims.is_fully_defined():
raise ValueError('All except the leading shape dimensions in `spec`'
' must be fully defined,'
f' got shape={spec.shape}')
result_dims = [nrows, *inner_dims.as_list()]
result = tf.fill(result_dims, value)
assert result.shape[1:].as_list() == inner_dims.as_list()
elif isinstance(spec, tf.RaggedTensorSpec):
# By convension: scalar entries represent uniform row length, vector entries
# represent ragged row lenghts.
row_partitions = []
# The `cum_dim` tracks the minimum positive number of entities that could be
# partitioned by the continuous sequence of higher-up uniform dimensions.
cum_dim = nrows
for dim in spec.shape[1:(spec.ragged_rank + 1)]:
if dim is None:
# Ragged dimension: add row lengths ([0, 0.., 0]) for empty values that
# are compatible with outer dimensions.
row_partitions.append(
tf.fill([cum_dim], tf.constant(0, dtype=spec.row_splits_dtype)))
cum_dim = 0
else:
row_partitions.append(tf.constant(dim, dtype=spec.row_splits_dtype))
cum_dim = cum_dim * dim
assert spec.shape[spec.ragged_rank] is None, spec
features_shape = spec.shape[(spec.ragged_rank + 1):]
flat_values_shape = tf.TensorShape([0]).concatenate(features_shape)
flat_values = tf.fill(flat_values_shape, value)
result = flat_values
for row_partition in reversed(row_partitions):
if row_partition.shape.rank == 0:
result = tf.RaggedTensor.from_uniform_row_length(result, row_partition)
else:
assert row_partition.shape.rank == 1, row_partition.rank
result = tf.RaggedTensor.from_row_lengths(result, row_partition)
else:
raise ValueError(f'Unsupported type spec {type(spec).__name__}')
assert spec.is_compatible_with(
result), f'{spec}, {tf.type_spec_from_value(result)}'
return result