in tensorflow_federated/python/core/impl/compiler/transformation_utils.py [0:0]
def transform_postorder(comp, transform):
"""Traverses `comp` recursively postorder and replaces its constituents.
For each element of `comp` viewed as an expression tree, the transformation
`transform` is applied first to building blocks it is parameterized by, then
the element itself. The transformation `transform` should act as an identity
function on the kinds of elements (computation building blocks) it does not
care to transform. This corresponds to a post-order traversal of the
expression tree, i.e., parameters are always transformed left-to-right (in
the order in which they are listed in building block constructors), then the
parent is visited and transformed with the already-visited, and possibly
transformed arguments in place.
Note: In particular, in `Call(f,x)`, both `f` and `x` are arguments to `Call`.
Therefore, `f` is transformed into `f'`, next `x` into `x'` and finally,
`Call(f',x')` is transformed at the end.
Args:
comp: A `computation_building_block.ComputationBuildingBlock` to traverse
and transform bottom-up.
transform: The transformation to apply locally to each building block in
`comp`. It is a Python function that accepts a building block at input,
and should return a (building block, bool) tuple as output, where the
building block is a `computation_building_block.ComputationBuildingBlock`
representing either the original building block or a transformed building
block and the bool is a flag indicating if the building block was modified
as.
Returns:
The result of applying `transform` to parts of `comp` in a bottom-up
fashion, along with a Boolean with the value `True` if `comp` was
transformed and `False` if it was not.
Raises:
TypeError: If the arguments are of the wrong computation_types.
NotImplementedError: If the argument is a kind of computation building block
that is currently not recognized.
"""
py_typecheck.check_type(comp, building_blocks.ComputationBuildingBlock)
if (comp.is_compiled_computation() or comp.is_data() or comp.is_intrinsic() or
comp.is_placement() or comp.is_reference()):
return transform(comp)
elif comp.is_selection():
source, source_modified = transform_postorder(comp.source, transform)
if source_modified:
comp = building_blocks.Selection(source, comp.name, comp.index)
comp, comp_modified = transform(comp)
return comp, comp_modified or source_modified
elif comp.is_struct():
elements = []
elements_modified = False
for key, value in structure.iter_elements(comp):
value, value_modified = transform_postorder(value, transform)
elements.append((key, value))
elements_modified = elements_modified or value_modified
if elements_modified:
comp = building_blocks.Struct(
elements, container_type=comp.type_signature.python_container)
comp, comp_modified = transform(comp)
return comp, comp_modified or elements_modified
elif comp.is_call():
fn, fn_modified = transform_postorder(comp.function, transform)
if comp.argument is not None:
arg, arg_modified = transform_postorder(comp.argument, transform)
else:
arg, arg_modified = (None, False)
if fn_modified or arg_modified:
comp = building_blocks.Call(fn, arg)
comp, comp_modified = transform(comp)
return comp, comp_modified or fn_modified or arg_modified
elif comp.is_lambda():
result, result_modified = transform_postorder(comp.result, transform)
if result_modified:
comp = building_blocks.Lambda(comp.parameter_name, comp.parameter_type,
result)
comp, comp_modified = transform(comp)
return comp, comp_modified or result_modified
elif comp.is_block():
variables = []
variables_modified = False
for key, value in comp.locals:
value, value_modified = transform_postorder(value, transform)
variables.append((key, value))
variables_modified = variables_modified or value_modified
result, result_modified = transform_postorder(comp.result, transform)
if variables_modified or result_modified:
comp = building_blocks.Block(variables, result)
comp, comp_modified = transform(comp)
return comp, comp_modified or variables_modified or result_modified
else:
raise NotImplementedError(
'Unrecognized computation building block: {}'.format(str(comp)))