def pad_to_total_sizes()

in tensorflow_gnn/graph/padding_ops.py [0:0]


def pad_to_total_sizes(
    graph_tensor: gt.GraphTensor,
    target_total_sizes: preprocessing.SizesConstraints,
    *,
    padding_values: Optional[preprocessing.DefaultValues] = None,
    validate: bool = True) -> Tuple[gt.GraphTensor, tf.Tensor]:
  """Pads graph tensor to the total sizes by inserting fake graph components.

  Padding is done by inserting "fake" graph components at the end of the input
  graph tensor until target total sizes are exactly matched. If that is not
  possible (e.g. input already has more nodes than allowed by the constraints)
  function raises tf.errors.InvalidArgumentError. Context, node or edge features
  of the appended fake components are filled using user-provided scalar values
  or with zeros if the latter are not specified. Fake edges are created such
  that each fake node has an approximately uniform number of incident edges
  (NOTE: this behavior is not guaranteed and may change in the future).

  Args:
    graph_tensor: scalar graph tensor (rank=0) to pad.
    target_total_sizes: target total sizes for each graph piece. Must define the
      target number of graph components (`.total_num_components`), target total
      number of items for each node set (`.total_num_nodes[node_set_name]`) and
      likewise for each edge set (`.total_num_edges[edge_set_name]`).
    padding_values: optional mapping from a context, node set or edge set
      feature name to a scalar tensor to use for padding. If no value is
      specified for some feature, its type 'zero' is used (as in tf.zeros(...)).
    validate: If true, then use assertions to check that the input graph tensor
      could be padded. NOTE: while these assertions provide more readable error
      messages, they incur a runtime cost, since assertions must be checked for
      each input value.

  Returns:
    Tuple of padded graph tensor and padding mask. The mask is a rank-1 dense
    boolean tensor wth size equal to the number of graph compoents is the result
    containing True for real graph components and False - for fake one used for
    padding.

  Raises:
    ValueError: if input parameters are invalid.
    tf.errors.InvalidArgumentError: if input graph tensor could not be padded to
      the `target_total_sizes`.
  """
  gt.check_scalar_graph_tensor(graph_tensor, 'tfgnn.pad_to_total_sizes()')

  def _ifnone(value, default):
    return value if value is not None else default

  if padding_values is None:
    padding_values = preprocessing.DefaultValues()

  def get_default_value(
      graph_piece_spec: gt._GraphPieceWithFeaturesSpec,  # pylint: disable=protected-access
      padding_values: gt.Fields,
      feature_name: str,
      debug_context: str) -> tf.Tensor:
    spec = graph_piece_spec.features_spec[feature_name]
    value = padding_values.get(feature_name, None)
    if value is None:
      value = tf.zeros([], dtype=spec.dtype)
    else:
      value = tf.convert_to_tensor(value, spec.dtype)
      if value.shape.rank != 0:
        raise ValueError(f'Default value for {debug_context} must be scalar,'
                         f' got shape={value.shape}')
    return value

  def get_min_max_fake_nodes_indices(
      node_set_name: str) -> Tuple[tf.Tensor, tf.Tensor]:
    min_node_index = graph_tensor.node_sets[node_set_name].total_size
    max_node_index = tf.constant(
        target_total_sizes.total_num_nodes[node_set_name],
        dtype=min_node_index.dtype) - 1
    return min_node_index, max_node_index

  # Note: we check that graph tensor could potentially fit into the target sizes
  # before running padding. This simplifies padding implementation and removes
  # duplicative validations.
  if validate:
    validation_ops = assert_satisfies_total_sizes(graph_tensor,
                                                  target_total_sizes)
  else:
    validation_ops = []

  with tf.control_dependencies(validation_ops):
    total_num_components = graph_tensor.total_num_components
    target_total_num_components = target_total_sizes.total_num_components

    padded_context = _pad_to_total_sizes(
        graph_tensor.context,
        target_total_num_components=target_total_num_components,
        padding_value_fn=functools.partial(
            get_default_value,
            graph_tensor.context.spec,
            _ifnone(padding_values.context, {}),
            debug_context='context'))

    padded_node_sets = {}
    for name, item in graph_tensor.node_sets.items():
      padded_node_sets[name] = _pad_to_total_sizes(
          item,
          target_total_num_components=target_total_num_components,
          target_total_size=target_total_sizes.total_num_nodes[name],
          padding_value_fn=functools.partial(
              get_default_value,
              item.spec,
              _ifnone(padding_values.node_sets, {}).get(name, {}),
              debug_context=f'{name} nodes'))

    padded_edge_sets = {}
    for name, item in graph_tensor.edge_sets.items():
      padded_edge_sets[name] = _pad_to_total_sizes(
          item,
          target_total_num_components=target_total_num_components,
          target_total_size=target_total_sizes.total_num_edges[name],
          min_max_node_index_fn=get_min_max_fake_nodes_indices,
          padding_value_fn=functools.partial(
              get_default_value,
              item.spec,
              _ifnone(padding_values.edge_sets, {}).get(name, {}),
              debug_context=f'{name} edges'))

  padded_graph_tensor = gt.GraphTensor.from_pieces(
      context=padded_context,
      node_sets=padded_node_sets,
      edge_sets=padded_edge_sets)

  num_padded = tf.constant(target_total_num_components,
                           total_num_components.dtype) - total_num_components
  padding_mask = tensor_utils.ensure_static_nrows(
      tf.concat(
          values=[
              tf.ones([total_num_components], dtype=tf.bool),
              tf.zeros([num_padded], dtype=tf.bool)
          ],
          axis=0), target_total_num_components)
  return padded_graph_tensor, cast(tf.Tensor, padding_mask)