in tensorflow_gnn/graph/tensor_utils.py [0:0]
def pad_to_nrows(value: Value,
target_nrows: tf.Tensor,
padding_value: tf.Tensor,
validate: bool = True) -> Value:
"""Pads `value` to the target number of rows with scalar `padding_value`.
Args:
value: tensor of rank > 0 or ragged tensor to pad.
target_nrows: number of rows in the result tensor. For a dense tensor, this
is the outermost dimension size. For a ragged tensor, this is the number
of rows in the outermost split (`tf.RaggedTensor.nrows`).
padding_value: scalar value to use for padding.
validate: if true, adds runtime checks that value could be padded.
Returns:
Input `value` padded to the target number of rows.
"""
if value.shape.rank == 0:
raise ValueError('The `value` must have rank>0, got scalar (rank=0)')
if is_dense_tensor(value):
diff_size = tf.cast(target_nrows, tf.int64) - tf.shape(value, tf.int64)[0]
elif is_ragged_tensor(value):
diff_size = tf.cast(target_nrows, tf.int64) - tf.cast(
value.nrows(), tf.int64)
else:
raise ValueError(f'Unsupported type {type(value).__name__}')
spec = tf.type_spec_from_value(value)
relaxed_shape = tf.TensorShape([None, *spec.shape[1:]])
if isinstance(spec, tf.RaggedTensorSpec):
spec = tf.RaggedTensorSpec(
shape=relaxed_shape,
dtype=spec.dtype,
ragged_rank=spec.ragged_rank,
row_splits_dtype=spec.row_splits_dtype)
else:
assert isinstance(spec, tf.TensorSpec)
spec = tf.TensorSpec(shape=relaxed_shape, dtype=spec.dtype)
if validate:
validation_ops = [
tf.debugging.assert_non_negative(
diff_size,
f'The `value` has more rows then the target_nrows={target_nrows}.')
]
else:
validation_ops = []
with tf.control_dependencies(validation_ops):
diff = fill(spec, nrows=diff_size, value=padding_value)
return tf.concat([value, diff], axis=0)