in tensorflow_gnn/graph/batching_utils.py [0:0]
def dynamic_batch(dataset: tf.data.Dataset,
constraints: SizesConstraints) -> tf.data.Dataset:
"""Batches as many consecutive graphs as allowed by the `constraints`.
Each result batch can have variable number of graphs. Batches are returned as
graph tensors of rank 1 with the first dimension that indexes individual
examples in the batch. The result graphs could be converted to scalar graph
tensors using `.merge_batch_to_components()` and then padded to the target
sizes with `pad_to_total_sizes()`.
TODO(b/212274918): add support for non-scalar input graph tensors.
NOTE: this operation is more expensive compared to the fixed size batching.
The overhead is mainly due to `tf.data.Dataset.scan()` and grows with the
average number of graphs in the result batches. This overhead is small when
only a few examples are batched on evarage. When ~100 examples are combined on
average the operation could become 3-4x slower. For the latter case consider
static batching as by the law of large numbers it should create comparable
results over such a large sample size. Another alternative is a mixed
strategy: if on average N >> 10 graphs are batched, first use fixed size
batching with sqrt(N) batch size, convert rank-1 results into scalar graphs
using `.merge_batch_to_components()` and then apply dynamic batching.
Args:
dataset: dataset of scalar graph tensors.
constraints: the size contrains for the graph tensor. Must define the
maximum number of graph components (`.total_num_components`), the maximum
total number of nodes in each node set (`.total_num_nodes[node_set_name]`)
and likewise for each edge set (`.total_num_edges[edge_set_name]`).
Returns:
The dataset of rank-1 graph tensors compatible with the `constraints`.
Raises:
ValueError: if the `constraints` are not defined for some node sets or edges
sets defined by the graph tensors type specification.
tf.errors.InvalidArgumentError: if any of the input graph tensor instances
are not compatible with the `constraints` so batching is not possible. For
example, if some graph tensor has more nodes then it is allowed.
"""
# pylint: disable=protected-access
#
# The implementation relies on `._to_tensor_list()` from the composite tensor
# API to convert graph pieces into the flat list of variant tensors. The API
# guarantees that those components could be stacked independetly and with
# `._from_tensor_list()` combined into the rank+1 graph piece. Those stackable
# components are accumulated in the TensorArray containers. This allows the
# components to be stacked only once using `TensorArray.stack()` when result
# batch is finalized.
input_spec = dataset.element_spec
if not isinstance(input_spec, gt.GraphTensorSpec):
raise ValueError('The element of dataset must be scalar GraphTensor.')
gt.check_scalar_graph_tensor(
cast(gt.GraphTensorSpec, input_spec), 'dynamic_batch()')
output_spec = input_spec._batch(None)
constraints = _validate_and_prepare_constraints(constraints, input_spec)
# A terminating element is needed at the end of a finite dataset to flush
# accumulated inputs even if the size budget is not exhausted yet. We mark it
# with a boolean flag. This should not create any visible overhead compared to
# the graph tensor itself.
def add_eod_flag(dataset: tf.data.Dataset, flag: bool) -> tf.data.Dataset:
return dataset.map(lambda g: (g, flag))
has_infinite_cardinality = dataset.cardinality(
) == tf.data.INFINITE_CARDINALITY
if has_infinite_cardinality:
# For known-infinite datasets (like the repeated training data), we can take
# a shortcut, because there is no last element.
dataset = add_eod_flag(dataset, False)
else:
# For datasets with known-finite or unknown cardinality, we attach an extra
# copy of the first element with EOD flag at the end. (It's up to the code
# below to not output that.) Repeating a genuine input instead of an empty
# value is meant to avoid potential special cases elsewhere.
dataset_end = dataset.take(1)
dataset = add_eod_flag(dataset, False)
dataset = dataset.concatenate(add_eod_flag(dataset_end, True))
def get_empty_value() -> gt.GraphTensor:
return output_spec._create_empty_value()
def get_initial_state() -> _ScanState:
accumulator = list()
for spec in input_spec._flat_tensor_specs:
accumulator.append(
tf.TensorArray(
spec.dtype, size=0, dynamic_size=True, clear_after_read=True))
accumulator = tuple(accumulator)
return _ScanState(budget_left=constraints, accumulator=accumulator)
def extract_value(state: _ScanState) -> gt.GraphTensor:
value = tf.nest.map_structure(lambda t: t.stack(), state.accumulator)
value = output_spec._from_tensor_list(list(value))
return value
def get_next_state(state: _ScanState,
graph_tensor: gt.GraphTensor) -> _ScanState:
budget_left = tf.nest.map_structure(tf.math.subtract, state.budget_left,
_get_total_sizes(graph_tensor))
accumulator = tf.nest.map_structure(
lambda ta, spec: ta.write(ta.size(), spec), state.accumulator,
tuple(graph_tensor.spec._to_tensor_list(graph_tensor)))
return _ScanState(budget_left=budget_left, accumulator=accumulator)
def exceeds_budget(state: _ScanState,
graph_tensor: gt.GraphTensor) -> tf.Tensor:
within_budget = padding_ops.satisfies_total_sizes(graph_tensor,
state.budget_left)
return tf.math.logical_not(within_budget)
def scan_func(
state: _ScanState, value: Tuple[gt.GraphTensor, tf.Tensor]
) -> Tuple[_ScanState, Tuple[tf.Tensor, gt.GraphTensor]]:
graph_tensor, eod_flag = value
def flush():
with tf.control_dependencies(
padding_ops.assert_satisfies_total_sizes(
graph_tensor, target_total_sizes=constraints)):
# For simplicity, next_state remembers the graph_tensor in all cases.
# If graph_tensor comes with eod_flag=True, there will be no further
# call to flush(), and this artificially added graph_tensor is omitted
# from the output, as it should.
next_state = get_next_state(get_initial_state(), graph_tensor)
return (next_state, (True, extract_value(state)))
def accumulate():
next_state = get_next_state(state, graph_tensor)
return (next_state, (False, get_empty_value()))
should_flush = tf.math.logical_or(
exceeds_budget(state, graph_tensor), eod_flag)
return tf.cond(should_flush, flush, accumulate)
dataset = dataset.scan(get_initial_state(), scan_func)
dataset = dataset.filter(lambda has_value, _: has_value)
dataset = dataset.map(lambda _, value: value)
if has_infinite_cardinality and dataset.cardinality(
) != tf.data.INFINITE_CARDINALITY:
# The Dataset.filter() always sets cardinality to the UNKNOWN_CARDINALITY.
# In our case the `filter()` operation from above could only filter up to
# `constraints.total_num_components` consecutive elements, so if the input
# dataset is INFINITE_CARDINALITY so should be the output.
dataset = dataset.repeat()
return dataset