in tensorflow_gnn/graph/tensor_utils.py [0:0]
def ones_like_leading_dims(value: Value, rank: int,
dtype: tf.dtypes.DType) -> Value:
"""Creates a tensor of all ones for first `rank` dimensions."""
if rank == 0:
raise ValueError(f'Expected rank > 0, got {rank}')
if rank > value.shape.rank:
raise ValueError('`rank` is greater then `value` rank,'
f' got rank={rank},'
f' value.shape.rank={value.shape.rank}')
if is_dense_tensor(value):
size_shape = tf.shape(value)[:rank]
return tf.ones(size_shape, dtype=dtype)
if not is_ragged_tensor(value):
raise ValueError(f'Unsupported type {type(value).__name__}')
def iterate(value: Value, rank: int) -> Value:
if rank == 0:
if is_ragged_tensor(value):
nrows = value.nrows()
else:
nrows = tf.shape(value)[0]
return tf.ones(tf.expand_dims(nrows, -1), dtype=dtype)
if value.uniform_row_length:
return tf.RaggedTensor.from_uniform_row_length(
iterate(value.values, rank - 1), value.uniform_row_lengths)
return tf.RaggedTensor.from_row_splits(
iterate(value.values, rank - 1), value.row_splits)
return iterate(value, rank - 1)