in tensorflow_federated/python/core/impl/compiler/transformation_utils.py [0:0]
def transform_postorder_with_symbol_bindings(comp, transform, symbol_tree):
"""Uses symbol binding hooks to execute transformations.
`transform_postorder_with_symbol_bindings` hooks into the preorder traversal
that is defined by walking down the tree to its leaves, using
the variable bindings along this path to push information onto
the given `SymbolTree`. Once we hit the leaves, we walk back up the
tree in a postorder fashion, calling `transform` as we go.
The transformations `transform_postorder_with_symbol_bindings` executes are
therefore stateful in some sense. Here 'stateful' means that a transformation
executed on a given AST node in general depends on not only the node itself
or its immediate vicinity; possibly there is some global information on which
this transformation depends. `transform_postorder_with_symbol_bindings` is
functional 'from AST to AST' (where `comp` represents the root of an AST) but
not 'from node to node'.
One important fact to note: there are recursion invariants that
`transform_postorder_with_symbol_bindings` uses the `SymbolTree` data
structure to enforce. In particular, within a `transform` call the following
invariants hold:
* `symbol_tree.update_payload_with_name` with an argument `name` will call
`update` on the `BoundVariableTracker` in `symbol_tree` which tracks the
value of `ref` active in the current lexical scope. Will raise a
`NameError` if none exists.
* `symbol_tree.get_payload_with_name` with a string argument `name` will
return the `BoundVariableTracker` instance from `symbol_tree` which
corresponds to the computation bound to the variable `name` in the current
lexical scope. Will raise a `NameError` if none exists.
These recursion invariants are enforced by the framework, and should be
relied on when designing new transformations that depend on variable
bindings.
Args:
comp: Instance of `building_blocks.ComputationBuildingBlock` to read
information from or transform.
transform: Python function accepting `comp` and `symbol_tree` arguments and
returning `transformed_comp`.
symbol_tree: Instance of `SymbolTree`, the data structure into which we may
read information about variable bindings, and from which we may read.
Returns:
Returns a possibly modified version of `comp`, an instance
of `building_blocks.ComputationBuildingBlock`, along with a
Boolean with the value `True` if `comp` was transformed and `False` if it
was not.
"""
py_typecheck.check_type(comp, building_blocks.ComputationBuildingBlock)
py_typecheck.check_type(symbol_tree, SymbolTree)
if not callable(transform):
raise TypeError('Argument `transform` to '
'`transform_postorder_with_symbol_bindings` must '
'be callable.')
identifier_seq = itertools.count(start=1)
def _transform_postorder_with_symbol_bindings_switch(comp, transform_fn,
ctxt_tree,
identifier_sequence):
"""Recursive helper function delegated to after binding comp_id sequence."""
if (comp.is_compiled_computation() or comp.is_data() or
comp.is_intrinsic() or comp.is_placement() or comp.is_reference()):
return _traverse_leaf(comp, transform_fn, ctxt_tree, identifier_sequence)
elif comp.is_selection():
return _traverse_selection(comp, transform, ctxt_tree,
identifier_sequence)
elif comp.is_struct():
return _traverse_tuple(comp, transform, ctxt_tree, identifier_sequence)
elif comp.is_call():
return _traverse_call(comp, transform, ctxt_tree, identifier_sequence)
elif comp.is_lambda():
return _traverse_lambda(comp, transform, ctxt_tree, identifier_sequence)
elif comp.is_block():
return _traverse_block(comp, transform, ctxt_tree, identifier_sequence)
else:
raise NotImplementedError(
'Unrecognized computation building block: {}'.format(str(comp)))
def _traverse_leaf(comp, transform, context_tree, identifier_seq):
"""Helper function holding traversal logic for leaf nodes."""
_ = next(identifier_seq)
return transform(comp, context_tree)
def _traverse_selection(comp, transform, context_tree, identifier_seq):
"""Helper function holding traversal logic for selection nodes."""
_ = next(identifier_seq)
source, source_modified = _transform_postorder_with_symbol_bindings_switch(
comp.source, transform, context_tree, identifier_seq)
if source_modified:
# Normalize selection to index based on the type signature of the
# original source. The new source may not have names present.
if comp.index is not None:
index = comp.index
else:
index = structure.name_to_index_map(
comp.source.type_signature)[comp.name]
comp = building_blocks.Selection(source, index=index)
comp, comp_modified = transform(comp, context_tree)
return comp, comp_modified or source_modified
def _traverse_tuple(comp, transform, context_tree, identifier_seq):
"""Helper function holding traversal logic for tuple nodes."""
_ = next(identifier_seq)
elements = []
elements_modified = False
for key, value in structure.iter_elements(comp):
value, value_modified = _transform_postorder_with_symbol_bindings_switch(
value, transform, context_tree, identifier_seq)
elements.append((key, value))
elements_modified = elements_modified or value_modified
if elements_modified:
comp = building_blocks.Struct(elements)
comp, comp_modified = transform(comp, context_tree)
return comp, comp_modified or elements_modified
def _traverse_call(comp, transform, context_tree, identifier_seq):
"""Helper function holding traversal logic for call nodes."""
_ = next(identifier_seq)
fn, fn_modified = _transform_postorder_with_symbol_bindings_switch(
comp.function, transform, context_tree, identifier_seq)
if comp.argument is not None:
arg, arg_modified = _transform_postorder_with_symbol_bindings_switch(
comp.argument, transform, context_tree, identifier_seq)
else:
arg, arg_modified = (None, False)
if fn_modified or arg_modified:
comp = building_blocks.Call(fn, arg)
comp, comp_modified = transform(comp, context_tree)
return comp, comp_modified or fn_modified or arg_modified
def _traverse_lambda(comp, transform, context_tree, identifier_seq):
"""Helper function holding traversal logic for lambda nodes."""
comp_id = next(identifier_seq)
context_tree.drop_scope_down(comp_id)
context_tree.ingest_variable_binding(name=comp.parameter_name, value=None)
result, result_modified = _transform_postorder_with_symbol_bindings_switch(
comp.result, transform, context_tree, identifier_seq)
context_tree.walk_to_scope_beginning()
if result_modified:
comp = building_blocks.Lambda(comp.parameter_name, comp.parameter_type,
result)
comp, comp_modified = transform(comp, context_tree)
context_tree.pop_scope_up()
return comp, comp_modified or result_modified
def _traverse_block(comp, transform, context_tree, identifier_seq):
"""Helper function holding traversal logic for block nodes."""
comp_id = next(identifier_seq)
context_tree.drop_scope_down(comp_id)
variables = []
variables_modified = False
for key, value in comp.locals:
value, value_modified = _transform_postorder_with_symbol_bindings_switch(
value, transform, context_tree, identifier_seq)
context_tree.ingest_variable_binding(name=key, value=value)
variables.append((key, value))
variables_modified = variables_modified or value_modified
result, result_modified = _transform_postorder_with_symbol_bindings_switch(
comp.result, transform, context_tree, identifier_seq)
context_tree.walk_to_scope_beginning()
if variables_modified or result_modified:
comp = building_blocks.Block(variables, result)
comp, comp_modified = transform(comp, context_tree)
context_tree.pop_scope_up()
return comp, comp_modified or variables_modified or result_modified
return _transform_postorder_with_symbol_bindings_switch(
comp, transform, symbol_tree, identifier_seq)