def ones_like_leading_dims()

in tensorflow_gnn/graph/tensor_utils.py [0:0]


def ones_like_leading_dims(value: Value, rank: int,
                           dtype: tf.dtypes.DType) -> Value:
  """Creates a tensor of all ones for first `rank` dimensions."""
  if rank == 0:
    raise ValueError(f'Expected rank > 0, got {rank}')
  if rank > value.shape.rank:
    raise ValueError('`rank` is greater then `value` rank,'
                     f' got rank={rank},'
                     f' value.shape.rank={value.shape.rank}')

  if is_dense_tensor(value):
    size_shape = tf.shape(value)[:rank]
    return tf.ones(size_shape, dtype=dtype)

  if not is_ragged_tensor(value):
    raise ValueError(f'Unsupported type {type(value).__name__}')

  def iterate(value: Value, rank: int) -> Value:
    if rank == 0:
      if is_ragged_tensor(value):
        nrows = value.nrows()
      else:
        nrows = tf.shape(value)[0]
      return tf.ones(tf.expand_dims(nrows, -1), dtype=dtype)
    if value.uniform_row_length:
      return tf.RaggedTensor.from_uniform_row_length(
          iterate(value.values, rank - 1), value.uniform_row_lengths)
    return tf.RaggedTensor.from_row_splits(
        iterate(value.values, rank - 1), value.row_splits)

  return iterate(value, rank - 1)