def pad_to_nrows()

in tensorflow_gnn/graph/tensor_utils.py [0:0]


def pad_to_nrows(value: Value,
                 target_nrows: tf.Tensor,
                 padding_value: tf.Tensor,
                 validate: bool = True) -> Value:
  """Pads `value` to the target number of rows with scalar `padding_value`.

  Args:
    value: tensor of rank > 0 or ragged tensor to pad.
    target_nrows: number of rows in the result tensor. For a dense tensor, this
      is the outermost dimension size. For a ragged tensor, this is the number
      of rows in the outermost split (`tf.RaggedTensor.nrows`).
    padding_value: scalar value to use for padding.
    validate: if true, adds runtime checks that value could be padded.

  Returns:
    Input `value` padded to the target number of rows.
  """
  if value.shape.rank == 0:
    raise ValueError('The `value` must have rank>0, got scalar (rank=0)')

  if is_dense_tensor(value):
    diff_size = tf.cast(target_nrows, tf.int64) - tf.shape(value, tf.int64)[0]
  elif is_ragged_tensor(value):
    diff_size = tf.cast(target_nrows, tf.int64) - tf.cast(
        value.nrows(), tf.int64)
  else:
    raise ValueError(f'Unsupported type {type(value).__name__}')

  spec = tf.type_spec_from_value(value)
  relaxed_shape = tf.TensorShape([None, *spec.shape[1:]])
  if isinstance(spec, tf.RaggedTensorSpec):
    spec = tf.RaggedTensorSpec(
        shape=relaxed_shape,
        dtype=spec.dtype,
        ragged_rank=spec.ragged_rank,
        row_splits_dtype=spec.row_splits_dtype)
  else:
    assert isinstance(spec, tf.TensorSpec)
    spec = tf.TensorSpec(shape=relaxed_shape, dtype=spec.dtype)

  if validate:
    validation_ops = [
        tf.debugging.assert_non_negative(
            diff_size,
            f'The `value` has more rows then the target_nrows={target_nrows}.')
    ]
  else:
    validation_ops = []

  with tf.control_dependencies(validation_ops):
    diff = fill(spec, nrows=diff_size, value=padding_value)
    return tf.concat([value, diff], axis=0)