def _validate_indices()

in tensorflow_gnn/graph/adjacency.py [0:0]


def _validate_indices(indices: Indices) -> Indices:
  """Checks that indices have compatible shapes."""
  if not indices:
    raise ValueError('`indices` must contain at least one entry.')

  assert_ops = []

  def check_index(tag, name, index):
    if index.dtype not in (tf.int32, tf.int64):
      raise ValueError((f'Adjacency indices ({tag_0}, {name_0}) must have '
                        f'tf.int32 or tf.int64 dtype, got {index.dtype}'))
    if isinstance(index, tf.RaggedTensor):
      if index.flat_values.shape.rank != 1:
        raise ValueError(
            (f'Adjacency indices ({tag_0}, {name_0}) as ragged tensor must'
             f' have flat values rank 1, got {index.flat_values.shape.rank}'))

  def check_compatibility(tag_0, name_0, index_0, tag_i, name_i, index_i):
    err_message = ('Adjacency indices are not compatible:'
                   f' ({tag_0}, {name_0}) and ({tag_i}, {name_i})')
    try:
      if index_0.dtype != index_i.dtype:
        raise ValueError(err_message)

      if isinstance(index_0, tf.Tensor) and isinstance(index_i, tf.Tensor):
        assert_ops.append(
            tf.assert_equal(
                tf.shape(index_0), tf.shape(index_i), message=err_message))
        return

      if isinstance(index_0, tf.RaggedTensor) and isinstance(
          index_i, tf.RaggedTensor):
        if index_0.ragged_rank != index_i.ragged_rank:
          raise ValueError(err_message)
        for partition_0, partition_i in zip(index_0.nested_row_splits,
                                            index_i.nested_row_splits):
          assert_ops.append(
              tf.assert_equal(partition_0, partition_i, message=err_message))

        assert_ops.append(
            tf.assert_equal(
                tf.shape(index_0.flat_values),
                tf.shape(index_i.flat_values),
                message=err_message))
        return
    except:
      raise ValueError(err_message)

    raise ValueError(err_message)

  indices = sorted(list(indices.items()), key=lambda i: i[0])
  tag_0, (name_0, index_0) = indices[0]
  check_index(tag_0, name_0, index_0)
  for tag_i, (name_i, index_i) in indices[1:]:
    check_index(tag_i, name_i, index_i)
    check_compatibility(tag_0, name_0, index_0, tag_i, name_i, index_i)

  # Apply identity operations to all index tensors to ensure that assertions are
  # executed in the graph mode.
  with tf.control_dependencies(assert_ops):
    result = {}
    for node_tag, (node_set, index) in indices:
      result[node_tag] = (node_set, tf.identity(index))

    return result