in tensorflow_gnn/graph/padding_ops.py [0:0]
def pad_to_total_sizes(
graph_tensor: gt.GraphTensor,
target_total_sizes: preprocessing.SizesConstraints,
*,
padding_values: Optional[preprocessing.DefaultValues] = None,
validate: bool = True) -> Tuple[gt.GraphTensor, tf.Tensor]:
"""Pads graph tensor to the total sizes by inserting fake graph components.
Padding is done by inserting "fake" graph components at the end of the input
graph tensor until target total sizes are exactly matched. If that is not
possible (e.g. input already has more nodes than allowed by the constraints)
function raises tf.errors.InvalidArgumentError. Context, node or edge features
of the appended fake components are filled using user-provided scalar values
or with zeros if the latter are not specified. Fake edges are created such
that each fake node has an approximately uniform number of incident edges
(NOTE: this behavior is not guaranteed and may change in the future).
Args:
graph_tensor: scalar graph tensor (rank=0) to pad.
target_total_sizes: target total sizes for each graph piece. Must define the
target number of graph components (`.total_num_components`), target total
number of items for each node set (`.total_num_nodes[node_set_name]`) and
likewise for each edge set (`.total_num_edges[edge_set_name]`).
padding_values: optional mapping from a context, node set or edge set
feature name to a scalar tensor to use for padding. If no value is
specified for some feature, its type 'zero' is used (as in tf.zeros(...)).
validate: If true, then use assertions to check that the input graph tensor
could be padded. NOTE: while these assertions provide more readable error
messages, they incur a runtime cost, since assertions must be checked for
each input value.
Returns:
Tuple of padded graph tensor and padding mask. The mask is a rank-1 dense
boolean tensor wth size equal to the number of graph compoents is the result
containing True for real graph components and False - for fake one used for
padding.
Raises:
ValueError: if input parameters are invalid.
tf.errors.InvalidArgumentError: if input graph tensor could not be padded to
the `target_total_sizes`.
"""
gt.check_scalar_graph_tensor(graph_tensor, 'tfgnn.pad_to_total_sizes()')
def _ifnone(value, default):
return value if value is not None else default
if padding_values is None:
padding_values = preprocessing.DefaultValues()
def get_default_value(
graph_piece_spec: gt._GraphPieceWithFeaturesSpec, # pylint: disable=protected-access
padding_values: gt.Fields,
feature_name: str,
debug_context: str) -> tf.Tensor:
spec = graph_piece_spec.features_spec[feature_name]
value = padding_values.get(feature_name, None)
if value is None:
value = tf.zeros([], dtype=spec.dtype)
else:
value = tf.convert_to_tensor(value, spec.dtype)
if value.shape.rank != 0:
raise ValueError(f'Default value for {debug_context} must be scalar,'
f' got shape={value.shape}')
return value
def get_min_max_fake_nodes_indices(
node_set_name: str) -> Tuple[tf.Tensor, tf.Tensor]:
min_node_index = graph_tensor.node_sets[node_set_name].total_size
max_node_index = tf.constant(
target_total_sizes.total_num_nodes[node_set_name],
dtype=min_node_index.dtype) - 1
return min_node_index, max_node_index
# Note: we check that graph tensor could potentially fit into the target sizes
# before running padding. This simplifies padding implementation and removes
# duplicative validations.
if validate:
validation_ops = assert_satisfies_total_sizes(graph_tensor,
target_total_sizes)
else:
validation_ops = []
with tf.control_dependencies(validation_ops):
total_num_components = graph_tensor.total_num_components
target_total_num_components = target_total_sizes.total_num_components
padded_context = _pad_to_total_sizes(
graph_tensor.context,
target_total_num_components=target_total_num_components,
padding_value_fn=functools.partial(
get_default_value,
graph_tensor.context.spec,
_ifnone(padding_values.context, {}),
debug_context='context'))
padded_node_sets = {}
for name, item in graph_tensor.node_sets.items():
padded_node_sets[name] = _pad_to_total_sizes(
item,
target_total_num_components=target_total_num_components,
target_total_size=target_total_sizes.total_num_nodes[name],
padding_value_fn=functools.partial(
get_default_value,
item.spec,
_ifnone(padding_values.node_sets, {}).get(name, {}),
debug_context=f'{name} nodes'))
padded_edge_sets = {}
for name, item in graph_tensor.edge_sets.items():
padded_edge_sets[name] = _pad_to_total_sizes(
item,
target_total_num_components=target_total_num_components,
target_total_size=target_total_sizes.total_num_edges[name],
min_max_node_index_fn=get_min_max_fake_nodes_indices,
padding_value_fn=functools.partial(
get_default_value,
item.spec,
_ifnone(padding_values.edge_sets, {}).get(name, {}),
debug_context=f'{name} edges'))
padded_graph_tensor = gt.GraphTensor.from_pieces(
context=padded_context,
node_sets=padded_node_sets,
edge_sets=padded_edge_sets)
num_padded = tf.constant(target_total_num_components,
total_num_components.dtype) - total_num_components
padding_mask = tensor_utils.ensure_static_nrows(
tf.concat(
values=[
tf.ones([total_num_components], dtype=tf.bool),
tf.zeros([num_padded], dtype=tf.bool)
],
axis=0), target_total_num_components)
return padded_graph_tensor, cast(tf.Tensor, padding_mask)