library/compiler/visitor.py (124 lines of code) (raw):

# Portions copyright (c) Facebook, Inc. and its affiliates. (http://www.facebook.com) # pyre-unsafe import ast from ast import AST, copy_location from typing import Any, Sequence, TypeVar, Union # XXX should probably rename ASTVisitor to ASTWalker # XXX can it be made even more generic? class ASTVisitor: """Performs a depth-first walk of the AST The ASTVisitor is responsible for walking over the tree in the correct order. For each node, it checks the visitor argument for a method named 'visitNodeType' where NodeType is the name of the node's class, e.g. Class. If the method exists, it is called with the node as its sole argument. This is basically the same as the built-in ast.NodeVisitor except for the following differences: It accepts extra parameters through the visit methods for flowing state It uses "visitNodeName" instead of "visit_NodeName" It accepts a list to the generic_visit function rather than just nodes """ VERBOSE = 0 def __init__(self): self.node = None self._cache = {} def generic_visit(self, node, *args): """Called if no explicit visitor function exists for a node.""" if isinstance(node, list): for item in node: if isinstance(item, ast.AST): self.visit(item, *args) return for _field, value in ast.iter_fields(node): if isinstance(value, list): for item in value: if isinstance(item, ast.AST): self.visit(item, *args) elif isinstance(value, ast.AST): self.visit(value, *args) def walk_list(self, nodes: Sequence[AST], *args): for item in nodes: if isinstance(item, ast.AST): self.visit(item, *args) def skip_visit(self): return False def visit(self, node: Union[AST, Sequence[AST]], *args): if self.skip_visit(): return if isinstance(node, list): return self.walk_list(node, *args) self.node = node klass = node.__class__ meth = self._cache.get(klass, None) if meth is None: className = klass.__name__ meth = getattr(self, "visit" + className, self.generic_visit) self._cache[klass] = meth return meth(node, *args) TAst = TypeVar("TAst", bound=AST) class ASTRewriter(ASTVisitor): """performs rewrites on the AST, rewriting parent nodes when child nodes are replaced.""" @staticmethod def update_node(node: TAst, **replacement: Any) -> TAst: res = node for name, val in replacement.items(): existing = getattr(res, name) if existing is val: continue if node is res: res = ASTRewriter.clone_node(node) setattr(res, name, val) return res @staticmethod def clone_node(node: TAst) -> TAst: attrs = [] for name in node._fields: attr = getattr(node, name, None) if isinstance(attr, list): attr = list(attr) attrs.append(attr) new = type(node)(*attrs) return copy_location(new, node) def walk_list(self, nodes: Sequence[TAst], *args) -> Sequence[TAst]: new_values = [] changed = False for value in nodes: if isinstance(value, AST): new_value = self.visit(value) changed |= new_value is not value if new_value is None: continue elif not isinstance(new_value, AST): new_values.extend(new_value) continue value = new_value new_values.append(value) return new_values if changed else nodes def generic_visit(self, node: TAst, *args) -> TAst: ret_node = node for field, old_value in ast.iter_fields(node): if not isinstance(old_value, (AST, list)): continue new_node = self.visit(old_value) assert ( # noqa: IG01 new_node is not None ), f"can't remove AST nodes that aren't part of a list {old_value!r}" if new_node is not old_value: if ret_node is node: ret_node = self.clone_node(node) setattr(ret_node, field, new_node) return ret_node class ExampleASTVisitor(ASTVisitor): """Prints examples of the nodes that aren't visited This visitor-driver is only useful for development, when it's helpful to develop a visitor incrementally, and get feedback on what you still have to do. """ examples = {} def visit(self, node, *args): self.node = node meth = self._cache.get(node.__class__, None) className = node.__class__.__name__ if meth is None: meth = getattr(self, "visit" + className, 0) self._cache[node.__class__] = meth if self.VERBOSE > 1: print("visit", className, meth and meth.__name__ or "") if meth: meth(node, *args) elif self.VERBOSE > 0: klass = node.__class__ if klass not in self.examples: self.examples[klass] = klass print() print(self) print(klass) for attr in dir(node): if attr[0] != "_": print("\t", "%-12.12s" % attr, getattr(node, attr)) print() return self.default(node, *args) # XXX this is an API change def walk(tree, visitor): return visitor.visit(tree) def dumpNode(node): print(node.__class__) for attr in dir(node): if attr[0] != "_": print("\t", "%-10.10s" % attr, getattr(node, attr))