in tensorflow_gnn/graph/adjacency.py [0:0]
def _validate_indices(indices: Indices) -> Indices:
"""Checks that indices have compatible shapes."""
if not indices:
raise ValueError('`indices` must contain at least one entry.')
assert_ops = []
def check_index(tag, name, index):
if index.dtype not in (tf.int32, tf.int64):
raise ValueError((f'Adjacency indices ({tag_0}, {name_0}) must have '
f'tf.int32 or tf.int64 dtype, got {index.dtype}'))
if isinstance(index, tf.RaggedTensor):
if index.flat_values.shape.rank != 1:
raise ValueError(
(f'Adjacency indices ({tag_0}, {name_0}) as ragged tensor must'
f' have flat values rank 1, got {index.flat_values.shape.rank}'))
def check_compatibility(tag_0, name_0, index_0, tag_i, name_i, index_i):
err_message = ('Adjacency indices are not compatible:'
f' ({tag_0}, {name_0}) and ({tag_i}, {name_i})')
try:
if index_0.dtype != index_i.dtype:
raise ValueError(err_message)
if isinstance(index_0, tf.Tensor) and isinstance(index_i, tf.Tensor):
assert_ops.append(
tf.assert_equal(
tf.shape(index_0), tf.shape(index_i), message=err_message))
return
if isinstance(index_0, tf.RaggedTensor) and isinstance(
index_i, tf.RaggedTensor):
if index_0.ragged_rank != index_i.ragged_rank:
raise ValueError(err_message)
for partition_0, partition_i in zip(index_0.nested_row_splits,
index_i.nested_row_splits):
assert_ops.append(
tf.assert_equal(partition_0, partition_i, message=err_message))
assert_ops.append(
tf.assert_equal(
tf.shape(index_0.flat_values),
tf.shape(index_i.flat_values),
message=err_message))
return
except:
raise ValueError(err_message)
raise ValueError(err_message)
indices = sorted(list(indices.items()), key=lambda i: i[0])
tag_0, (name_0, index_0) = indices[0]
check_index(tag_0, name_0, index_0)
for tag_i, (name_i, index_i) in indices[1:]:
check_index(tag_i, name_i, index_i)
check_compatibility(tag_0, name_0, index_0, tag_i, name_i, index_i)
# Apply identity operations to all index tensors to ensure that assertions are
# executed in the graph mode.
with tf.control_dependencies(assert_ops):
result = {}
for node_tag, (node_set, index) in indices:
result[node_tag] = (node_set, tf.identity(index))
return result