in tensorflow_gnn/graph/graph_tensor_io.py [0:0]
def _check_size_fields(spec: gt.GraphTensorSpec,
flat_values: gt.Fields) -> List[AssertOp]:
"""Checks special size fields for all node and edge sets.
Size fields must
* be set for all node and edge set defined in the `spec`;
* have identical shapes (so identical partitions of graph components);
* always describe number of nodes/edges in graph components, in particular,
be set to zero when there are no node/edge instances in some set.
Args:
spec: graph tensor specification.
flat_values: flattened graph tensor values to check matching the `spec`.
Returns:
List of assertion operations.
"""
asserts = []
size_shapes = []
size_error_message = (
'The `{size_name}` field must always be present when parsing graph '
'tensor with non-static number of graph components. E.g. it is required '
'that `nodes/{{node_set}}.{size_name}` and '
'`edges/{{edge_set}}.{size_name}` features are present for all node and '
'edge sets in each Tensorflow example.').format(size_name=gc.SIZE_NAME)
def check_ragged_shapes(size: tf.RaggedTensor):
assert isinstance(size, tf.RaggedTensor)
dim = 0
while isinstance(size.values, tf.RaggedTensor):
size = size.values
size_shapes.append((size.row_splits, (f'R{dim}',)))
dim += 1
assert isinstance(size.values, tf.Tensor)
asserts.append(
tf.debugging.assert_positive(
size.row_lengths(), message=size_error_message))
size_shapes.append((size.values, ('C',)))
def check_dense_shapes(size: tf.Tensor):
assert isinstance(size, tf.Tensor)
outer_shape = size.shape[(spec.rank + 1):]
assert None not in outer_shape.as_list(), (
'Undefined inner dimensions for dense `{}` field.'.format(gc.SIZE_NAME))
if size.shape[spec.rank:(spec.rank + 1)].as_list() == [None]:
asserts.append(
tf.debugging.assert_positive(
tf.size(size), message=size_error_message))
size_shapes.append((size, [f'D{d}' for d in range(size.shape.rank)]))
def check_shapes(size: gt.Field):
if isinstance(size, tf.RaggedTensor):
check_ragged_shapes(size)
elif isinstance(size, tf.Tensor):
check_dense_shapes(size)
else:
raise ValueError('Unsupported `{}` field type {}'.format(
gc.SIZE_NAME, type(size).__name__))
for fname, value in flat_values.items():
if _is_size_field(fname):
check_shapes(value)
# Note: This does not return an op.
tf.debugging.assert_shapes(
size_shapes,
message=('All `{}` fields must have identical shapes for all node and '
'edge sets.').format(gc.SIZE_NAME))
return asserts