library/compiler/visitor.py (124 lines of code) (raw):
# Portions copyright (c) Facebook, Inc. and its affiliates. (http://www.facebook.com)
# pyre-unsafe
import ast
from ast import AST, copy_location
from typing import Any, Sequence, TypeVar, Union
# XXX should probably rename ASTVisitor to ASTWalker
# XXX can it be made even more generic?
class ASTVisitor:
"""Performs a depth-first walk of the AST
The ASTVisitor is responsible for walking over the tree in the
correct order. For each node, it checks the visitor argument for
a method named 'visitNodeType' where NodeType is the name of the
node's class, e.g. Class. If the method exists, it is called
with the node as its sole argument.
This is basically the same as the built-in ast.NodeVisitor except
for the following differences:
It accepts extra parameters through the visit methods for flowing state
It uses "visitNodeName" instead of "visit_NodeName"
It accepts a list to the generic_visit function rather than just nodes
"""
VERBOSE = 0
def __init__(self):
self.node = None
self._cache = {}
def generic_visit(self, node, *args):
"""Called if no explicit visitor function exists for a node."""
if isinstance(node, list):
for item in node:
if isinstance(item, ast.AST):
self.visit(item, *args)
return
for _field, value in ast.iter_fields(node):
if isinstance(value, list):
for item in value:
if isinstance(item, ast.AST):
self.visit(item, *args)
elif isinstance(value, ast.AST):
self.visit(value, *args)
def walk_list(self, nodes: Sequence[AST], *args):
for item in nodes:
if isinstance(item, ast.AST):
self.visit(item, *args)
def skip_visit(self):
return False
def visit(self, node: Union[AST, Sequence[AST]], *args):
if self.skip_visit():
return
if isinstance(node, list):
return self.walk_list(node, *args)
self.node = node
klass = node.__class__
meth = self._cache.get(klass, None)
if meth is None:
className = klass.__name__
meth = getattr(self, "visit" + className, self.generic_visit)
self._cache[klass] = meth
return meth(node, *args)
TAst = TypeVar("TAst", bound=AST)
class ASTRewriter(ASTVisitor):
"""performs rewrites on the AST, rewriting parent nodes when child nodes
are replaced."""
@staticmethod
def update_node(node: TAst, **replacement: Any) -> TAst:
res = node
for name, val in replacement.items():
existing = getattr(res, name)
if existing is val:
continue
if node is res:
res = ASTRewriter.clone_node(node)
setattr(res, name, val)
return res
@staticmethod
def clone_node(node: TAst) -> TAst:
attrs = []
for name in node._fields:
attr = getattr(node, name, None)
if isinstance(attr, list):
attr = list(attr)
attrs.append(attr)
new = type(node)(*attrs)
return copy_location(new, node)
def walk_list(self, nodes: Sequence[TAst], *args) -> Sequence[TAst]:
new_values = []
changed = False
for value in nodes:
if isinstance(value, AST):
new_value = self.visit(value)
changed |= new_value is not value
if new_value is None:
continue
elif not isinstance(new_value, AST):
new_values.extend(new_value)
continue
value = new_value
new_values.append(value)
return new_values if changed else nodes
def generic_visit(self, node: TAst, *args) -> TAst:
ret_node = node
for field, old_value in ast.iter_fields(node):
if not isinstance(old_value, (AST, list)):
continue
new_node = self.visit(old_value)
assert ( # noqa: IG01
new_node is not None
), f"can't remove AST nodes that aren't part of a list {old_value!r}"
if new_node is not old_value:
if ret_node is node:
ret_node = self.clone_node(node)
setattr(ret_node, field, new_node)
return ret_node
class ExampleASTVisitor(ASTVisitor):
"""Prints examples of the nodes that aren't visited
This visitor-driver is only useful for development, when it's
helpful to develop a visitor incrementally, and get feedback on what
you still have to do.
"""
examples = {}
def visit(self, node, *args):
self.node = node
meth = self._cache.get(node.__class__, None)
className = node.__class__.__name__
if meth is None:
meth = getattr(self, "visit" + className, 0)
self._cache[node.__class__] = meth
if self.VERBOSE > 1:
print("visit", className, meth and meth.__name__ or "")
if meth:
meth(node, *args)
elif self.VERBOSE > 0:
klass = node.__class__
if klass not in self.examples:
self.examples[klass] = klass
print()
print(self)
print(klass)
for attr in dir(node):
if attr[0] != "_":
print("\t", "%-12.12s" % attr, getattr(node, attr))
print()
return self.default(node, *args)
# XXX this is an API change
def walk(tree, visitor):
return visitor.visit(tree)
def dumpNode(node):
print(node.__class__)
for attr in dir(node):
if attr[0] != "_":
print("\t", "%-10.10s" % attr, getattr(node, attr))