Lib/compiler/static/declaration_visitor.py (212 lines of code) (raw):

# Copyright (c) Facebook, Inc. and its affiliates. (http://www.facebook.com) from __future__ import annotations from ast import ( AST, AnnAssign, Assign, AsyncFor, AsyncFunctionDef, AsyncWith, ClassDef, For, FunctionDef, If, Import, ImportFrom, Name, Try, While, With, ) from typing import Union, List, TYPE_CHECKING from .module_table import ModuleTable from .types import ( AwaitableTypeRef, Class, Function, ModuleInstance, ResolvedTypeRef, DecoratedMethod, TypeEnvironment, TypeName, TypeRef, UnknownDecoratedMethod, ) from .visitor import GenericVisitor if TYPE_CHECKING: from .compiler import Compiler class NestedScope: def declare_class(self, node: AST, klass: Class) -> None: pass def declare_function(self, func: Function | DecoratedMethod) -> None: pass def declare_variable(self, node: AnnAssign, module: ModuleTable) -> None: pass def declare_variables(self, node: Assign, module: ModuleTable) -> None: pass TScopeTypes = Union[ModuleTable, Class, Function, NestedScope] class DeclarationVisitor(GenericVisitor[None]): def __init__( self, mod_name: str, filename: str, symbols: Compiler, optimize: int ) -> None: module = symbols[mod_name] = ModuleTable(mod_name, filename, symbols) super().__init__(module) self.scopes: List[TScopeTypes] = [self.module] self.optimize = optimize self.compiler = symbols self.type_env: TypeEnvironment = symbols.type_env def finish_bind(self) -> None: self.module.finish_bind() def parent_scope(self) -> TScopeTypes: return self.scopes[-1] def enter_scope(self, scope: TScopeTypes) -> None: self.scopes.append(scope) def exit_scope(self) -> None: self.scopes.pop() def visitAnnAssign(self, node: AnnAssign) -> None: self.parent_scope().declare_variable(node, self.module) def visitAssign(self, node: Assign) -> None: self.parent_scope().declare_variables(node, self.module) def visitClassDef(self, node: ClassDef) -> None: bases = [ self.module.resolve_type(base) or self.type_env.dynamic for base in node.bases ] if not bases: bases.append(self.type_env.object) with self.compiler.error_sink.error_context(self.filename, node): klasses = [] for base in bases: klasses.append( # TODO (self.module_name, node.name) here is wrong for all nested scopes base.make_subclass(TypeName(self.module_name, node.name), bases) ) for cur_type in klasses: if type(cur_type) != type(klasses[0]): self.syntax_error("Incompatible subtypes", node) klass = klasses[0] for base in bases: if base is self.type_env.named_tuple: # In named tuples, the fields are actually elements # of the tuple, so we can't do any advanced binding against it. klass = self.type_env.dynamic break if base is self.type_env.protocol: # Protocols aren't guaranteed to exist in the actual MRO, so let's treat # them as dynamic to force dynamic dispatch. klass = self.type_env.dynamic break if base.is_final: self.syntax_error( f"Class `{klass.instance.name}` cannot subclass a Final class: `{base.instance.name}`", node, ) parent_scope = self.parent_scope() # we can't statically load classes nested inside functions, and for now # we don't bother with ones nested inside classes (would need to fix # the TypeName construction above) if not isinstance(parent_scope, ModuleTable): klass = self.type_env.dynamic for d in reversed(node.decorator_list): if klass is self.type_env.dynamic: break with self.compiler.error_sink.error_context(self.filename, d): decorator = self.module.resolve_decorator(d) or self.type_env.dynamic klass = decorator.resolve_decorate_class(klass) self.enter_scope(NestedScope() if klass is self.type_env.dynamic else klass) for item in node.body: with self.compiler.error_sink.error_context(self.filename, item): self.visit(item) parent_scope.declare_class(node, klass.exact_type()) # We want the name corresponding to `C` to be the exact type when imported. self.module.types[node] = klass.exact_type() self.exit_scope() def _visitFunc(self, node: Union[FunctionDef, AsyncFunctionDef]) -> None: function = self._make_function(node) self.parent_scope().declare_function(function) def _make_function(self, node: Union[FunctionDef, AsyncFunctionDef]) -> Function: func = Function(node, self.module, self.type_ref(node)) self.enter_scope(func) for item in node.body: self.visit(item) self.exit_scope() func_type = func if node.decorator_list: # Since we haven't resolved decorators yet (until finish_bind), we # don't know what type we should ultimately set for this node; # Function.finish_bind() will likely override this. func_type = UnknownDecoratedMethod(func) self.module.types[node] = func_type return func def visitFunctionDef(self, node: FunctionDef) -> None: self._visitFunc(node) def visitAsyncFunctionDef(self, node: AsyncFunctionDef) -> None: self._visitFunc(node) def type_ref(self, node: Union[FunctionDef, AsyncFunctionDef]) -> TypeRef: ann = node.returns if not ann: res = ResolvedTypeRef(self.type_env.dynamic) else: res = TypeRef(self.module, ann) if isinstance(node, AsyncFunctionDef): res = AwaitableTypeRef(res, self.module.compiler) return res def visitImport(self, node: Import) -> None: for name in node.names: self.compiler.import_module(name.name, self.optimize) asname = name.asname if asname is None: top_level_module = name.name.split(".")[0] self.module.declare_import( top_level_module, None, ModuleInstance(top_level_module, self.compiler), ) else: self.module.declare_import( asname, None, ModuleInstance(name.name, self.compiler) ) def visitImportFrom(self, node: ImportFrom) -> None: mod_name = node.module if not mod_name or node.level: raise NotImplementedError("relative imports aren't supported") self.compiler.import_module(mod_name, self.optimize) mod = self.compiler.modules.get(mod_name) if mod is not None: for name in node.names: val = mod.children.get(name.name) child_name = name.asname or name.name if val is not None: self.module.declare_import(child_name, (mod_name, name.name), val) else: # We might be facing a module imported as an attribute. module_as_attribute = f"{mod_name}.{name.name}" self.compiler.import_module(module_as_attribute, self.optimize) if module_as_attribute in self.compiler.modules: self.module.declare_import( child_name, (mod_name, name.name), ModuleInstance(module_as_attribute, self.compiler), ) # We don't pick up declarations in nested statements def visitFor(self, node: For) -> None: self.enter_scope(NestedScope()) self.generic_visit(node) self.exit_scope() def visitAsyncFor(self, node: AsyncFor) -> None: self.enter_scope(NestedScope()) self.generic_visit(node) self.exit_scope() def visitWhile(self, node: While) -> None: self.enter_scope(NestedScope()) self.generic_visit(node) self.exit_scope() def visitIf(self, node: If) -> None: test = node.test if isinstance(test, Name) and test.id == "TYPE_CHECKING": self.visit(node.body) else: self.enter_scope(NestedScope()) self.visit(node.body) self.exit_scope() if node.orelse: self.enter_scope(NestedScope()) self.visit(node.orelse) self.exit_scope() def visitWith(self, node: With) -> None: self.enter_scope(NestedScope()) self.generic_visit(node) self.exit_scope() def visitAsyncWith(self, node: AsyncWith) -> None: self.enter_scope(NestedScope()) self.generic_visit(node) self.exit_scope() def visitTry(self, node: Try) -> None: self.enter_scope(NestedScope()) self.generic_visit(node) self.exit_scope()