def _check_size_fields()

in tensorflow_gnn/graph/graph_tensor_io.py [0:0]


def _check_size_fields(spec: gt.GraphTensorSpec,
                       flat_values: gt.Fields) -> List[AssertOp]:
  """Checks special size fields for all node and edge sets.

  Size fields must

  * be set for all node and edge set defined in the `spec`;
  * have identical shapes (so identical partitions of graph components);
  * always describe number of nodes/edges in graph components, in particular,
    be set to zero when there are no node/edge instances in some set.

  Args:
    spec: graph tensor specification.
    flat_values: flattened graph tensor values to check matching the `spec`.

  Returns:
    List of assertion operations.
  """
  asserts = []
  size_shapes = []

  size_error_message = (
      'The `{size_name}` field must always be present when parsing graph '
      'tensor with non-static number of graph components. E.g. it is required '
      'that `nodes/{{node_set}}.{size_name}` and '
      '`edges/{{edge_set}}.{size_name}` features are present for all node and '
      'edge sets in each Tensorflow example.').format(size_name=gc.SIZE_NAME)

  def check_ragged_shapes(size: tf.RaggedTensor):
    assert isinstance(size, tf.RaggedTensor)

    dim = 0
    while isinstance(size.values, tf.RaggedTensor):
      size = size.values
      size_shapes.append((size.row_splits, (f'R{dim}',)))
      dim += 1

    assert isinstance(size.values, tf.Tensor)
    asserts.append(
        tf.debugging.assert_positive(
            size.row_lengths(), message=size_error_message))
    size_shapes.append((size.values, ('C',)))

  def check_dense_shapes(size: tf.Tensor):
    assert isinstance(size, tf.Tensor)
    outer_shape = size.shape[(spec.rank + 1):]
    assert None not in outer_shape.as_list(), (
        'Undefined inner dimensions for dense `{}` field.'.format(gc.SIZE_NAME))

    if size.shape[spec.rank:(spec.rank + 1)].as_list() == [None]:
      asserts.append(
          tf.debugging.assert_positive(
              tf.size(size), message=size_error_message))
    size_shapes.append((size, [f'D{d}' for d in range(size.shape.rank)]))

  def check_shapes(size: gt.Field):
    if isinstance(size, tf.RaggedTensor):
      check_ragged_shapes(size)
    elif isinstance(size, tf.Tensor):
      check_dense_shapes(size)
    else:
      raise ValueError('Unsupported `{}` field type {}'.format(
          gc.SIZE_NAME, type(size).__name__))

  for fname, value in flat_values.items():
    if _is_size_field(fname):
      check_shapes(value)

  # Note: This does not return an op.
  tf.debugging.assert_shapes(
      size_shapes,
      message=('All `{}` fields must have identical shapes for all node and '
               'edge sets.').format(gc.SIZE_NAME))

  return asserts