in tensorflow_federated/python/core/impl/compiler/tree_analysis.py [0:0]
def extract_nodes_consuming(tree, predicate):
"""Returns the set of AST nodes which consume nodes matching `predicate`.
Notice we adopt the convention that a node which itself satisfies the
predicate is in this set.
Args:
tree: Instance of `building_blocks.ComputationBuildingBlock` to view as an
abstract syntax tree, and construct the set of nodes in this tree having a
dependency on nodes matching `predicate`; that is, the set of nodes whose
value depends on evaluating nodes matching `predicate`.
predicate: One-arg callable, accepting arguments of type
`building_blocks.ComputationBuildingBlock` and returning a `bool`
indicating match or mismatch with the desired pattern.
Returns:
A `set` of `building_blocks.ComputationBuildingBlock` instances
representing the nodes in `tree` dependent on nodes matching `predicate`.
"""
py_typecheck.check_type(tree, building_blocks.ComputationBuildingBlock)
py_typecheck.check_callable(predicate)
class _NodeSet:
def __init__(self):
self.mapping = {}
def add(self, comp):
self.mapping[id(comp)] = comp
def to_set(self):
return set(self.mapping.values())
dependent_nodes = _NodeSet()
def _are_children_in_dependent_set(comp, symbol_tree):
"""Checks if the dependencies of `comp` are present in `dependent_nodes`."""
if (comp.is_intrinsic() or comp.is_data() or comp.is_placement() or
comp.is_compiled_computation()):
return False
elif comp.is_lambda():
return id(comp.result) in dependent_nodes.mapping
elif comp.is_block():
return any(
id(x[1]) in dependent_nodes.mapping for x in comp.locals) or id(
comp.result) in dependent_nodes.mapping
elif comp.is_struct():
return any(id(x) in dependent_nodes.mapping for x in comp)
elif comp.is_selection():
return id(comp.source) in dependent_nodes.mapping
elif comp.is_call():
return id(comp.function) in dependent_nodes.mapping or id(
comp.argument) in dependent_nodes.mapping
elif comp.is_reference():
return _is_reference_dependent(comp, symbol_tree)
def _is_reference_dependent(comp, symbol_tree):
payload = symbol_tree.get_payload_with_name(comp.name)
if payload is None:
return False
# The postorder traversal ensures that we process any
# bindings before we process the reference to those bindings
return id(payload.value) in dependent_nodes.mapping
def _populate_dependent_set(comp, symbol_tree):
"""Populates `dependent_nodes` with all nodes dependent on `predicate`."""
if predicate(comp):
dependent_nodes.add(comp)
elif _are_children_in_dependent_set(comp, symbol_tree):
dependent_nodes.add(comp)
return comp, False
symbol_tree = transformation_utils.SymbolTree(
transformation_utils.ReferenceCounter)
transformation_utils.transform_postorder_with_symbol_bindings(
tree, _populate_dependent_set, symbol_tree)
return dependent_nodes.to_set()