in tensorflow_gnn/graph/padding_ops.py [0:0]
def _satisfies_total_sizes_internal(
graph_tensor: gt.GraphTensor, total_sizes: preprocessing.SizesConstraints,
check_fn: Callable[[tf.Tensor, str], Any]) -> List[Any]:
"""Checks that the graph tensor could fit in the target sizes.
This operation tests multiple conditions that all must be True for the input
`graph_tensor` to satisfy the `total_sizes`. The evaluated conditions along
with a description string are passed to the caller using `check_fn` callbacks.
The function tries to statically evaluate each condition and pass its result
as tf.constant so that it could be extracted (see `tf.get_static_value()`).
See `assert_satisfies_total_sizes()` for more information on how this might be
useful.
Args:
graph_tensor: a graph tensor to check against total sizes.
total_sizes: total sizes constraints for each graph piece.
check_fn: callable with two arguments. The first argument is an evaluation
result for one of required conditions. It is a boolean scalar tensor where
`True` means condition is satisfied. If all conditions result int True,
the `graph_tensor` satisfies `total_sizes`. The second argument is a
string description of the condition. All values returned by the `check_fn`
are accumulated and returned.
Returns:
List of all results returned by the `check_fn`.
"""
# NOTE: TF implements for some operations s.c. contant folding when those
# operations are evaluated statically if all their inputs have static values.
# Those operations could also raise an exception staticaly if their arguments
# are invalid. The constant folding is only supported by some operations (e.g.
# tf.fill) and is not supported by others (e.g. tf.debug.Assert). This could
# break control flow rules (https://www.tensorflow.org/guide/intro_to_graphs).
# See b/205974487 for more examples. This function always attempts to evaluate
# assertions statically by using python logical operators to test conditions
# in the _fold_constants. Because those operators are overriden both by
# np.ndarray and tf.Tensor they could be evaluated statically on in the
# runtime depending on its arguments.
total_num_components = graph_tensor.total_num_components
could_add_new_component = _fold_constants(lambda x, y: x < y,
total_num_components,
total_sizes.total_num_components)
assert_ops = [
check_fn(
_fold_constants(
lambda x, y: x <= y, total_num_components,
tf.convert_to_tensor(
total_sizes.total_num_components,
dtype=total_num_components.dtype)),
('Could not pad graph as it already has more graph components'
' then it is allowed by `total_sizes.total_num_components`'))
]
def _check_sizes(entity_type: str, entity_name: str, total_size: tf.Tensor,
target_total_size: Optional[int]):
if target_total_size is None:
raise ValueError(
f'The target total number of <{entity_name}> {entity_type} must be'
' specified as'
f' `total_sizes.total_num_{entity_type}[<{entity_name}>]`.')
target_total_size = tf.convert_to_tensor(
target_total_size, dtype=total_size.dtype)
assert_ops.append(
check_fn(
_fold_constants(lambda x, y: x <= y, total_size, target_total_size),
(f'Could not pad <{entity_name}> as it already has more'
f' {entity_type} then it is allowed by the'
f' `total_sizes.total_num_{entity_type}[<{entity_name}>]`.')))
assert_ops.append(
check_fn(
_fold_constants(
lambda x, y: x | y, could_add_new_component,
_fold_constants(lambda x, y: x == y, total_size,
target_total_size)),
(f'Could not pad <{entity_name}> {entity_type}. To do this, at'
' least one graph component must be added to the input graph.'
' The latter is not possible as the input graph has already'
' `total_sizes.total_num_components` graph components.')))
total_num_nodes = {}
for name, item in graph_tensor.node_sets.items():
total_size = item.total_size
target_total_size = total_sizes.total_num_nodes.get(name, None)
total_num_nodes[name] = total_size
_check_sizes('nodes', name, total_size, target_total_size)
for name, item in graph_tensor.edge_sets.items():
total_size = item.total_size
target_total_size = total_sizes.total_num_edges.get(name, None)
_check_sizes('edges', name, total_size, target_total_size)
assert target_total_size is not None
has_all_edges = _fold_constants(lambda x, y: x == y, total_size,
target_total_size)
indices = item.adjacency.get_indices_dict()
for _, (incident_node_set_name, _) in indices.items():
permits_new_incident_nodes = _fold_constants(
lambda x, y: x < y, total_num_nodes[incident_node_set_name],
total_sizes.total_num_nodes[incident_node_set_name])
assert_ops.append(
check_fn(
_fold_constants(lambda x, y: x | y, has_all_edges,
permits_new_incident_nodes),
('Could not create fake incident edges for the node set'
f' {incident_node_set_name}. This could happen when the'
' total number of real nodes is equal to the target total'
' number of nodes, so there are no fake nodes that could be'
' connected by inserted fake edges.')))
return assert_ops