Lib/compiler/static/module_table.py (272 lines of code) (raw):

# Copyright (c) Facebook, Inc. and its affiliates. (http://www.facebook.com) from __future__ import annotations import ast from ast import ( AST, Attribute, BinOp, Call, ClassDef, Constant, Expression, Subscript, Name, ) from contextlib import nullcontext from enum import Enum from typing import ( cast, ContextManager, Dict, List, Optional, Set, TYPE_CHECKING, Tuple, ) from ..errors import TypedSyntaxError from ..symbols import Scope, ModuleScope from .types import ( Callable, CType, Callable, Class, ClassVar, DynamicClass, ExactClass, Function, FunctionGroup, FinalClass, MethodType, TypeDescr, UnionType, Value, ) from .visitor import GenericVisitor if TYPE_CHECKING: from .compiler import Compiler class ModuleFlag(Enum): CHECKED_DICTS = 1 SHADOW_FRAME = 2 CHECKED_LISTS = 3 class ReferenceVisitor(GenericVisitor[Optional[Value]]): def __init__(self, module: ModuleTable) -> None: super().__init__(module) self.types: Dict[AST, Value] = {} self.subscr_nesting = 0 def visitName(self, node: Name) -> Optional[Value]: return self.module.children.get( node.id ) or self.module.compiler.builtins.children.get(node.id) def visitAttribute(self, node: Attribute) -> Optional[Value]: val = self.visit(node.value) if val is not None: return val.resolve_attr(node, self) class AnnotationVisitor(ReferenceVisitor): def resolve_annotation( self, node: ast.AST, *, is_declaration: bool = False, ) -> Optional[Class]: with self.error_context(node): klass = self.visit(node) if not isinstance(klass, Class): return None if self.subscr_nesting or not is_declaration: if isinstance(klass, FinalClass): raise TypedSyntaxError( "Final annotation is only valid in initial declaration " "of attribute or module-level constant", ) if isinstance(klass, ClassVar): raise TypedSyntaxError( "ClassVar is allowed only in class attribute annotations. " "Class Finals are inferred ClassVar; do not nest with Final." ) if isinstance(klass, ExactClass): klass = klass.unwrap().exact_type() elif isinstance(klass, FinalClass): pass else: klass = klass.inexact_type() # PEP-484 specifies that ints should be treated as a subclass of floats, # even though they differ in the runtime. We need to maintain the distinction # between the two internally, so we should view user-specified `float` annotations # as `float | int`. This widening of the type prevents us from applying # optimizations to user-specified floats, but does not affect ints. Since we # don't optimize Python floats anyway, we accept this to maintain PEP-484 compatibility. if klass.unwrap() is self.type_env.float: klass = self.compiler.type_env.get_union( (self.type_env.float, self.type_env.int) ) # TODO until we support runtime checking of unions, we must for # safety resolve union annotations to dynamic (except for # optionals, which we can check at runtime) if ( isinstance(klass, UnionType) and klass is not self.type_env.union and klass is not self.type_env.optional and klass.opt_type is None ): return None return klass def visitSubscript(self, node: Subscript) -> Optional[Value]: target = self.resolve_annotation(node.value, is_declaration=True) if target is None: return None self.subscr_nesting += 1 slice = self.visit(node.slice) or self.type_env.DYNAMIC self.subscr_nesting -= 1 return target.resolve_subscr(node, slice, self) or target def visitBinOp(self, node: BinOp) -> Optional[Value]: if isinstance(node.op, ast.BitOr): ltype = self.resolve_annotation(node.left) rtype = self.resolve_annotation(node.right) if ltype is None or rtype is None: return None return self.module.compiler.type_env.get_union((ltype, rtype)) def visitConstant(self, node: Constant) -> Optional[Value]: sval = node.value if sval is None: return self.type_env.none elif isinstance(sval, str): n = cast(Expression, ast.parse(node.value, "", "eval")).body return self.visit(n) class ModuleTable: def __init__( self, name: str, filename: str, compiler: Compiler, members: Optional[Dict[str, Value]] = None, ) -> None: self.name = name self.filename = filename self.children: Dict[str, Value] = members or {} self.compiler = compiler self.types: Dict[AST, Value] = {} self.node_data: Dict[Tuple[AST, object], object] = {} self.flags: Set[ModuleFlag] = set() self.decls: List[Tuple[AST, Optional[str], Optional[Value]]] = [] self.compile_non_static: Set[AST] = set() # (source module, source name) for every name this module imports-from # another static module at top level self.imported_from: Dict[str, Tuple[str, str]] = {} # TODO: final constants should be typed to literals, and # this should be removed in the future self.named_finals: Dict[str, ast.Constant] = {} # Have we completed our first pass through the module, populating # imports and types defined in the module? Until we have, resolving # type annotations is not safe. self.first_pass_done = False self.ann_visitor = AnnotationVisitor(self) self.ref_visitor = ReferenceVisitor(self) def syntax_error(self, msg: str, node: AST) -> None: return self.compiler.error_sink.syntax_error(msg, self.filename, node) def error_context(self, node: Optional[AST]) -> ContextManager[None]: if node is None: return nullcontext() return self.compiler.error_sink.error_context(self.filename, node) def declare_class(self, node: ClassDef, klass: Class) -> None: self.decls.append((node, node.name, klass)) self.children[node.name] = klass def declare_function(self, func: Function) -> None: existing = self.children.get(func.func_name) new_member = func if existing is not None: if isinstance(existing, Function): new_member = FunctionGroup([existing, new_member], func.klass.type_env) elif isinstance(existing, FunctionGroup): existing.functions.append(new_member) new_member = existing else: raise TypedSyntaxError( f"function conflicts with other member {func.func_name} in {self.name}" ) self.decls.append((func.node, func.func_name, new_member)) self.children[func.func_name] = new_member def _get_inferred_type(self, value: ast.expr) -> Optional[Value]: if not isinstance(value, ast.Name): return None return self.children.get(value.id) def finish_bind(self) -> None: self.first_pass_done = True for node, name, value in self.decls: with self.error_context(node): if value is not None: assert name is not None new_value = value.finish_bind(self, None) if new_value is None: del self.children[name] elif new_value is not value: self.children[name] = new_value if isinstance(node, ast.AnnAssign): typ = self.resolve_annotation(node.annotation, is_declaration=True) if typ is not None: # Special case Final[dynamic] to use inferred type. target = node.target instance = typ.instance value = node.value if ( value is not None and isinstance(typ, FinalClass) and isinstance(typ.unwrap(), DynamicClass) ): instance = self._get_inferred_type(value) or instance if isinstance(target, ast.Name): self.children[target.id] = instance if isinstance(typ, FinalClass): target = node.target value = node.value if not value: raise TypedSyntaxError( "Must assign a value when declaring a Final" ) elif ( not isinstance(typ, CType) and isinstance(target, ast.Name) and isinstance(value, ast.Constant) ): self.named_finals[target.id] = value # We don't need these anymore... self.decls.clear() def resolve_type(self, node: ast.AST) -> Optional[Class]: typ = self.ann_visitor.visit(node) if isinstance(typ, Class): return typ def resolve_decorator(self, node: ast.AST) -> Optional[Value]: if isinstance(node, Call): func = self.ref_visitor.visit(node.func) if isinstance(func, Class): return func.instance elif isinstance(func, Callable): return func.return_type.resolved().instance elif isinstance(func, MethodType): return func.function.return_type.resolved().instance return self.ref_visitor.visit(node) def resolve_annotation( self, node: ast.AST, *, is_declaration: bool = False, ) -> Optional[Class]: assert self.first_pass_done, ( "Type annotations cannot be resolved until after initial pass, " "so that all imports and types are available." ) return self.ann_visitor.resolve_annotation(node, is_declaration=is_declaration) def resolve_name_with_descr( self, name: str ) -> Tuple[Optional[Value], Optional[TypeDescr]]: if val := self.children.get(name): return val, (self.name, name) elif val := self.compiler.builtins.children.get(name): return val, None return None, None def resolve_name(self, name: str) -> Optional[Value]: return self.resolve_name_with_descr(name)[0] def get_final_literal(self, node: AST, scope: Scope) -> Optional[ast.Constant]: if not isinstance(node, Name): return None final_val = self.named_finals.get(node.id, None) if ( final_val is not None and isinstance(node.ctx, ast.Load) and ( # Ensure the name is not shadowed in the local scope isinstance(scope, ModuleScope) or node.id not in scope.defs ) ): return final_val def declare_import( self, name: str, source: Tuple[str, str] | None, val: Value ) -> None: """Declare a name imported into this module. `name` is the name in this module's namespace. `source` is a (str, str) tuple of (source_module, source_name) for an `import from`. For a top-level module import, `source` should be `None`. """ self.children[name] = val if source is not None: self.imported_from[name] = source def declare_variable(self, node: ast.AnnAssign, module: ModuleTable) -> None: self.decls.append((node, None, None)) def declare_variables(self, node: ast.Assign, module: ModuleTable) -> None: pass