def extract_nodes_consuming()

in tensorflow_federated/python/core/impl/compiler/tree_analysis.py [0:0]


def extract_nodes_consuming(tree, predicate):
  """Returns the set of AST nodes which consume nodes matching `predicate`.

  Notice we adopt the convention that a node which itself satisfies the
  predicate is in this set.

  Args:
    tree: Instance of `building_blocks.ComputationBuildingBlock` to view as an
      abstract syntax tree, and construct the set of nodes in this tree having a
      dependency on nodes matching `predicate`; that is, the set of nodes whose
      value depends on evaluating nodes matching `predicate`.
    predicate: One-arg callable, accepting arguments of type
      `building_blocks.ComputationBuildingBlock` and returning a `bool`
      indicating match or mismatch with the desired pattern.

  Returns:
    A `set` of `building_blocks.ComputationBuildingBlock` instances
    representing the nodes in `tree` dependent on nodes matching `predicate`.
  """
  py_typecheck.check_type(tree, building_blocks.ComputationBuildingBlock)
  py_typecheck.check_callable(predicate)

  class _NodeSet:

    def __init__(self):
      self.mapping = {}

    def add(self, comp):
      self.mapping[id(comp)] = comp

    def to_set(self):
      return set(self.mapping.values())

  dependent_nodes = _NodeSet()

  def _are_children_in_dependent_set(comp, symbol_tree):
    """Checks if the dependencies of `comp` are present in `dependent_nodes`."""
    if (comp.is_intrinsic() or comp.is_data() or comp.is_placement() or
        comp.is_compiled_computation()):
      return False
    elif comp.is_lambda():
      return id(comp.result) in dependent_nodes.mapping
    elif comp.is_block():
      return any(
          id(x[1]) in dependent_nodes.mapping for x in comp.locals) or id(
              comp.result) in dependent_nodes.mapping
    elif comp.is_struct():
      return any(id(x) in dependent_nodes.mapping for x in comp)
    elif comp.is_selection():
      return id(comp.source) in dependent_nodes.mapping
    elif comp.is_call():
      return id(comp.function) in dependent_nodes.mapping or id(
          comp.argument) in dependent_nodes.mapping
    elif comp.is_reference():
      return _is_reference_dependent(comp, symbol_tree)

  def _is_reference_dependent(comp, symbol_tree):
    payload = symbol_tree.get_payload_with_name(comp.name)
    if payload is None:
      return False
    # The postorder traversal ensures that we process any
    # bindings before we process the reference to those bindings
    return id(payload.value) in dependent_nodes.mapping

  def _populate_dependent_set(comp, symbol_tree):
    """Populates `dependent_nodes` with all nodes dependent on `predicate`."""
    if predicate(comp):
      dependent_nodes.add(comp)
    elif _are_children_in_dependent_set(comp, symbol_tree):
      dependent_nodes.add(comp)
    return comp, False

  symbol_tree = transformation_utils.SymbolTree(
      transformation_utils.ReferenceCounter)
  transformation_utils.transform_postorder_with_symbol_bindings(
      tree, _populate_dependent_set, symbol_tree)
  return dependent_nodes.to_set()