def transform_postorder_with_symbol_bindings()

in tensorflow_federated/python/core/impl/compiler/transformation_utils.py [0:0]


def transform_postorder_with_symbol_bindings(comp, transform, symbol_tree):
  """Uses symbol binding hooks to execute transformations.

  `transform_postorder_with_symbol_bindings` hooks into the preorder traversal
  that is defined by walking down the tree to its leaves, using
  the variable bindings along this path to push information onto
  the given `SymbolTree`. Once we hit the leaves, we walk back up the
  tree in a postorder fashion, calling `transform` as we go.

  The transformations `transform_postorder_with_symbol_bindings` executes are
  therefore stateful in some sense. Here 'stateful' means that a transformation
  executed on a given AST node in general depends on not only the node itself
  or its immediate vicinity; possibly there is some global information on which
  this transformation depends. `transform_postorder_with_symbol_bindings` is
  functional 'from AST to AST' (where `comp` represents the root of an AST) but
  not 'from node to node'.

  One important fact to note: there are recursion invariants that
  `transform_postorder_with_symbol_bindings` uses the `SymbolTree` data
  structure to enforce. In particular, within a `transform` call the following
  invariants hold:

  *  `symbol_tree.update_payload_with_name` with an argument `name` will call
     `update` on the `BoundVariableTracker` in `symbol_tree` which tracks the
     value of `ref` active in the current lexical scope. Will raise a
     `NameError` if none exists.

  *  `symbol_tree.get_payload_with_name` with a string argument `name` will
     return the `BoundVariableTracker` instance from `symbol_tree` which
     corresponds to the computation bound to the variable `name` in the current
     lexical scope. Will raise a `NameError` if none exists.

  These recursion invariants are enforced by the framework, and should be
  relied on when designing new transformations that depend on variable
  bindings.

  Args:
    comp: Instance of `building_blocks.ComputationBuildingBlock` to read
      information from or transform.
    transform: Python function accepting `comp` and `symbol_tree` arguments and
      returning `transformed_comp`.
    symbol_tree: Instance of `SymbolTree`, the data structure into which we may
      read information about variable bindings, and from which we may read.

  Returns:
    Returns a possibly modified version of `comp`, an instance
    of `building_blocks.ComputationBuildingBlock`, along with a
    Boolean with the value `True` if `comp` was transformed and `False` if it
    was not.
  """
  py_typecheck.check_type(comp, building_blocks.ComputationBuildingBlock)
  py_typecheck.check_type(symbol_tree, SymbolTree)
  if not callable(transform):
    raise TypeError('Argument `transform` to '
                    '`transform_postorder_with_symbol_bindings` must '
                    'be callable.')
  identifier_seq = itertools.count(start=1)

  def _transform_postorder_with_symbol_bindings_switch(comp, transform_fn,
                                                       ctxt_tree,
                                                       identifier_sequence):
    """Recursive helper function delegated to after binding comp_id sequence."""
    if (comp.is_compiled_computation() or comp.is_data() or
        comp.is_intrinsic() or comp.is_placement() or comp.is_reference()):
      return _traverse_leaf(comp, transform_fn, ctxt_tree, identifier_sequence)
    elif comp.is_selection():
      return _traverse_selection(comp, transform, ctxt_tree,
                                 identifier_sequence)
    elif comp.is_struct():
      return _traverse_tuple(comp, transform, ctxt_tree, identifier_sequence)
    elif comp.is_call():
      return _traverse_call(comp, transform, ctxt_tree, identifier_sequence)
    elif comp.is_lambda():
      return _traverse_lambda(comp, transform, ctxt_tree, identifier_sequence)
    elif comp.is_block():
      return _traverse_block(comp, transform, ctxt_tree, identifier_sequence)
    else:
      raise NotImplementedError(
          'Unrecognized computation building block: {}'.format(str(comp)))

  def _traverse_leaf(comp, transform, context_tree, identifier_seq):
    """Helper function holding traversal logic for leaf nodes."""
    _ = next(identifier_seq)
    return transform(comp, context_tree)

  def _traverse_selection(comp, transform, context_tree, identifier_seq):
    """Helper function holding traversal logic for selection nodes."""
    _ = next(identifier_seq)
    source, source_modified = _transform_postorder_with_symbol_bindings_switch(
        comp.source, transform, context_tree, identifier_seq)
    if source_modified:
      # Normalize selection to index based on the type signature of the
      # original source. The new source may not have names present.
      if comp.index is not None:
        index = comp.index
      else:
        index = structure.name_to_index_map(
            comp.source.type_signature)[comp.name]
      comp = building_blocks.Selection(source, index=index)
    comp, comp_modified = transform(comp, context_tree)
    return comp, comp_modified or source_modified

  def _traverse_tuple(comp, transform, context_tree, identifier_seq):
    """Helper function holding traversal logic for tuple nodes."""
    _ = next(identifier_seq)
    elements = []
    elements_modified = False
    for key, value in structure.iter_elements(comp):
      value, value_modified = _transform_postorder_with_symbol_bindings_switch(
          value, transform, context_tree, identifier_seq)
      elements.append((key, value))
      elements_modified = elements_modified or value_modified
    if elements_modified:
      comp = building_blocks.Struct(elements)
    comp, comp_modified = transform(comp, context_tree)
    return comp, comp_modified or elements_modified

  def _traverse_call(comp, transform, context_tree, identifier_seq):
    """Helper function holding traversal logic for call nodes."""
    _ = next(identifier_seq)
    fn, fn_modified = _transform_postorder_with_symbol_bindings_switch(
        comp.function, transform, context_tree, identifier_seq)
    if comp.argument is not None:
      arg, arg_modified = _transform_postorder_with_symbol_bindings_switch(
          comp.argument, transform, context_tree, identifier_seq)
    else:
      arg, arg_modified = (None, False)
    if fn_modified or arg_modified:
      comp = building_blocks.Call(fn, arg)
    comp, comp_modified = transform(comp, context_tree)
    return comp, comp_modified or fn_modified or arg_modified

  def _traverse_lambda(comp, transform, context_tree, identifier_seq):
    """Helper function holding traversal logic for lambda nodes."""
    comp_id = next(identifier_seq)
    context_tree.drop_scope_down(comp_id)
    context_tree.ingest_variable_binding(name=comp.parameter_name, value=None)
    result, result_modified = _transform_postorder_with_symbol_bindings_switch(
        comp.result, transform, context_tree, identifier_seq)
    context_tree.walk_to_scope_beginning()
    if result_modified:
      comp = building_blocks.Lambda(comp.parameter_name, comp.parameter_type,
                                    result)
    comp, comp_modified = transform(comp, context_tree)
    context_tree.pop_scope_up()
    return comp, comp_modified or result_modified

  def _traverse_block(comp, transform, context_tree, identifier_seq):
    """Helper function holding traversal logic for block nodes."""
    comp_id = next(identifier_seq)
    context_tree.drop_scope_down(comp_id)
    variables = []
    variables_modified = False
    for key, value in comp.locals:
      value, value_modified = _transform_postorder_with_symbol_bindings_switch(
          value, transform, context_tree, identifier_seq)
      context_tree.ingest_variable_binding(name=key, value=value)
      variables.append((key, value))
      variables_modified = variables_modified or value_modified
    result, result_modified = _transform_postorder_with_symbol_bindings_switch(
        comp.result, transform, context_tree, identifier_seq)
    context_tree.walk_to_scope_beginning()
    if variables_modified or result_modified:
      comp = building_blocks.Block(variables, result)
    comp, comp_modified = transform(comp, context_tree)
    context_tree.pop_scope_up()
    return comp, comp_modified or variables_modified or result_modified

  return _transform_postorder_with_symbol_bindings_switch(
      comp, transform, symbol_tree, identifier_seq)