def _satisfies_total_sizes_internal()

in tensorflow_gnn/graph/padding_ops.py [0:0]


def _satisfies_total_sizes_internal(
    graph_tensor: gt.GraphTensor, total_sizes: preprocessing.SizesConstraints,
    check_fn: Callable[[tf.Tensor, str], Any]) -> List[Any]:
  """Checks that the graph tensor could fit in the target sizes.

  This operation tests multiple conditions that all must be True for the input
  `graph_tensor` to satisfy the `total_sizes`. The evaluated conditions along
  with a description string are passed to the caller using `check_fn` callbacks.

  The function tries to statically evaluate each condition and pass its result
  as tf.constant so that it could be extracted (see `tf.get_static_value()`).
  See `assert_satisfies_total_sizes()` for more information on how this might be
  useful.

  Args:
    graph_tensor: a graph tensor to check against total sizes.
    total_sizes: total sizes constraints for each graph piece.
    check_fn: callable with two arguments. The first argument is an evaluation
      result for one of required conditions. It is a boolean scalar tensor where
      `True` means condition is satisfied. If all conditions result int True,
      the `graph_tensor` satisfies `total_sizes`. The second argument is a
      string description of the condition. All values returned by the `check_fn`
      are accumulated and returned.

  Returns:
    List of all results returned by the `check_fn`.
  """
  # NOTE: TF implements for some operations s.c. contant folding when those
  # operations are evaluated statically if all their inputs have static values.
  # Those operations could also raise an exception staticaly if their arguments
  # are invalid. The constant folding is only supported by some operations (e.g.
  # tf.fill) and is not supported by others (e.g. tf.debug.Assert). This could
  # break control flow rules (https://www.tensorflow.org/guide/intro_to_graphs).
  # See b/205974487 for more examples. This function always attempts to evaluate
  # assertions statically by using python logical operators to test conditions
  # in the _fold_constants. Because those operators are overriden both by
  # np.ndarray and tf.Tensor they could be evaluated statically on in the
  # runtime depending on its arguments.
  total_num_components = graph_tensor.total_num_components
  could_add_new_component = _fold_constants(lambda x, y: x < y,
                                            total_num_components,
                                            total_sizes.total_num_components)

  assert_ops = [
      check_fn(
          _fold_constants(
              lambda x, y: x <= y, total_num_components,
              tf.convert_to_tensor(
                  total_sizes.total_num_components,
                  dtype=total_num_components.dtype)),
          ('Could not pad graph as it already has more graph components'
           ' then it is allowed by `total_sizes.total_num_components`'))
  ]

  def _check_sizes(entity_type: str, entity_name: str, total_size: tf.Tensor,
                   target_total_size: Optional[int]):
    if target_total_size is None:
      raise ValueError(
          f'The target total number of <{entity_name}> {entity_type} must be'
          ' specified as'
          f' `total_sizes.total_num_{entity_type}[<{entity_name}>]`.')
    target_total_size = tf.convert_to_tensor(
        target_total_size, dtype=total_size.dtype)
    assert_ops.append(
        check_fn(
            _fold_constants(lambda x, y: x <= y, total_size, target_total_size),
            (f'Could not pad <{entity_name}> as it already has more'
             f' {entity_type} then it is allowed by the'
             f' `total_sizes.total_num_{entity_type}[<{entity_name}>]`.')))

    assert_ops.append(
        check_fn(
            _fold_constants(
                lambda x, y: x | y, could_add_new_component,
                _fold_constants(lambda x, y: x == y, total_size,
                                target_total_size)),
            (f'Could not pad <{entity_name}> {entity_type}. To do this, at'
             ' least one graph component must be added to the input graph.'
             ' The latter is not possible as the input graph has already'
             ' `total_sizes.total_num_components` graph components.')))

  total_num_nodes = {}
  for name, item in graph_tensor.node_sets.items():
    total_size = item.total_size
    target_total_size = total_sizes.total_num_nodes.get(name, None)
    total_num_nodes[name] = total_size
    _check_sizes('nodes', name, total_size, target_total_size)

  for name, item in graph_tensor.edge_sets.items():
    total_size = item.total_size
    target_total_size = total_sizes.total_num_edges.get(name, None)
    _check_sizes('edges', name, total_size, target_total_size)

    assert target_total_size is not None
    has_all_edges = _fold_constants(lambda x, y: x == y, total_size,
                                    target_total_size)
    indices = item.adjacency.get_indices_dict()
    for _, (incident_node_set_name, _) in indices.items():
      permits_new_incident_nodes = _fold_constants(
          lambda x, y: x < y, total_num_nodes[incident_node_set_name],
          total_sizes.total_num_nodes[incident_node_set_name])
      assert_ops.append(
          check_fn(
              _fold_constants(lambda x, y: x | y, has_all_edges,
                              permits_new_incident_nodes),
              ('Could not create fake incident edges for the node set'
               f' {incident_node_set_name}. This could happen when the'
               ' total number of real nodes is equal to the target total'
               ' number of nodes, so there are no fake nodes that could be'
               ' connected by inserted fake edges.')))

  return assert_ops