Lib/compiler/static/declaration_visitor.py (212 lines of code) (raw):
# Copyright (c) Facebook, Inc. and its affiliates. (http://www.facebook.com)
from __future__ import annotations
from ast import (
AST,
AnnAssign,
Assign,
AsyncFor,
AsyncFunctionDef,
AsyncWith,
ClassDef,
For,
FunctionDef,
If,
Import,
ImportFrom,
Name,
Try,
While,
With,
)
from typing import Union, List, TYPE_CHECKING
from .module_table import ModuleTable
from .types import (
AwaitableTypeRef,
Class,
Function,
ModuleInstance,
ResolvedTypeRef,
DecoratedMethod,
TypeEnvironment,
TypeName,
TypeRef,
UnknownDecoratedMethod,
)
from .visitor import GenericVisitor
if TYPE_CHECKING:
from .compiler import Compiler
class NestedScope:
def declare_class(self, node: AST, klass: Class) -> None:
pass
def declare_function(self, func: Function | DecoratedMethod) -> None:
pass
def declare_variable(self, node: AnnAssign, module: ModuleTable) -> None:
pass
def declare_variables(self, node: Assign, module: ModuleTable) -> None:
pass
TScopeTypes = Union[ModuleTable, Class, Function, NestedScope]
class DeclarationVisitor(GenericVisitor[None]):
def __init__(
self, mod_name: str, filename: str, symbols: Compiler, optimize: int
) -> None:
module = symbols[mod_name] = ModuleTable(mod_name, filename, symbols)
super().__init__(module)
self.scopes: List[TScopeTypes] = [self.module]
self.optimize = optimize
self.compiler = symbols
self.type_env: TypeEnvironment = symbols.type_env
def finish_bind(self) -> None:
self.module.finish_bind()
def parent_scope(self) -> TScopeTypes:
return self.scopes[-1]
def enter_scope(self, scope: TScopeTypes) -> None:
self.scopes.append(scope)
def exit_scope(self) -> None:
self.scopes.pop()
def visitAnnAssign(self, node: AnnAssign) -> None:
self.parent_scope().declare_variable(node, self.module)
def visitAssign(self, node: Assign) -> None:
self.parent_scope().declare_variables(node, self.module)
def visitClassDef(self, node: ClassDef) -> None:
bases = [
self.module.resolve_type(base) or self.type_env.dynamic
for base in node.bases
]
if not bases:
bases.append(self.type_env.object)
with self.compiler.error_sink.error_context(self.filename, node):
klasses = []
for base in bases:
klasses.append(
# TODO (self.module_name, node.name) here is wrong for all nested scopes
base.make_subclass(TypeName(self.module_name, node.name), bases)
)
for cur_type in klasses:
if type(cur_type) != type(klasses[0]):
self.syntax_error("Incompatible subtypes", node)
klass = klasses[0]
for base in bases:
if base is self.type_env.named_tuple:
# In named tuples, the fields are actually elements
# of the tuple, so we can't do any advanced binding against it.
klass = self.type_env.dynamic
break
if base is self.type_env.protocol:
# Protocols aren't guaranteed to exist in the actual MRO, so let's treat
# them as dynamic to force dynamic dispatch.
klass = self.type_env.dynamic
break
if base.is_final:
self.syntax_error(
f"Class `{klass.instance.name}` cannot subclass a Final class: `{base.instance.name}`",
node,
)
parent_scope = self.parent_scope()
# we can't statically load classes nested inside functions, and for now
# we don't bother with ones nested inside classes (would need to fix
# the TypeName construction above)
if not isinstance(parent_scope, ModuleTable):
klass = self.type_env.dynamic
for d in reversed(node.decorator_list):
if klass is self.type_env.dynamic:
break
with self.compiler.error_sink.error_context(self.filename, d):
decorator = self.module.resolve_decorator(d) or self.type_env.dynamic
klass = decorator.resolve_decorate_class(klass)
self.enter_scope(NestedScope() if klass is self.type_env.dynamic else klass)
for item in node.body:
with self.compiler.error_sink.error_context(self.filename, item):
self.visit(item)
parent_scope.declare_class(node, klass.exact_type())
# We want the name corresponding to `C` to be the exact type when imported.
self.module.types[node] = klass.exact_type()
self.exit_scope()
def _visitFunc(self, node: Union[FunctionDef, AsyncFunctionDef]) -> None:
function = self._make_function(node)
self.parent_scope().declare_function(function)
def _make_function(self, node: Union[FunctionDef, AsyncFunctionDef]) -> Function:
func = Function(node, self.module, self.type_ref(node))
self.enter_scope(func)
for item in node.body:
self.visit(item)
self.exit_scope()
func_type = func
if node.decorator_list:
# Since we haven't resolved decorators yet (until finish_bind), we
# don't know what type we should ultimately set for this node;
# Function.finish_bind() will likely override this.
func_type = UnknownDecoratedMethod(func)
self.module.types[node] = func_type
return func
def visitFunctionDef(self, node: FunctionDef) -> None:
self._visitFunc(node)
def visitAsyncFunctionDef(self, node: AsyncFunctionDef) -> None:
self._visitFunc(node)
def type_ref(self, node: Union[FunctionDef, AsyncFunctionDef]) -> TypeRef:
ann = node.returns
if not ann:
res = ResolvedTypeRef(self.type_env.dynamic)
else:
res = TypeRef(self.module, ann)
if isinstance(node, AsyncFunctionDef):
res = AwaitableTypeRef(res, self.module.compiler)
return res
def visitImport(self, node: Import) -> None:
for name in node.names:
self.compiler.import_module(name.name, self.optimize)
asname = name.asname
if asname is None:
top_level_module = name.name.split(".")[0]
self.module.declare_import(
top_level_module,
None,
ModuleInstance(top_level_module, self.compiler),
)
else:
self.module.declare_import(
asname, None, ModuleInstance(name.name, self.compiler)
)
def visitImportFrom(self, node: ImportFrom) -> None:
mod_name = node.module
if not mod_name or node.level:
raise NotImplementedError("relative imports aren't supported")
self.compiler.import_module(mod_name, self.optimize)
mod = self.compiler.modules.get(mod_name)
if mod is not None:
for name in node.names:
val = mod.children.get(name.name)
child_name = name.asname or name.name
if val is not None:
self.module.declare_import(child_name, (mod_name, name.name), val)
else:
# We might be facing a module imported as an attribute.
module_as_attribute = f"{mod_name}.{name.name}"
self.compiler.import_module(module_as_attribute, self.optimize)
if module_as_attribute in self.compiler.modules:
self.module.declare_import(
child_name,
(mod_name, name.name),
ModuleInstance(module_as_attribute, self.compiler),
)
# We don't pick up declarations in nested statements
def visitFor(self, node: For) -> None:
self.enter_scope(NestedScope())
self.generic_visit(node)
self.exit_scope()
def visitAsyncFor(self, node: AsyncFor) -> None:
self.enter_scope(NestedScope())
self.generic_visit(node)
self.exit_scope()
def visitWhile(self, node: While) -> None:
self.enter_scope(NestedScope())
self.generic_visit(node)
self.exit_scope()
def visitIf(self, node: If) -> None:
test = node.test
if isinstance(test, Name) and test.id == "TYPE_CHECKING":
self.visit(node.body)
else:
self.enter_scope(NestedScope())
self.visit(node.body)
self.exit_scope()
if node.orelse:
self.enter_scope(NestedScope())
self.visit(node.orelse)
self.exit_scope()
def visitWith(self, node: With) -> None:
self.enter_scope(NestedScope())
self.generic_visit(node)
self.exit_scope()
def visitAsyncWith(self, node: AsyncWith) -> None:
self.enter_scope(NestedScope())
self.generic_visit(node)
self.exit_scope()
def visitTry(self, node: Try) -> None:
self.enter_scope(NestedScope())
self.generic_visit(node)
self.exit_scope()