library/compiler/static.py (5,776 lines of code) (raw):
# Copyright (c) Facebook, Inc. and its affiliates. (http://www.facebook.com)
from __future__ import annotations
import ast
import linecache
import sys
from ast import (
AST,
And,
AnnAssign,
Assign,
AsyncFor,
AsyncFunctionDef,
AsyncWith,
Attribute,
AugAssign,
Await,
BinOp,
BoolOp,
Bytes,
Call,
ClassDef,
Compare,
Constant,
DictComp,
Ellipsis,
For,
FormattedValue,
FunctionDef,
GeneratorExp,
If,
IfExp,
Import,
ImportFrom,
Index,
Is,
IsNot,
JoinedStr,
Lambda,
ListComp,
Module,
Name,
NameConstant,
Num,
Return,
SetComp,
Slice,
Starred,
Str,
Subscript,
Try,
UnaryOp,
While,
With,
Yield,
YieldFrom,
cmpop,
expr,
)
from contextlib import contextmanager, nullcontext
from enum import IntEnum
from functools import partial
from types import BuiltinFunctionType, CodeType, MethodDescriptorType
from typing import (
Callable as typingCallable,
Collection,
Dict,
Generator,
Generic,
Iterable,
List,
Mapping,
NoReturn,
Optional,
Sequence,
Set,
Tuple,
Type,
TypeVar,
Union,
cast,
)
from __static__ import chkdict # pyre-ignore[21]: unknown module
from _static import ( # pyre-fixme[21]: Could not find module `_static`.
TYPED_BOOL,
TYPED_INT_8BIT,
TYPED_INT_16BIT,
TYPED_INT_32BIT,
TYPED_INT_64BIT,
TYPED_OBJECT,
TYPED_ARRAY,
TYPED_INT_UNSIGNED,
TYPED_INT_SIGNED,
TYPED_INT8,
TYPED_INT16,
TYPED_INT32,
TYPED_INT64,
TYPED_UINT8,
TYPED_UINT16,
TYPED_UINT32,
TYPED_UINT64,
SEQ_LIST,
SEQ_TUPLE,
SEQ_LIST_INEXACT,
SEQ_ARRAY_INT8,
SEQ_ARRAY_INT16,
SEQ_ARRAY_INT32,
SEQ_ARRAY_INT64,
SEQ_ARRAY_UINT8,
SEQ_ARRAY_UINT16,
SEQ_ARRAY_UINT32,
SEQ_ARRAY_UINT64,
SEQ_SUBSCR_UNCHECKED,
SEQ_REPEAT_INEXACT_SEQ,
SEQ_REPEAT_INEXACT_NUM,
SEQ_REPEAT_REVERSED,
SEQ_REPEAT_PRIMITIVE_NUM,
PRIM_OP_EQ_INT,
PRIM_OP_NE_INT,
PRIM_OP_LT_INT,
PRIM_OP_LE_INT,
PRIM_OP_GT_INT,
PRIM_OP_GE_INT,
PRIM_OP_LT_UN_INT,
PRIM_OP_LE_UN_INT,
PRIM_OP_GT_UN_INT,
PRIM_OP_GE_UN_INT,
PRIM_OP_ADD_INT,
PRIM_OP_SUB_INT,
PRIM_OP_MUL_INT,
PRIM_OP_DIV_INT,
PRIM_OP_DIV_UN_INT,
PRIM_OP_MOD_INT,
PRIM_OP_MOD_UN_INT,
PRIM_OP_LSHIFT_INT,
PRIM_OP_RSHIFT_INT,
PRIM_OP_RSHIFT_UN_INT,
PRIM_OP_XOR_INT,
PRIM_OP_OR_INT,
PRIM_OP_AND_INT,
PRIM_OP_NEG_INT,
PRIM_OP_INV_INT,
PRIM_OP_ADD_DBL,
PRIM_OP_SUB_DBL,
PRIM_OP_MUL_DBL,
PRIM_OP_DIV_DBL,
PRIM_OP_MOD_DBL,
PROM_OP_POW_DBL,
FAST_LEN_INEXACT,
FAST_LEN_LIST,
FAST_LEN_DICT,
FAST_LEN_SET,
FAST_LEN_TUPLE,
FAST_LEN_ARRAY,
FAST_LEN_STR,
TYPED_DOUBLE,
RAND_MAX,
rand,
)
from . import symbols, opcode38static
from .consts import SC_LOCAL, SC_GLOBAL_EXPLICIT, SC_GLOBAL_IMPLICIT
from .opcodebase import Opcode
from .optimizer import AstOptimizer
from .pyassem import Block, PyFlowGraph, PyFlowGraphCinder, IndexedSet
from .pycodegen import (
AugAttribute,
AugName,
AugSubscript,
CodeGenerator,
CinderCodeGenerator,
Delegator,
compile,
wrap_aug,
FOR_LOOP,
)
from .symbols import Scope, SymbolVisitor, ModuleScope, ClassScope
from .unparse import to_expr
from .visitor import ASTVisitor, ASTRewriter, TAst
try:
import xxclassloader # pyre-ignore[21]: unknown module
from xxclassloader import spamobj
except ImportError:
spamobj = None
def exec_static(
source: str,
locals: Dict[str, object],
globals: Dict[str, object],
modname: str = "<module>",
) -> None:
code = compile(
source, "<module>", "exec", compiler=StaticCodeGenerator, modname=modname
)
exec(code, locals, globals) # noqa: P204
CBOOL_TYPE: CIntType
INT8_TYPE: CIntType
INT16_TYPE: CIntType
INT32_TYPE: CIntType
INT64_TYPE: CIntType
INT64_VALUE: CIntInstance
SIGNED_CINT_TYPES: Sequence[CIntType]
INT_TYPE: NumClass
INT_EXACT_TYPE: NumClass
FLOAT_TYPE: NumClass
COMPLEX_TYPE: NumClass
BOOL_TYPE: Class
ARRAY_TYPE: Class
DICT_TYPE: Class
LIST_TYPE: Class
TUPLE_TYPE: Class
SET_TYPE: Class
OBJECT_TYPE: Class
OBJECT: Value
DYNAMIC_TYPE: DynamicClass
DYNAMIC: DynamicInstance
FUNCTION_TYPE: Class
METHOD_TYPE: Class
MEMBER_TYPE: Class
NONE_TYPE: Class
TYPE_TYPE: Class
ARG_TYPE: Class
SLICE_TYPE: Class
CHAR_TYPE: CIntType
DOUBLE_TYPE: CDoubleType
# Prefix for temporary var names. It's illegal in normal
# Python, so there's no chance it will ever clash with a
# user defined name.
_TMP_VAR_PREFIX = "_pystatic_.0._tmp__"
CMPOP_SIGILS: Mapping[Type[cmpop], str] = {
ast.Lt: "<",
ast.Gt: ">",
ast.Eq: "==",
ast.NotEq: "!=",
ast.LtE: "<=",
ast.GtE: ">=",
ast.Is: "is",
ast.IsNot: "is",
}
def syntax_error(msg: str, filename: str, node: AST) -> TypedSyntaxError:
lineno, offset, source_line = error_location(filename, node)
return TypedSyntaxError(msg, (filename, lineno, offset, source_line))
def error_location(filename: str, node: AST) -> Tuple[int, int, Optional[str]]:
source_line = linecache.getline(filename, node.lineno)
return (node.lineno, node.col_offset, source_line or None)
@contextmanager
def error_context(filename: str, node: AST) -> Generator[None, None, None]:
"""Add error location context to any TypedSyntaxError raised in with block."""
try:
yield
except TypedSyntaxError as exc:
if exc.filename is None:
exc.filename = filename
if (exc.lineno, exc.offset) == (None, None):
exc.lineno, exc.offset, exc.text = error_location(filename, node)
raise
class TypeRef:
"""Stores unresolved typed references, capturing the referring module
as well as the annotation"""
def __init__(self, module: ModuleTable, ref: ast.expr) -> None:
self.module = module
self.ref = ref
def resolved(self, is_declaration: bool = False) -> Class:
res = self.module.resolve_annotation(self.ref, is_declaration=is_declaration)
if res is None:
return DYNAMIC_TYPE
return res
def __repr__(self) -> str:
return f"TypeRef({self.module.name}, {ast.dump(self.ref)})"
class ResolvedTypeRef(TypeRef):
def __init__(self, type: Class) -> None:
self._resolved = type
def resolved(self, is_declaration: bool = False) -> Class:
return self._resolved
def __repr__(self) -> str:
return f"ResolvedTypeRef({self.resolved()})"
# Pyre doesn't support recursive generics, so we can't represent the recursively
# nested tuples that make up a type_descr. Fortunately we don't need to, since
# we don't parse them in Python, we just generate them and emit them as
# constants. So just call them `Tuple[object, ...]`
TypeDescr = Tuple[object, ...]
class TypeName:
def __init__(self, module: str, name: str) -> None:
self.module = module
self.name = name
@property
def type_descr(self) -> TypeDescr:
"""The metadata emitted into the const pool to describe a type.
For normal types this is just the fully qualified type name as a tuple
('mypackage', 'mymod', 'C'). For optional types we have an extra '?'
element appended. For generic types we append a tuple of the generic
args' type_descrs.
"""
return (self.module, self.name)
@property
def friendly_name(self) -> str:
if self.module and self.module not in ("builtins", "__static__", "typing"):
return f"{self.module}.{self.name}"
return self.name
class GenericTypeName(TypeName):
def __init__(self, module: str, name: str, args: Tuple[Class, ...]) -> None:
super().__init__(module, name)
self.args = args
@property
def type_descr(self) -> TypeDescr:
gen_args: List[TypeDescr] = []
for arg in self.args:
gen_args.append(arg.type_descr)
return (self.module, self.name, tuple(gen_args))
@property
def friendly_name(self) -> str:
args = ", ".join(arg.instance.name for arg in self.args)
return f"{super().friendly_name}[{args}]"
GenericTypeIndex = Tuple["Class", ...]
GenericTypesDict = Dict["Class", Dict[GenericTypeIndex, "Class"]]
class SymbolTable:
def __init__(self) -> None:
self.modules: Dict[str, ModuleTable] = {}
builtins_children = {
"object": OBJECT_TYPE,
"type": TYPE_TYPE,
"None": NONE_TYPE.instance,
"int": INT_EXACT_TYPE,
"complex": COMPLEX_EXACT_TYPE,
"str": STR_EXACT_TYPE,
"bytes": BYTES_TYPE,
"bool": BOOL_TYPE,
"float": FLOAT_EXACT_TYPE,
"len": LenFunction(FUNCTION_TYPE, boxed=True),
"min": ExtremumFunction(FUNCTION_TYPE, is_min=True),
"max": ExtremumFunction(FUNCTION_TYPE, is_min=False),
"list": LIST_EXACT_TYPE,
"tuple": TUPLE_EXACT_TYPE,
"set": SET_EXACT_TYPE,
"sorted": SortedFunction(FUNCTION_TYPE),
"Exception": EXCEPTION_TYPE,
"BaseException": BASE_EXCEPTION_TYPE,
"isinstance": IsInstanceFunction(),
"issubclass": IsSubclassFunction(),
"staticmethod": STATIC_METHOD_TYPE,
"reveal_type": RevealTypeFunction(),
}
strict_builtins = StrictBuiltins(builtins_children)
typing_children = {
# TODO: Need typed members for dict
"Dict": DICT_TYPE,
"List": LIST_TYPE,
"Final": FINAL_TYPE,
"final": FINAL_METHOD_TYPE,
"NamedTuple": NAMED_TUPLE_TYPE,
"Optional": OPTIONAL_TYPE,
"Union": UNION_TYPE,
"Tuple": TUPLE_TYPE,
"TYPE_CHECKING": BOOL_TYPE.instance,
}
builtins_children["<builtins>"] = strict_builtins
builtins_children["<fixed-modules>"] = StrictBuiltins(
{"typing": StrictBuiltins(typing_children)}
)
self.builtins = self.modules["builtins"] = ModuleTable(
"builtins",
"<builtins>",
self,
builtins_children,
)
self.typing = self.modules["typing"] = ModuleTable(
"typing", "<typing>", self, typing_children
)
self.statics = self.modules["__static__"] = ModuleTable(
"__static__",
"<__static__>",
self,
{
"Array": ARRAY_EXACT_TYPE,
"CheckedDict": CHECKED_DICT_EXACT_TYPE,
"allow_weakrefs": ALLOW_WEAKREFS_TYPE,
"box": BoxFunction(FUNCTION_TYPE),
"cast": CastFunction(FUNCTION_TYPE),
"clen": LenFunction(FUNCTION_TYPE, boxed=False),
"dynamic_return": DYNAMIC_RETURN_TYPE,
"size_t": UINT64_TYPE,
"ssize_t": INT64_TYPE,
"cbool": CBOOL_TYPE,
"inline": INLINE_TYPE,
# This is a way to disable the static compiler for
# individual functions/methods
"_donotcompile": DONOTCOMPILE_TYPE,
"int8": INT8_TYPE,
"int16": INT16_TYPE,
"int32": INT32_TYPE,
"int64": INT64_TYPE,
"uint8": UINT8_TYPE,
"uint16": UINT16_TYPE,
"uint32": UINT32_TYPE,
"uint64": UINT64_TYPE,
"char": CHAR_TYPE,
"double": DOUBLE_TYPE,
"unbox": UnboxFunction(FUNCTION_TYPE),
"nonchecked_dicts": BOOL_TYPE.instance,
"pydict": DICT_TYPE,
"PyDict": DICT_TYPE,
"Vector": VECTOR_TYPE,
"RAND_MAX": NumClass(
TypeName("builtins", "int"), pytype=int, literal_value=RAND_MAX
).instance,
"rand": reflect_builtin_function(rand),
},
)
if SPAM_OBJ is not None:
self.modules["xxclassloader"] = ModuleTable(
"xxclassloader",
"<xxclassloader>",
self,
{
"spamobj": SPAM_OBJ,
"XXGeneric": XX_GENERIC_TYPE,
"foo": reflect_builtin_function(xxclassloader.foo),
"bar": reflect_builtin_function(xxclassloader.bar),
"neg": reflect_builtin_function(xxclassloader.neg),
},
)
# We need to clone the dictionaries for each type so that as we populate
# generic instantations that we don't store them in the global dict for
# built-in types
self.generic_types: GenericTypesDict = {
k: dict(v) for k, v in BUILTIN_GENERICS.items()
}
def __getitem__(self, name: str) -> ModuleTable:
return self.modules[name]
def __setitem__(self, name: str, value: ModuleTable) -> None:
self.modules[name] = value
def add_module(self, name: str, filename: str, tree: AST) -> None:
decl_visit = DeclarationVisitor(name, filename, self)
decl_visit.visit(tree)
decl_visit.finish_bind()
def compile(
self, name: str, filename: str, tree: AST, optimize: int = 0
) -> CodeType:
if name not in self.modules:
self.add_module(name, filename, tree)
tree = AstOptimizer(optimize=optimize > 0).visit(tree)
# Analyze variable scopes
s = SymbolVisitor()
s.visit(tree)
# Analyze the types of objects within local scopes
type_binder = TypeBinder(s, filename, self, name, optimize)
type_binder.visit(tree)
# Compile the code w/ the static compiler
graph = StaticCodeGenerator.flow_graph(
name, filename, s.scopes[tree], peephole_enabled=True
)
graph.setFlag(StaticCodeGenerator.consts.CO_STATICALLY_COMPILED)
code_gen = StaticCodeGenerator(
None, tree, s, graph, self, name, flags=0, optimization_lvl=optimize
)
code_gen.visit(tree)
return code_gen.getCode()
def import_module(self, name: str) -> None:
pass
TType = TypeVar("TType")
class ModuleTable:
def __init__(
self,
name: str,
filename: str,
symtable: SymbolTable,
members: Optional[Dict[str, Value]] = None,
) -> None:
self.name = name
self.filename = filename
self.children: Dict[str, Value] = members or {}
self.symtable = symtable
self.types: Dict[Union[AST, Delegator], Value] = {}
self.node_data: Dict[Tuple[Union[AST, Delegator], object], object] = {}
self.nonchecked_dicts = False
self.noframe = False
self.decls: List[Tuple[AST, Optional[Value]]] = []
# TODO: final constants should be typed to literals, and
# this should be removed in the future
self.named_finals: Dict[str, ast.Constant] = {}
# Functions in this module that have been decorated with
# `dynamic_return`. We actually store their `.args` node in here, not
# the `FunctionDef` node itself, since strict modules rewriter will
# replace the latter in between decls visit and type binding / codegen.
self.dynamic_returns: Set[ast.AST] = set()
# Have we completed our first pass through the module, populating
# imports and types defined in the module? Until we have, resolving
# type annotations is not safe.
self.first_pass_done = False
def finish_bind(self) -> None:
self.first_pass_done = True
for node, value in self.decls:
with error_context(self.filename, node):
if value is not None:
value.finish_bind(self)
elif isinstance(node, ast.AnnAssign):
typ = self.resolve_annotation(node.annotation, is_declaration=True)
if isinstance(typ, FinalClass):
target = node.target
value = node.value
if not value:
raise TypedSyntaxError(
"Must assign a value when declaring a Final"
)
elif (
not isinstance(typ, CType)
and isinstance(target, ast.Name)
and isinstance(value, ast.Constant)
):
self.named_finals[target.id] = value
# We don't need these anymore...
self.decls.clear()
def resolve_type(self, node: ast.AST) -> Optional[Class]:
# TODO handle Call
return self._resolve(node, self.resolve_type)
def _resolve(
self,
node: ast.AST,
_resolve: typingCallable[[ast.AST], Optional[Class]],
_resolve_subscr_target: Optional[
typingCallable[[ast.AST], Optional[Class]]
] = None,
) -> Optional[Class]:
if isinstance(node, ast.Name):
res = self.resolve_name(node.id)
if isinstance(res, Class):
return res
elif isinstance(node, Subscript):
slice = node.slice
if isinstance(slice, Index):
val = (_resolve_subscr_target or _resolve)(node.value)
if val is not None:
value = slice.value
if isinstance(value, ast.Tuple):
anns = []
for elt in value.elts:
ann = _resolve(elt) or DYNAMIC_TYPE
anns.append(ann)
values = tuple(anns)
gen = val.make_generic_type(values, self.symtable.generic_types)
return gen or val
else:
index = _resolve(value) or DYNAMIC_TYPE
gen = val.make_generic_type(
(index,), self.symtable.generic_types
)
return gen or val
# TODO handle Attribute
def resolve_annotation(
self,
node: ast.AST,
is_declaration: bool = False,
) -> Optional[Class]:
assert self.first_pass_done, (
"Type annotations cannot be resolved until after initial pass, "
"so that all imports and types are available."
)
with error_context(self.filename, node):
klass = self._resolve_annotation(node)
if isinstance(klass, FinalClass) and not is_declaration:
raise TypedSyntaxError(
"Final annotation is only valid in initial declaration "
"of attribute or module-level constant",
)
# TODO until we support runtime checking of unions, we must for
# safety resolve union annotations to dynamic (except for
# optionals, which we can check at runtime)
if (
isinstance(klass, UnionType)
and klass is not UNION_TYPE
and klass is not OPTIONAL_TYPE
and klass.opt_type is None
):
return None
# Even if we know that e.g. `builtins.str` is the exact `str` type and
# not a subclass, and it's useful to track that knowledge, when we
# annotate `x: str` that annotation should not exclude subclasses.
return inexact_type(klass) if klass else None
def _resolve_annotation(self, node: ast.AST) -> Optional[Class]:
# First try to resolve non-annotation-specific forms. For resolving the
# outer target of a subscript (e.g. `Final` in `Final[int]`) we pass
# `is_declaration=True` to allow `Final` in that position; if in fact
# we are not resolving a declaration, the outer `resolve_annotation`
# (our caller) will still catch the generic Final that we end up
# returning.
typ = self._resolve(
node,
self.resolve_annotation,
_resolve_subscr_target=partial(
self.resolve_annotation, is_declaration=True
),
)
if typ:
return typ
elif isinstance(node, ast.Str):
# pyre-ignore[16]: `AST` has no attribute `body`.
return self.resolve_annotation(ast.parse(node.s, "", "eval").body)
elif isinstance(node, ast.Constant):
sval = node.value
if sval is None:
return NONE_TYPE
elif isinstance(sval, str):
return self.resolve_annotation(ast.parse(node.value, "", "eval").body)
elif isinstance(node, NameConstant) and node.value is None:
return NONE_TYPE
elif isinstance(node, ast.BinOp) and isinstance(node.op, ast.BitOr):
ltype = self.resolve_annotation(node.left)
rtype = self.resolve_annotation(node.right)
if ltype is None or rtype is None:
return None
return UNION_TYPE.make_generic_type(
(ltype, rtype), self.symtable.generic_types
)
def resolve_name(self, name: str) -> Optional[Value]:
return self.children.get(name) or self.symtable.builtins.children.get(name)
def get_final_literal(self, node: AST, scope: Scope) -> Optional[ast.Constant]:
if not isinstance(node, Name):
return None
final_val = self.named_finals.get(node.id, None)
if (
final_val is not None
and isinstance(node.ctx, ast.Load)
and (
# Ensure the name is not shadowed in the local scope
isinstance(scope, ModuleScope)
or node.id not in scope.defs
)
):
return final_val
TClass = TypeVar("TClass", bound="Class", covariant=True)
TClassInv = TypeVar("TClassInv", bound="Class")
class Value:
"""base class for all values tracked at compile time."""
def __init__(self, klass: Class) -> None:
"""name: the name of the value, for instances this is used solely for
debug/reporting purposes. In Class subclasses this will be the
qualified name (e.g. module.Foo).
klass: the Class of this object"""
self.klass = klass
@property
def name(self) -> str:
return type(self).__name__
def finish_bind(self, module: ModuleTable) -> None:
pass
def make_generic_type(
self, index: GenericTypeIndex, generic_types: GenericTypesDict
) -> Optional[Class]:
pass
def get_iter_type(self, node: ast.expr, visitor: TypeBinder) -> Value:
"""returns the type that is produced when iterating over this value"""
raise visitor.syntax_error(f"cannot iterate over {self.name}", node)
def as_oparg(self) -> int:
raise TypeError(f"{self.name} not valid here")
def bind_attr(
self, node: ast.Attribute, visitor: TypeBinder, type_ctx: Optional[Class]
) -> None:
raise visitor.syntax_error(f"cannot load attribute from {self.name}", node)
def bind_call(
self, node: ast.Call, visitor: TypeBinder, type_ctx: Optional[Class]
) -> NarrowingEffect:
raise visitor.syntax_error(f"cannot call {self.name}", node)
def check_args_for_primitives(self, node: ast.Call, visitor: TypeBinder) -> None:
for arg in node.args:
if isinstance(visitor.get_type(arg), CInstance):
raise visitor.syntax_error("Call argument cannot be a primitive", arg)
for arg in node.keywords:
if isinstance(visitor.get_type(arg.value), CInstance):
raise visitor.syntax_error(
"Call argument cannot be a primitive", arg.value
)
def bind_descr_get(
self,
node: ast.Attribute,
inst: Optional[Object[TClassInv]],
ctx: TClassInv,
visitor: TypeBinder,
type_ctx: Optional[Class],
) -> None:
raise visitor.syntax_error(f"cannot get descriptor {self.name}", node)
def bind_decorate_function(
self, visitor: DeclarationVisitor, fn: Function | StaticMethod
) -> Optional[Value]:
return None
def bind_decorate_class(self, klass: Class) -> Class:
return DYNAMIC_TYPE
def bind_subscr(
self, node: ast.Subscript, type: Value, visitor: TypeBinder
) -> None:
raise visitor.syntax_error(f"cannot index {self.name}", node)
def emit_subscr(
self, node: ast.Subscript, aug_flag: bool, code_gen: Static38CodeGenerator
) -> None:
code_gen.defaultVisit(node, aug_flag)
def emit_store_subscr(
self, node: ast.Subscript, code_gen: Static38CodeGenerator
) -> None:
code_gen.emit("ROT_THREE")
code_gen.emit("STORE_SUBSCR")
def emit_call(self, node: ast.Call, code_gen: Static38CodeGenerator) -> None:
code_gen.defaultVisit(node)
def emit_attr(
self, node: Union[ast.Attribute, AugAttribute], code_gen: Static38CodeGenerator
) -> None:
if isinstance(node.ctx, ast.Store):
code_gen.emit("STORE_ATTR", code_gen.mangle(node.attr))
elif isinstance(node.ctx, ast.Del):
code_gen.emit("DELETE_ATTR", code_gen.mangle(node.attr))
else:
code_gen.emit("LOAD_ATTR", code_gen.mangle(node.attr))
def bind_compare(
self,
node: ast.Compare,
left: expr,
op: cmpop,
right: expr,
visitor: TypeBinder,
type_ctx: Optional[Class],
) -> bool:
raise visitor.syntax_error(f"cannot compare with {self.name}", node)
def bind_reverse_compare(
self,
node: ast.Compare,
left: expr,
op: cmpop,
right: expr,
visitor: TypeBinder,
type_ctx: Optional[Class],
) -> bool:
raise visitor.syntax_error(f"cannot reverse with {self.name}", node)
def emit_compare(self, op: cmpop, code_gen: Static38CodeGenerator) -> None:
code_gen.defaultEmitCompare(op)
def bind_binop(
self, node: ast.BinOp, visitor: TypeBinder, type_ctx: Optional[Class]
) -> bool:
raise visitor.syntax_error(f"cannot bin op with {self.name}", node)
def bind_reverse_binop(
self, node: ast.BinOp, visitor: TypeBinder, type_ctx: Optional[Class]
) -> bool:
raise visitor.syntax_error(f"cannot reverse bin op with {self.name}", node)
def bind_unaryop(
self, node: ast.UnaryOp, visitor: TypeBinder, type_ctx: Optional[Class]
) -> None:
raise visitor.syntax_error(f"cannot reverse unary op with {self.name}", node)
def emit_binop(self, node: ast.BinOp, code_gen: Static38CodeGenerator) -> None:
code_gen.defaultVisit(node)
def emit_forloop(self, node: ast.For, code_gen: Static38CodeGenerator) -> None:
start = code_gen.newBlock("default_forloop_start")
anchor = code_gen.newBlock("default_forloop_anchor")
after = code_gen.newBlock("default_forloop_after")
code_gen.set_lineno(node)
code_gen.push_loop(FOR_LOOP, start, after)
code_gen.visit(node.iter)
code_gen.emit("GET_ITER")
code_gen.nextBlock(start)
code_gen.emit("FOR_ITER", anchor)
code_gen.visit(node.target)
code_gen.visit(node.body)
code_gen.emit("JUMP_ABSOLUTE", start)
code_gen.nextBlock(anchor)
code_gen.pop_loop()
if node.orelse:
code_gen.visit(node.orelse)
code_gen.nextBlock(after)
def emit_unaryop(self, node: ast.UnaryOp, code_gen: Static38CodeGenerator) -> None:
code_gen.defaultVisit(node)
def emit_augassign(
self, node: ast.AugAssign, code_gen: Static38CodeGenerator
) -> None:
code_gen.defaultVisit(node)
def emit_augname(
self, node: AugName, code_gen: Static38CodeGenerator, mode: str
) -> None:
code_gen.defaultVisit(node, mode)
def bind_constant(self, node: ast.Constant, visitor: TypeBinder) -> None:
raise visitor.syntax_error(f"cannot constant with {self.name}", node)
def emit_constant(
self, node: ast.Constant, code_gen: Static38CodeGenerator
) -> None:
return code_gen.defaultVisit(node)
def emit_name(self, node: ast.Name, code_gen: Static38CodeGenerator) -> None:
return code_gen.defaultVisit(node)
def emit_jumpif(
self, test: AST, next: Block, is_if_true: bool, code_gen: Static38CodeGenerator
) -> None:
CinderCodeGenerator.compileJumpIf(code_gen, test, next, is_if_true)
def emit_jumpif_pop(
self, test: AST, next: Block, is_if_true: bool, code_gen: Static38CodeGenerator
) -> None:
CinderCodeGenerator.compileJumpIfPop(code_gen, test, next, is_if_true)
def emit_box(self, node: expr, code_gen: Static38CodeGenerator) -> None:
raise RuntimeError(f"Unsupported box type: {code_gen.get_type(node)}")
def emit_unbox(self, node: expr, code_gen: Static38CodeGenerator) -> None:
raise RuntimeError("Unsupported unbox type")
def get_fast_len_type(self) -> Optional[int]:
return None
def emit_len(
self, node: ast.Call, code_gen: Static38CodeGenerator, boxed: bool
) -> None:
if not boxed:
raise RuntimeError("Unsupported type for clen()")
return self.emit_call(node, code_gen)
def make_generic(
self, new_type: Class, name: GenericTypeName, generic_types: GenericTypesDict
) -> Value:
return self
def emit_convert(self, to_type: Value, code_gen: Static38CodeGenerator) -> None:
pass
class Object(Value, Generic[TClass]):
"""Represents an instance of a type at compile time"""
klass: TClass
@property
def name(self) -> str:
return self.klass.instance_name
def as_oparg(self) -> int:
return TYPED_OBJECT
def bind_call(
self, node: ast.Call, visitor: TypeBinder, type_ctx: Optional[Class]
) -> NarrowingEffect:
visitor.set_type(node, DYNAMIC)
for arg in node.args:
visitor.visit(arg)
for arg in node.keywords:
visitor.visit(arg.value)
self.check_args_for_primitives(node, visitor)
return NO_EFFECT
def bind_attr(
self, node: ast.Attribute, visitor: TypeBinder, type_ctx: Optional[Class]
) -> None:
for base in self.klass.mro:
member = base.members.get(node.attr)
if member is not None:
member.bind_descr_get(node, self, self.klass, visitor, type_ctx)
return
if node.attr == "__class__":
visitor.set_type(node, self.klass)
else:
visitor.set_type(node, DYNAMIC)
def emit_attr(
self, node: Union[ast.Attribute, AugAttribute], code_gen: Static38CodeGenerator
) -> None:
for base in self.klass.mro:
member = base.members.get(node.attr)
if member is not None and isinstance(member, Slot):
type_descr = member.container_type.type_descr
type_descr += (member.slot_name,)
if isinstance(node.ctx, ast.Store):
code_gen.emit("STORE_FIELD", type_descr)
elif isinstance(node.ctx, ast.Del):
code_gen.emit("DELETE_ATTR", node.attr)
else:
code_gen.emit("LOAD_FIELD", type_descr)
return
super().emit_attr(node, code_gen)
def bind_descr_get(
self,
node: ast.Attribute,
inst: Optional[Object[TClass]],
ctx: Class,
visitor: TypeBinder,
type_ctx: Optional[Class],
) -> None:
visitor.set_type(node, DYNAMIC)
def bind_subscr(
self, node: ast.Subscript, type: Value, visitor: TypeBinder
) -> None:
visitor.check_can_assign_from(DYNAMIC_TYPE, type.klass, node)
visitor.set_type(node, DYNAMIC)
def bind_compare(
self,
node: ast.Compare,
left: expr,
op: cmpop,
right: expr,
visitor: TypeBinder,
type_ctx: Optional[Class],
) -> bool:
visitor.set_type(op, DYNAMIC)
visitor.set_type(node, DYNAMIC)
return False
def bind_reverse_compare(
self,
node: ast.Compare,
left: expr,
op: cmpop,
right: expr,
visitor: TypeBinder,
type_ctx: Optional[Class],
) -> bool:
visitor.set_type(op, DYNAMIC)
visitor.set_type(node, DYNAMIC)
return False
def bind_binop(
self, node: ast.BinOp, visitor: TypeBinder, type_ctx: Optional[Class]
) -> bool:
return False
def bind_reverse_binop(
self, node: ast.BinOp, visitor: TypeBinder, type_ctx: Optional[Class]
) -> bool:
# we'll set the type in case we're the only one called
visitor.set_type(node, DYNAMIC)
return False
def bind_unaryop(
self, node: ast.UnaryOp, visitor: TypeBinder, type_ctx: Optional[Class]
) -> None:
if isinstance(node.op, ast.Not):
visitor.set_type(node, BOOL_TYPE.instance)
else:
visitor.set_type(node, DYNAMIC)
def bind_constant(self, node: ast.Constant, visitor: TypeBinder) -> None:
node_type = CONSTANT_TYPES[type(node.value)]
visitor.set_type(node, node_type)
visitor.check_can_assign_from(self.klass, node_type.klass, node)
def get_iter_type(self, node: ast.expr, visitor: TypeBinder) -> Value:
"""returns the type that is produced when iterating over this value"""
return DYNAMIC
def __repr__(self) -> str:
return f"<{self.name}>"
class Class(Object["Class"]):
"""Represents a type object at compile time"""
suppress_exact = False
def __init__(
self,
type_name: TypeName,
bases: Optional[List[Class]] = None,
instance: Optional[Value] = None,
klass: Optional[Class] = None,
members: Optional[Dict[str, Value]] = None,
is_exact: bool = False,
pytype: Optional[Type[object]] = None,
) -> None:
super().__init__(klass or TYPE_TYPE)
assert isinstance(bases, (type(None), list))
self.type_name = type_name
self.instance: Value = instance or Object(self)
self.bases: List[Class] = bases or []
self._mro: Optional[List[Class]] = None
self._mro_type_descrs: Optional[Set[TypeDescr]] = None
self.members: Dict[str, Value] = members or {}
self.is_exact = is_exact
self.is_final = False
self.allow_weakrefs = False
self.donotcompile = False
if pytype:
self.members.update(make_type_dict(self, pytype))
# store attempted slot redefinitions during type declaration, for resolution in finish_bind
self._slot_redefs: Dict[str, List[TypeRef]] = {}
@property
def name(self) -> str:
return f"Type[{self.instance_name}]"
@property
def instance_name(self) -> str:
name = self.qualname
if self.is_exact and not self.suppress_exact:
name = f"Exact[{name}]"
return name
@property
def qualname(self) -> str:
return self.type_name.friendly_name
@property
def is_generic_parameter(self) -> bool:
"""Returns True if this Class represents a generic parameter"""
return False
@property
def contains_generic_parameters(self) -> bool:
"""Returns True if this class contains any generic parameters"""
return False
@property
def is_generic_type(self) -> bool:
"""Returns True if this class is a generic type"""
return False
@property
def is_generic_type_definition(self) -> bool:
"""Returns True if this class is a generic type definition.
It'll be a generic type which still has unbound generic type
parameters"""
return False
@property
def generic_type_def(self) -> Optional[Class]:
"""Gets the generic type definition that defined this class"""
return None
def make_generic_type(
self,
index: Tuple[Class, ...],
generic_types: GenericTypesDict,
) -> Optional[Class]:
"""Binds the generic type parameters to a generic type definition"""
return None
def bind_attr(
self, node: ast.Attribute, visitor: TypeBinder, type_ctx: Optional[Class]
) -> None:
for base in self.mro:
member = base.members.get(node.attr)
if member is not None:
member.bind_descr_get(node, None, self, visitor, type_ctx)
return
super().bind_attr(node, visitor, type_ctx)
def bind_binop(
self, node: ast.BinOp, visitor: TypeBinder, type_ctx: Optional[Class]
) -> bool:
if isinstance(node.op, ast.BitOr):
rtype = visitor.get_type(node.right)
if rtype is NONE_TYPE.instance:
rtype = NONE_TYPE
if rtype is DYNAMIC:
rtype = DYNAMIC_TYPE
if not isinstance(rtype, Class):
raise visitor.syntax_error(
f"unsupported operand type(s) for |: {self.name} and {rtype.name}",
node,
)
union = UNION_TYPE.make_generic_type(
(self, rtype), visitor.symtable.generic_types
)
visitor.set_type(node, union)
return True
return super().bind_binop(node, visitor, type_ctx)
@property
def can_be_narrowed(self) -> bool:
return True
@property
def type_descr(self) -> TypeDescr:
return self.type_name.type_descr
def bind_call(
self, node: ast.Call, visitor: TypeBinder, type_ctx: Optional[Class]
) -> NarrowingEffect:
visitor.set_type(node, self.instance)
for arg in node.args:
visitor.visit(arg)
for arg in node.keywords:
visitor.visit(arg.value)
self.check_args_for_primitives(node, visitor)
return NO_EFFECT
def can_assign_from(self, src: Class) -> bool:
"""checks to see if the src value can be assigned to this value. Currently
you can assign a derived type to a base type. You cannot assign a primitive
type to an object type.
At some point we may also support some form of interfaces via protocols if we
implement a more efficient form of interface dispatch than doing the dictionary
lookup for the member."""
return src is self or (
not self.is_exact and not isinstance(src, CType) and self.issubclass(src)
)
def __repr__(self) -> str:
return f"<{self.name} class>"
def isinstance(self, src: Value) -> bool:
return self.issubclass(src.klass)
def issubclass(self, src: Class) -> bool:
return self.type_descr in src.mro_type_descrs
def finish_bind(self, module: ModuleTable) -> None:
for name, new_type_refs in self._slot_redefs.items():
cur_slot = self.members[name]
assert isinstance(cur_slot, Slot)
cur_type = cur_slot.decl_type
if any(tr.resolved() != cur_type for tr in new_type_refs):
raise TypedSyntaxError(
f"conflicting type definitions for slot {name} in {self.name}"
)
self._slot_redefs = {}
inherited = set()
for name, my_value in self.members.items():
for base in self.mro[1:]:
value = base.members.get(name)
if value is not None and type(my_value) != type(value):
# TODO: There's more checking we should be doing to ensure
# this is a compatible override
raise TypedSyntaxError(
f"class cannot hide inherited member: {value!r}"
)
elif isinstance(value, Slot):
inherited.add(name)
elif isinstance(value, (Function, StaticMethod)):
if value.is_final:
raise TypedSyntaxError(
f"Cannot assign to a Final attribute of {self.instance.name}:{name}"
)
if (
isinstance(my_value, Slot)
and my_value.is_final
and not my_value.assignment
):
raise TypedSyntaxError(
f"Final attribute not initialized: {self.instance.name}:{name}"
)
for name in inherited:
assert type(self.members[name]) is Slot
del self.members[name]
def define_slot(
self,
name: str,
type_ref: Optional[TypeRef] = None,
assignment: Optional[AST] = None,
) -> None:
existing = self.members.get(name)
if existing is None:
self.members[name] = Slot(
type_ref or ResolvedTypeRef(DYNAMIC_TYPE), name, self, assignment
)
elif isinstance(existing, Slot):
if not existing.assignment:
existing.assignment = assignment
if type_ref is not None:
self._slot_redefs.setdefault(name, []).append(type_ref)
else:
raise TypedSyntaxError(
f"slot conflicts with other member {name} in {self.name}"
)
def define_function(
self,
name: str,
func: Function | StaticMethod,
visitor: DeclarationVisitor,
) -> None:
if name in self.members:
raise TypedSyntaxError(
f"function conflicts with other member {name} in {self.name}"
)
func.set_container_type(self)
self.members[name] = func
@property
def mro(self) -> Sequence[Class]:
mro = self._mro
if mro is None:
if not all(self.bases):
# TODO: We can't compile w/ unknown bases
mro = []
else:
mro = _mro(self)
self._mro = mro
return mro
@property
def mro_type_descrs(self) -> Collection[TypeDescr]:
cached = self._mro_type_descrs
if cached is None:
self._mro_type_descrs = cached = {b.type_descr for b in self.mro}
return cached
def bind_generics(
self,
name: GenericTypeName,
generic_types: Dict[Class, Dict[Tuple[Class, ...], Class]],
) -> Class:
return self
def get_own_member(self, name: str) -> Optional[Value]:
return self.members.get(name)
def get_parent_member(self, name: str) -> Optional[Value]:
# the first entry of mro is the class itself
for b in self.mro[1:]:
slot = b.members.get(name, None)
if slot:
return slot
def get_member(self, name: str) -> Optional[Value]:
member = self.get_own_member(name)
if member:
return member
return self.get_parent_member(name)
class GenericClass(Class):
type_name: GenericTypeName
is_variadic = False
def __init__(
self,
name: GenericTypeName,
bases: Optional[List[Class]] = None,
instance: Optional[Object[Class]] = None,
klass: Optional[Class] = None,
members: Optional[Dict[str, Value]] = None,
type_def: Optional[GenericClass] = None,
is_exact: bool = False,
pytype: Optional[Type[object]] = None,
) -> None:
super().__init__(name, bases, instance, klass, members, is_exact, pytype)
self.gen_name = name
self.type_def = type_def
def bind_call(
self, node: ast.Call, visitor: TypeBinder, type_ctx: Optional[Class]
) -> NarrowingEffect:
if self.contains_generic_parameters:
raise visitor.syntax_error(
f"cannot create instances of a generic {self.name}", node
)
return super().bind_call(node, visitor, type_ctx)
def bind_subscr(
self, node: ast.Subscript, type: Value, visitor: TypeBinder
) -> None:
slice = node.slice
if not isinstance(slice, ast.Index):
raise visitor.syntax_error("can't slice generic types", node)
visitor.visit(node.slice)
val = slice.value
if isinstance(val, ast.Tuple):
multiple: List[Class] = []
for elt in val.elts:
klass = visitor.cur_mod.resolve_annotation(elt)
if klass is None:
visitor.set_type(node, DYNAMIC)
return
multiple.append(klass)
index = tuple(multiple)
if (not self.is_variadic) and len(val.elts) != len(self.gen_name.args):
raise visitor.syntax_error(
"incorrect number of generic arguments", node
)
else:
if (not self.is_variadic) and len(self.gen_name.args) != 1:
raise visitor.syntax_error(
"incorrect number of generic arguments", node
)
single = visitor.cur_mod.resolve_annotation(val)
if single is None:
visitor.set_type(node, DYNAMIC)
return
index = (single,)
klass = self.make_generic_type(index, visitor.symtable.generic_types)
visitor.set_type(node, klass)
@property
def type_args(self) -> Sequence[Class]:
return self.type_name.args
@property
def contains_generic_parameters(self) -> bool:
for arg in self.gen_name.args:
if arg.is_generic_parameter:
return True
return False
@property
def is_generic_type(self) -> bool:
return True
@property
def is_generic_type_definition(self) -> bool:
return self.type_def is None
@property
def generic_type_def(self) -> Optional[Class]:
"""Gets the generic type definition that defined this class"""
return self.type_def
def make_generic_type(
self,
index: Tuple[Class, ...],
generic_types: GenericTypesDict,
) -> Class:
instantiations = generic_types.get(self)
if instantiations is not None:
instance = instantiations.get(index)
if instance is not None:
return instance
else:
generic_types[self] = instantiations = {}
type_args = index
type_name = GenericTypeName(
self.type_name.module, self.type_name.name, type_args
)
generic_bases: List[Optional[Class]] = [
(
base.make_generic_type(index, generic_types)
if base.contains_generic_parameters
else base
)
for base in self.bases
]
bases: List[Class] = [base for base in generic_bases if base is not None]
InstanceType = type(self.instance)
instance = InstanceType.__new__(InstanceType)
instance.__dict__.update(self.instance.__dict__)
concrete = type(self)(
type_name,
bases,
instance,
self.klass,
{},
is_exact=self.is_exact,
type_def=self,
)
instance.klass = concrete
instantiations[index] = concrete
concrete.members.update(
{
k: v.make_generic(concrete, type_name, generic_types)
for k, v in self.members.items()
}
)
return concrete
def bind_generics(
self,
name: GenericTypeName,
generic_types: Dict[Class, Dict[Tuple[Class, ...], Class]],
) -> Class:
if self.contains_generic_parameters:
type_args = [
arg for arg in self.type_name.args if isinstance(arg, GenericParameter)
]
assert len(type_args) == len(self.type_name.args)
# map the generic type parameters for the type to the parameters provided
bind_args = tuple(name.args[arg.index] for arg in type_args)
# We don't yet support generic methods, so all of the generic parameters are coming from the
# type definition.
return self.make_generic_type(bind_args, generic_types)
return self
class GenericParameter(Class):
def __init__(self, name: str, index: int) -> None:
super().__init__(TypeName("", name), [], None, None, {})
self.index = index
@property
def name(self) -> str:
return self.type_name.name
@property
def is_generic_parameter(self) -> bool:
return True
def bind_generics(
self,
name: GenericTypeName,
generic_types: Dict[Class, Dict[Tuple[Class, ...], Class]],
) -> Class:
return name.args[self.index]
class CType(Class):
"""base class for primitives that aren't heap allocated"""
suppress_exact = True
def __init__(
self,
type_name: TypeName,
bases: Optional[List[Class]] = None,
instance: Optional[CInstance[Class]] = None,
klass: Optional[Class] = None,
members: Optional[Dict[str, Value]] = None,
is_exact: bool = True,
pytype: Optional[Type[object]] = None,
) -> None:
super().__init__(type_name, bases, instance, klass, members, is_exact, pytype)
@property
def can_be_narrowed(self) -> bool:
return False
def bind_call(
self, node: ast.Call, visitor: TypeBinder, type_ctx: Optional[Class]
) -> NarrowingEffect:
"""
Almost the same as the base class method, but this allows args to be primitives
so we can write something like (explicit conversions):
x = int32(int8(5))
"""
visitor.set_type(node, self.instance)
for arg in node.args:
visitor.visit(arg, self.instance)
return NO_EFFECT
class DynamicClass(Class):
instance: DynamicInstance
def __init__(self) -> None:
super().__init__(
# any references to dynamic at runtime are object
TypeName("builtins", "object"),
bases=[OBJECT_TYPE],
instance=DynamicInstance(self),
)
@property
def qualname(self) -> str:
return "dynamic"
def can_assign_from(self, src: Class) -> bool:
# No automatic boxing to the dynamic type
return not isinstance(src, CType)
class DynamicInstance(Object[DynamicClass]):
def __init__(self, klass: DynamicClass) -> None:
super().__init__(klass)
def bind_constant(self, node: ast.Constant, visitor: TypeBinder) -> None:
n = node.value
inst = CONSTANT_TYPES.get(type(n), DYNAMIC_TYPE.instance)
visitor.set_type(node, inst)
def emit_binop(self, node: ast.BinOp, code_gen: Static38CodeGenerator) -> None:
if maybe_emit_sequence_repeat(node, code_gen):
return
code_gen.defaultVisit(node)
class NoneType(Class):
suppress_exact = True
def __init__(self) -> None:
super().__init__(
TypeName("builtins", "None"),
[OBJECT_TYPE],
NoneInstance(self),
is_exact=True,
)
UNARY_SYMBOLS: Mapping[Type[ast.unaryop], str] = {
ast.UAdd: "+",
ast.USub: "-",
ast.Invert: "~",
}
class NoneInstance(Object[NoneType]):
def bind_attr(
self, node: ast.Attribute, visitor: TypeBinder, type_ctx: Optional[Class]
) -> None:
raise visitor.syntax_error(
f"'NoneType' object has no attribute '{node.attr}'", node
)
def bind_call(
self, node: ast.Call, visitor: TypeBinder, type_ctx: Optional[Class]
) -> NarrowingEffect:
raise visitor.syntax_error("'NoneType' object is not callable", node)
def bind_subscr(
self, node: ast.Subscript, type: Value, visitor: TypeBinder
) -> None:
raise visitor.syntax_error("'NoneType' object is not subscriptable", node)
def bind_unaryop(
self, node: ast.UnaryOp, visitor: TypeBinder, type_ctx: Optional[Class]
) -> None:
if not isinstance(node.op, ast.Not):
raise visitor.syntax_error(
f"bad operand type for unary {UNARY_SYMBOLS[type(node.op)]}: 'NoneType'",
node,
)
visitor.set_type(node, BOOL_TYPE.instance)
def bind_binop(
self, node: ast.BinOp, visitor: TypeBinder, type_ctx: Optional[Class]
) -> bool:
# support `None | int` as a union type; None is special in that it is
# not a type but can be used synonymously with NoneType for typing.
if isinstance(node.op, ast.BitOr):
return self.klass.bind_binop(node, visitor, type_ctx)
else:
return super().bind_binop(node, visitor, type_ctx)
def bind_compare(
self,
node: ast.Compare,
left: expr,
op: cmpop,
right: expr,
visitor: TypeBinder,
type_ctx: Optional[Class],
) -> bool:
if isinstance(op, (ast.Eq, ast.NotEq, ast.Is, ast.IsNot)):
return super().bind_compare(node, left, op, right, visitor, type_ctx)
ltype = visitor.get_type(left)
rtype = visitor.get_type(right)
raise visitor.syntax_error(
f"'{CMPOP_SIGILS[type(op)]}' not supported between '{ltype.name}' and '{rtype.name}'",
node,
)
def bind_reverse_compare(
self,
node: ast.Compare,
left: expr,
op: cmpop,
right: expr,
visitor: TypeBinder,
type_ctx: Optional[Class],
) -> bool:
if isinstance(op, (ast.Eq, ast.NotEq, ast.Is, ast.IsNot)):
return super().bind_reverse_compare(
node, left, op, right, visitor, type_ctx
)
ltype = visitor.get_type(left)
rtype = visitor.get_type(right)
raise visitor.syntax_error(
f"'{CMPOP_SIGILS[type(op)]}' not supported between '{ltype.name}' and '{rtype.name}'",
node,
)
# https://www.python.org/download/releases/2.3/mro/
def _merge(seqs: Iterable[List[Class]]) -> List[Class]:
res = []
i = 0
while True:
nonemptyseqs = [seq for seq in seqs if seq]
if not nonemptyseqs:
return res
i += 1
cand = None
for seq in nonemptyseqs: # find merge candidates among seq heads
cand = seq[0]
nothead = [s for s in nonemptyseqs if cand in s[1:]]
if nothead:
cand = None # reject candidate
else:
break
if not cand:
types = {seq[0]: None for seq in nonemptyseqs}
raise SyntaxError(
"Cannot create a consistent method resolution order (MRO) for bases: "
+ ", ".join(t.name for t in types)
)
res.append(cand)
for seq in nonemptyseqs: # remove cand
if seq[0] == cand:
del seq[0]
def _mro(C: Class) -> List[Class]:
"Compute the class precedence list (mro) according to C3"
return _merge([[C]] + list(map(_mro, C.bases)) + [list(C.bases)])
class Parameter:
def __init__(
self,
name: str,
idx: int,
type_ref: TypeRef,
has_default: bool,
default_val: object,
is_kwonly: bool,
) -> None:
self.name = name
self.type_ref = type_ref
self.index = idx
self.has_default = has_default
self.default_val = default_val
self.is_kwonly = is_kwonly
def __repr__(self) -> str:
return (
f"<Parameter name={self.name}, ref={self.type_ref}, "
f"index={self.index}, has_default={self.has_default}>"
)
def bind_generics(
self,
name: GenericTypeName,
generic_types: Dict[Class, Dict[Tuple[Class, ...], Class]],
) -> Parameter:
klass = self.type_ref.resolved().bind_generics(name, generic_types)
if klass is not self.type_ref.resolved():
return Parameter(
self.name,
self.index,
ResolvedTypeRef(klass),
self.has_default,
self.default_val,
self.is_kwonly,
)
return self
def is_subsequence(a: Iterable[object], b: Iterable[object]) -> bool:
# for loops go brrrr :)
# https://ericlippert.com/2020/03/27/new-grad-vs-senior-dev/
itr = iter(a)
for each in b:
if each not in itr:
return False
return True
class ArgMapping:
def __init__(
self,
callable: Callable[TClass],
call: ast.Call,
self_arg: Optional[ast.expr],
) -> None:
self.callable = callable
self.call = call
pos_args: List[ast.expr] = []
if self_arg is not None:
pos_args.append(self_arg)
pos_args.extend(call.args)
self.args: List[ast.expr] = pos_args
self.kwargs: List[Tuple[Optional[str], ast.expr]] = [
(kwarg.arg, kwarg.value) for kwarg in call.keywords
]
self.self_arg = self_arg
self.emitters: List[ArgEmitter] = []
self.nvariadic = 0
self.nseen = 0
self.spills: Dict[int, SpillArg] = {}
def bind_args(self, visitor: TypeBinder) -> None:
# TODO: handle duplicate args and other weird stuff a-la
# https://fburl.com/diffusion/q6tpinw8
# Process provided position arguments to expected parameters
for idx, (param, arg) in enumerate(zip(self.callable.args, self.args)):
if param.is_kwonly:
raise visitor.syntax_error(
f"{self.callable.qualname} takes {idx} positional args but "
f"{len(self.args)} {'was' if len(self.args) == 1 else 'were'} given",
self.call,
)
elif isinstance(arg, Starred):
# Skip type verification here, f(a, b, *something)
# TODO: add support for this by implementing type constrained tuples
self.nvariadic += 1
star_params = self.callable.args[idx:]
self.emitters.append(StarredArg(arg.value, star_params))
self.nseen = len(self.callable.args)
for arg in self.args[idx:]:
visitor.visit(arg)
break
resolved_type = self.visit_arg(visitor, param, arg, "positional")
self.emitters.append(PositionArg(arg, resolved_type))
self.nseen += 1
self.bind_kwargs(visitor)
for argname, argvalue in self.kwargs:
if argname is None:
visitor.visit(argvalue)
continue
if argname not in self.callable.args_by_name:
raise visitor.syntax_error(
f"Given argument {argname} "
f"does not exist in the definition of {self.callable.qualname}",
self.call,
)
# nseen must equal number of defined args if no variadic args are used
if self.nvariadic == 0 and (self.nseen != len(self.callable.args)):
raise visitor.syntax_error(
f"Mismatched number of args for {self.callable.name}. "
f"Expected {len(self.callable.args)}, got {self.nseen}",
self.call,
)
def bind_kwargs(self, visitor: TypeBinder) -> None:
spill_start = len(self.emitters)
seen_variadic = False
# Process unhandled arguments which can be populated via defaults,
# keyword arguments, or **mapping.
cur_kw_arg = 0
for idx in range(self.nseen, len(self.callable.args)):
param = self.callable.args[idx]
name = param.name
if (
cur_kw_arg is not None
and cur_kw_arg < len(self.kwargs)
and self.kwargs[cur_kw_arg][0] == name
):
# keyword arg hit, with the keyword arguments still in order...
arg = self.kwargs[cur_kw_arg][1]
resolved_type = self.visit_arg(visitor, param, arg, "keyword")
cur_kw_arg += 1
self.emitters.append(KeywordArg(arg, resolved_type))
self.nseen += 1
continue
variadic_idx = None
for candidate_kw in range(len(self.kwargs)):
if name == self.kwargs[candidate_kw][0]:
arg = self.kwargs[candidate_kw][1]
tmp_name = f"{_TMP_VAR_PREFIX}{name}"
self.spills[candidate_kw] = SpillArg(arg, tmp_name)
if cur_kw_arg is not None:
cur_kw_arg = None
spill_start = len(self.emitters)
resolved_type = self.visit_arg(visitor, param, arg, "keyword")
self.emitters.append(SpilledKeywordArg(tmp_name, resolved_type))
break
elif self.kwargs[candidate_kw][0] == None:
variadic_idx = candidate_kw
else:
if variadic_idx is not None:
# We have a f(**something), if the arg is unavailable, we
# load it from the mapping
if variadic_idx not in self.spills:
self.spills[variadic_idx] = SpillArg(
self.kwargs[variadic_idx][1], f"{_TMP_VAR_PREFIX}**"
)
if cur_kw_arg is not None:
cur_kw_arg = None
spill_start = len(self.emitters)
self.emitters.append(
KeywordMappingArg(param, f"{_TMP_VAR_PREFIX}**")
)
elif param.has_default:
self.emitters.append(DefaultArg(param.default_val))
else:
# It's an error if this arg did not have a default value in the definition
raise visitor.syntax_error(
f"Function {self.callable.qualname} expects a value for "
f"argument {param.name}",
self.call,
)
self.nseen += 1
if self.spills:
self.emitters[spill_start:spill_start] = [
x[1] for x in sorted(self.spills.items())
]
def visit_arg(
self, visitor: TypeBinder, param: Parameter, arg: expr, arg_style: str
) -> Class:
resolved_type = param.type_ref.resolved()
exc = None
try:
visitor.visit(arg, resolved_type.instance if resolved_type else None)
except TypedSyntaxError as e:
# We may report a better error message below...
exc = e
visitor.check_can_assign_from(
resolved_type,
visitor.get_type(arg).klass,
arg,
f"{arg_style} argument type mismatch",
)
if exc is not None:
raise exc
return resolved_type
class ArgEmitter:
def __init__(self, argument: expr, type: Class) -> None:
self.argument = argument
self.type = type
def emit(self, node: Call, code_gen: Static38CodeGenerator) -> None:
pass
class PositionArg(ArgEmitter):
def emit(self, node: Call, code_gen: Static38CodeGenerator) -> None:
arg_type = code_gen.get_type(self.argument)
code_gen.visit(self.argument)
code_gen.emit_type_check(
self.type,
arg_type.klass,
node,
)
def __repr__(self) -> str:
return f"PositionArg({to_expr(self.argument)}, {self.type})"
class StarredArg(ArgEmitter):
def __init__(self, argument: expr, params: List[Parameter]) -> None:
self.argument = argument
self.params = params
def emit(self, node: Call, code_gen: Static38CodeGenerator) -> None:
code_gen.visit(self.argument)
for idx, param in enumerate(self.params):
code_gen.emit("LOAD_ITERABLE_ARG", idx)
if (
param.type_ref.resolved() is not None
and param.type_ref.resolved() is not DYNAMIC
):
code_gen.emit("ROT_TWO")
code_gen.emit("CAST", param.type_ref.resolved().type_descr)
code_gen.emit("ROT_TWO")
# Remove the tuple from TOS
code_gen.emit("POP_TOP")
class SpillArg(ArgEmitter):
def __init__(self, argument: expr, temporary: str) -> None:
self.argument = argument
self.temporary = temporary
def emit(self, node: Call, code_gen: Static38CodeGenerator) -> None:
code_gen.visit(self.argument)
code_gen.emit("STORE_FAST", self.temporary)
def __repr__(self) -> str:
return f"SpillArg(..., {self.temporary})"
class SpilledKeywordArg(ArgEmitter):
def __init__(self, temporary: str, type: Class) -> None:
self.temporary = temporary
self.type = type
def emit(self, node: Call, code_gen: Static38CodeGenerator) -> None:
code_gen.emit("LOAD_FAST", self.temporary)
code_gen.emit_type_check(
self.type,
DYNAMIC_TYPE,
node,
)
def __repr__(self) -> str:
return f"SpilledKeywordArg({self.temporary})"
class KeywordArg(ArgEmitter):
def __init__(self, argument: expr, type: Class) -> None:
self.argument = argument
self.type = type
def emit(self, node: Call, code_gen: Static38CodeGenerator) -> None:
code_gen.visit(self.argument)
code_gen.emit_type_check(
self.type,
code_gen.get_type(self.argument).klass,
node,
)
class KeywordMappingArg(ArgEmitter):
def __init__(self, param: Parameter, variadic: str) -> None:
self.param = param
self.variadic = variadic
def emit(self, node: Call, code_gen: Static38CodeGenerator) -> None:
if self.param.has_default:
code_gen.emit("LOAD_CONST", self.param.default_val)
code_gen.emit("LOAD_FAST", self.variadic)
code_gen.emit("LOAD_CONST", self.param.name)
if self.param.has_default:
code_gen.emit("LOAD_MAPPING_ARG", 3)
else:
code_gen.emit("LOAD_MAPPING_ARG", 2)
code_gen.emit_type_check(
self.param.type_ref.resolved() or DYNAMIC_TYPE, DYNAMIC_TYPE, node
)
class DefaultArg(ArgEmitter):
def __init__(self, value: object) -> None:
self.value = value
def emit(self, node: Call, code_gen: Static38CodeGenerator) -> None:
code_gen.emit("LOAD_CONST", self.value)
class Callable(Object[TClass]):
def __init__(
self,
klass: Class,
func_name: str,
module_name: str,
args: List[Parameter],
args_by_name: Dict[str, Parameter],
num_required_args: int,
vararg: Optional[Parameter],
kwarg: Optional[Parameter],
return_type: TypeRef,
) -> None:
super().__init__(klass)
self.func_name = func_name
self.module_name = module_name
self.container_type: Optional[Class] = None
self.args = args
self.args_by_name = args_by_name
self.num_required_args = num_required_args
self.has_vararg: bool = vararg is not None
self.has_kwarg: bool = kwarg is not None
self.return_type = return_type
self.is_final = False
@property
def qualname(self) -> str:
cont = self.container_type
if cont:
return f"{cont.qualname}.{self.func_name}"
return f"{self.module_name}.{self.func_name}"
@property
def type_descr(self) -> TypeDescr:
cont = self.container_type
if cont:
return cont.type_descr + (self.func_name,)
return (self.module_name, self.func_name)
def set_container_type(self, klass: Optional[Class]) -> None:
self.container_type = klass
def bind_call(
self, node: ast.Call, visitor: TypeBinder, type_ctx: Optional[Class]
) -> NarrowingEffect:
# Careful adding logic here, MethodType.bind_call() will bypass it
return self.bind_call_self(node, visitor, type_ctx)
def bind_call_self(
self,
node: ast.Call,
visitor: TypeBinder,
type_ctx: Optional[Class],
self_expr: Optional[ast.expr] = None,
) -> NarrowingEffect:
if self.has_vararg or self.has_kwarg:
return super().bind_call(node, visitor, type_ctx)
if type_ctx is not None:
visitor.check_can_assign_from(
type_ctx.klass,
self.return_type.resolved(),
node,
"is an invalid return type, expected",
)
arg_mapping = ArgMapping(self, node, self_expr)
arg_mapping.bind_args(visitor)
visitor.set_type(node, self.return_type.resolved().instance)
visitor.set_node_data(node, ArgMapping, arg_mapping)
return NO_EFFECT
def _emit_kwarg_temps(
self, keywords: List[ast.keyword], code_gen: Static38CodeGenerator
) -> Dict[str, str]:
temporaries = {}
for each in keywords:
name = each.arg
if name is not None:
code_gen.visit(each.value)
temp_var_name = f"{_TMP_VAR_PREFIX}{name}"
code_gen.emit("STORE_FAST", temp_var_name)
temporaries[name] = temp_var_name
return temporaries
def _find_provided_kwargs(
self, node: ast.Call
) -> Tuple[Dict[int, int], Optional[int]]:
# This is a mapping of indices from index in the function definition --> node.keywords
provided_kwargs: Dict[int, int] = {}
# Index of `**something` in the call
variadic_idx: Optional[int] = None
for idx, argument in enumerate(node.keywords):
name = argument.arg
if name is not None:
provided_kwargs[self.args_by_name[name].index] = idx
else:
# Because of the constraints above, we will only ever reach here once
variadic_idx = idx
return provided_kwargs, variadic_idx
def can_call_self(self, node: ast.Call, has_self: bool) -> bool:
if self.has_vararg or self.has_kwarg:
return False
has_default_args = self.num_required_args < len(self.args)
has_star_args = False
for a in node.args:
if isinstance(a, ast.Starred):
if has_star_args:
# We don't support f(*a, *b)
return False
has_star_args = True
elif has_star_args:
# We don't support f(*a, b)
return False
num_star_args = [isinstance(a, ast.Starred) for a in node.args].count(True)
num_dstar_args = [(a.arg is None) for a in node.keywords].count(True)
num_kwonly = len([arg for arg in self.args if arg.is_kwonly])
start = 1 if has_self else 0
for arg in self.args[start + len(node.args) :]:
if arg.has_default and isinstance(arg.default_val, ast.expr):
for kw_arg in node.keywords:
if kw_arg.arg == arg.name:
break
else:
return False
if (
# We don't support f(**a, **b)
num_dstar_args > 1
# We don't support f(1, 2, *a) if f has any default arg values
or (has_default_args and has_star_args)
or num_kwonly
):
return False
return True
def emit_call_self(self, node: ast.Call, code_gen: Static38CodeGenerator) -> None:
arg_mapping: ArgMapping = code_gen.get_node_data(node, ArgMapping)
for emitter in arg_mapping.emitters:
emitter.emit(node, code_gen)
self_expr = arg_mapping.self_arg
if (
self_expr is None
or code_gen.get_type(self_expr).klass.is_exact
or code_gen.get_type(self_expr).klass.is_final
):
code_gen.emit("EXTENDED_ARG", 0)
code_gen.emit("INVOKE_FUNCTION", (self.type_descr, len(self.args)))
else:
code_gen.emit_invoke_method(self.type_descr, len(self.args) - 1)
class ContainerTypeRef(TypeRef):
def __init__(self, func: Function) -> None:
self.func = func
def resolved(self, is_declaration: bool = False) -> Class:
res = self.func.container_type
if res is None:
return DYNAMIC_TYPE
return res
class InlineRewriter(ASTRewriter):
def __init__(self, replacements: Dict[str, ast.expr]) -> None:
super().__init__()
self.replacements = replacements
def visit(
self, node: Union[TAst, Sequence[AST]], *args: object
) -> Union[AST, Sequence[AST]]:
res = super().visit(node, *args)
if res is node:
if isinstance(node, AST):
return self.clone_node(node)
return list(node)
return res
def visitName(self, node: ast.Name) -> AST:
res = self.replacements.get(node.id)
if res is None:
return self.clone_node(node)
return res
class InlinedCall:
def __init__(
self,
expr: ast.expr,
replacements: Dict[ast.expr, ast.expr],
spills: Dict[str, Tuple[ast.expr, ast.Name]],
) -> None:
self.expr = expr
self.replacements = replacements
self.spills = spills
class Function(Callable[Class]):
def __init__(
self,
node: Union[AsyncFunctionDef, FunctionDef],
module: ModuleTable,
ret_type: TypeRef,
) -> None:
super().__init__(
FUNCTION_TYPE,
node.name,
module.name,
[],
{},
0,
None,
None,
ret_type,
)
self.node = node
self.module = module
self.process_args(module)
self.inline = False
self.donotcompile = False
@property
def name(self) -> str:
return f"function {self.qualname}"
def bind_call(
self, node: ast.Call, visitor: TypeBinder, type_ctx: Optional[Class]
) -> NarrowingEffect:
res = super().bind_call(node, visitor, type_ctx)
if self.inline and visitor.optimize == 2:
assert isinstance(self.node.body[0], ast.Return)
return self.bind_inline_call(node, visitor, type_ctx) or res
return res
def emit_call(self, node: ast.Call, code_gen: Static38CodeGenerator) -> None:
if not self.can_call_self(node, False):
return super().emit_call(node, code_gen)
if self.inline and code_gen.optimization_lvl == 2:
return self.emit_inline_call(node, code_gen)
return self.emit_call_self(node, code_gen)
def bind_inline_call(
self, node: ast.Call, visitor: TypeBinder, type_ctx: Optional[Class]
) -> Optional[NarrowingEffect]:
args = visitor.get_node_data(node, ArgMapping)
arg_replacements = {}
spills = {}
if visitor.inline_depth > 20:
visitor.set_node_data(node, Optional[InlinedCall], None)
return None
visitor.inline_depth += 1
for idx, arg in enumerate(args.emitters):
name = self.node.args.args[idx].arg
if isinstance(arg, DefaultArg):
arg_replacements[name] = ast.Constant(arg.value)
continue
elif not isinstance(arg, (PositionArg, KeywordArg)):
# We don't support complicated calls to inline functions
visitor.set_node_data(node, Optional[InlinedCall], None)
return None
if (
isinstance(arg.argument, ast.Constant)
or visitor.get_final_literal(arg.argument) is not None
):
arg_replacements[name] = arg.argument
continue
# store to a temporary...
tmp_name = f"{_TMP_VAR_PREFIX}{visitor.inline_depth}{name}"
cur_scope = visitor.symbols.scopes[visitor.scope]
cur_scope.add_def(tmp_name)
store = ast.Name(tmp_name, ast.Store())
visitor.set_type(store, visitor.get_type(arg.argument))
spills[tmp_name] = arg.argument, store
replacement = ast.Name(tmp_name, ast.Load())
visitor.assign_value(replacement, visitor.get_type(arg.argument))
arg_replacements[name] = replacement
# re-write node body with replacements...
return_stmt = self.node.body[0]
assert isinstance(return_stmt, Return)
ret_value = return_stmt.value
if ret_value is not None:
new_node = InlineRewriter(arg_replacements).visit(ret_value)
else:
new_node = ast.Constant(None)
new_node = AstOptimizer().visit(new_node)
inlined_call = InlinedCall(new_node, arg_replacements, spills)
visitor.visit(new_node)
visitor.set_node_data(node, Optional[InlinedCall], inlined_call)
visitor.inline_depth -= 1
def emit_inline_call(self, node: ast.Call, code_gen: Static38CodeGenerator) -> None:
assert isinstance(self.node.body[0], ast.Return)
inlined_call = code_gen.get_node_data(node, Optional[InlinedCall])
if inlined_call is None:
return self.emit_call_self(node, code_gen)
for name, (arg, store) in inlined_call.spills.items():
code_gen.visit(arg)
code_gen.get_type(store).emit_name(store, code_gen)
code_gen.visit(inlined_call.expr)
def bind_descr_get(
self,
node: ast.Attribute,
inst: Optional[Object[TClassInv]],
ctx: TClassInv,
visitor: TypeBinder,
type_ctx: Optional[Class],
) -> None:
if inst is None:
visitor.set_type(node, self)
else:
visitor.set_type(node, MethodType(ctx.type_name, self.node, node, self))
def register_arg(
self,
name: str,
idx: int,
ref: TypeRef,
has_default: bool,
default_val: object,
is_kwonly: bool,
) -> None:
parameter = Parameter(name, idx, ref, has_default, default_val, is_kwonly)
self.args.append(parameter)
self.args_by_name[name] = parameter
if not has_default:
self.num_required_args += 1
def process_args(
self: Function,
module: ModuleTable,
) -> None:
"""
Register type-refs for each function argument, assume DYNAMIC if annotation is missing.
"""
arguments = self.node.args
nrequired = len(arguments.args) - len(arguments.defaults)
no_defaults = cast(List[Optional[ast.expr]], [None] * nrequired)
defaults = no_defaults + cast(List[Optional[ast.expr]], arguments.defaults)
idx = 0
for idx, (argument, default) in enumerate(zip(arguments.args, defaults)):
annotation = argument.annotation
default_val = None
has_default = False
if default is not None:
has_default = True
default_val = get_default_value(default)
if annotation:
ref = TypeRef(module, annotation)
elif idx == 0:
ref = ContainerTypeRef(self)
else:
ref = ResolvedTypeRef(DYNAMIC_TYPE)
self.register_arg(argument.arg, idx, ref, has_default, default_val, False)
base_idx = idx
vararg = arguments.vararg
if vararg:
base_idx += 1
self.has_vararg = True
for argument, default in zip(arguments.kwonlyargs, arguments.kw_defaults):
annotation = argument.annotation
default_val = None
has_default = default is not None
if default is not None:
default_val = get_default_value(default)
if annotation:
ref = TypeRef(module, annotation)
else:
ref = ResolvedTypeRef(DYNAMIC_TYPE)
base_idx += 1
self.register_arg(
argument.arg, base_idx, ref, has_default, default_val, True
)
kwarg = arguments.kwarg
if kwarg:
self.has_kwarg = True
def __repr__(self) -> str:
return f"<{self.name} '{self.name}' instance, args={self.args}>"
class MethodType(Object[Class]):
def __init__(
self,
bound_type_name: TypeName,
node: Union[AsyncFunctionDef, FunctionDef],
target: ast.Attribute,
function: Function,
) -> None:
super().__init__(METHOD_TYPE)
# TODO currently this type (the type the bound method was accessed
# from) is unused, and we just end up deferring to the type where the
# function was defined. This is fine until we want to fully support a
# method defined in one class being also referenced as a method in
# another class.
self.bound_type_name = bound_type_name
self.node = node
self.target = target
self.function = function
@property
def name(self) -> str:
return "method " + self.function.qualname
def bind_call(
self, node: ast.Call, visitor: TypeBinder, type_ctx: Optional[Class]
) -> NarrowingEffect:
result = self.function.bind_call_self(
node, visitor, type_ctx, self.target.value
)
self.check_args_for_primitives(node, visitor)
return result
def emit_call(self, node: ast.Call, code_gen: Static38CodeGenerator) -> None:
if not self.function.can_call_self(node, True):
return super().emit_call(node, code_gen)
code_gen.update_lineno(node)
self.function.emit_call_self(node, code_gen)
class StaticMethod(Object[Class]):
def __init__(
self,
function: Function,
) -> None:
super().__init__(STATIC_METHOD_TYPE)
self.function = function
@property
def name(self) -> str:
return "staticmethod " + self.function.qualname
@property
def func_name(self) -> str:
return self.function.func_name
@property
def is_final(self) -> bool:
return self.function.is_final
def set_container_type(self, container_type: Optional[Class]) -> None:
self.function.set_container_type(container_type)
def bind_descr_get(
self,
node: ast.Attribute,
inst: Optional[Object[TClassInv]],
ctx: TClassInv,
visitor: TypeBinder,
type_ctx: Optional[Class],
) -> None:
visitor.set_type(node, self.function)
class TypingFinalDecorator(Class):
def bind_decorate_function(
self, visitor: DeclarationVisitor, fn: Function | StaticMethod
) -> Value:
if isinstance(fn, StaticMethod):
fn.function.is_final = True
else:
fn.is_final = True
return fn
def bind_decorate_class(self, klass: Class) -> Class:
klass.is_final = True
return klass
class AllowWeakrefsDecorator(Class):
def bind_decorate_class(self, klass: Class) -> Class:
klass.allow_weakrefs = True
return klass
class DynamicReturnDecorator(Class):
def bind_decorate_function(
self, visitor: DeclarationVisitor, fn: Function | StaticMethod
) -> Value:
real_fn = fn.function if isinstance(fn, StaticMethod) else fn
real_fn.return_type = ResolvedTypeRef(DYNAMIC_TYPE)
real_fn.module.dynamic_returns.add(real_fn.node.args)
return fn
class StaticMethodDecorator(Class):
def bind_decorate_function(
self, visitor: DeclarationVisitor, fn: Function | StaticMethod
) -> Value:
if isinstance(fn, StaticMethod):
# no-op
return fn
return StaticMethod(fn)
class InlineFunctionDecorator(Class):
def bind_decorate_function(
self, visitor: DeclarationVisitor, fn: Function | StaticMethod
) -> Value:
real_fn = fn.function if isinstance(fn, StaticMethod) else fn
if not isinstance(real_fn.node.body[0], ast.Return):
raise visitor.syntax_error(
"@inline only supported on functions with simple return", real_fn.node
)
real_fn.inline = True
return fn
class DoNotCompileDecorator(Class):
def bind_decorate_function(
self, visitor: DeclarationVisitor, fn: Function | StaticMethod
) -> Optional[Value]:
real_fn = fn.function if isinstance(fn, StaticMethod) else fn
real_fn.donotcompile = True
return fn
def bind_decorate_class(self, klass: Class) -> Class:
klass.donotcompile = True
return klass
class BuiltinFunction(Callable[Class]):
def __init__(
self,
func_name: str,
module: str,
args: Optional[Tuple[Parameter, ...]] = None,
return_type: Optional[TypeRef] = None,
) -> None:
assert isinstance(return_type, (TypeRef, type(None)))
super().__init__(
BUILTIN_METHOD_DESC_TYPE,
func_name,
module,
args,
{},
0,
None,
None,
return_type or ResolvedTypeRef(DYNAMIC_TYPE),
)
def emit_call(self, node: ast.Call, code_gen: Static38CodeGenerator) -> None:
if node.keywords or (
self.args is not None and not self.can_call_self(node, True)
):
return super().emit_call(node, code_gen)
code_gen.update_lineno(node)
self.emit_call_self(node, code_gen)
class BuiltinMethodDescriptor(Callable[Class]):
def __init__(
self,
func_name: str,
container_type: Class,
args: Optional[Tuple[Parameter, ...]] = None,
return_type: Optional[TypeRef] = None,
) -> None:
assert isinstance(return_type, (TypeRef, type(None)))
super().__init__(
BUILTIN_METHOD_DESC_TYPE,
func_name,
container_type.type_name.module,
args,
{},
0,
None,
None,
return_type or ResolvedTypeRef(DYNAMIC_TYPE),
)
self.set_container_type(container_type)
def bind_call_self(
self,
node: ast.Call,
visitor: TypeBinder,
type_ctx: Optional[Class],
self_expr: Optional[expr] = None,
) -> NarrowingEffect:
if self.args is not None:
return super().bind_call_self(node, visitor, type_ctx, self_expr)
elif node.keywords:
return super().bind_call(node, visitor, type_ctx)
visitor.set_type(node, DYNAMIC)
for arg in node.args:
visitor.visit(arg)
return NO_EFFECT
def bind_descr_get(
self,
node: ast.Attribute,
inst: Optional[Object[TClassInv]],
ctx: TClassInv,
visitor: TypeBinder,
type_ctx: Optional[Class],
) -> None:
if inst is None:
visitor.set_type(node, self)
else:
visitor.set_type(node, BuiltinMethod(self, node))
def make_generic(
self, new_type: Class, name: GenericTypeName, generic_types: GenericTypesDict
) -> Value:
cur_args = self.args
cur_ret_type = self.return_type
if cur_args is not None and cur_ret_type is not None:
new_args = tuple(arg.bind_generics(name, generic_types) for arg in cur_args)
new_ret_type = cur_ret_type.resolved().bind_generics(name, generic_types)
return BuiltinMethodDescriptor(
self.func_name,
new_type,
new_args,
ResolvedTypeRef(new_ret_type),
)
else:
return BuiltinMethodDescriptor(self.func_name, new_type)
class BuiltinMethod(Callable[Class]):
def __init__(self, desc: BuiltinMethodDescriptor, target: ast.Attribute) -> None:
super().__init__(
BUILTIN_METHOD_TYPE,
desc.func_name,
desc.module_name,
desc.args,
{},
0,
None,
None,
desc.return_type,
)
self.desc = desc
self.target = target
self.set_container_type(desc.container_type)
@property
def name(self) -> str:
return self.qualname
def bind_call(
self, node: ast.Call, visitor: TypeBinder, type_ctx: Optional[Class]
) -> NarrowingEffect:
if self.args:
return super().bind_call_self(node, visitor, type_ctx, self.target.value)
if node.keywords:
return Object.bind_call(self, node, visitor, type_ctx)
visitor.set_type(node, self.return_type.resolved().instance)
visitor.visit(self.target.value)
for arg in node.args:
visitor.visit(arg)
self.check_args_for_primitives(node, visitor)
return NO_EFFECT
def emit_call(self, node: ast.Call, code_gen: Static38CodeGenerator) -> None:
if node.keywords or (
self.args is not None and not self.desc.can_call_self(node, True)
):
return super().emit_call(node, code_gen)
code_gen.update_lineno(node)
if self.args is not None:
self.desc.emit_call_self(node, code_gen)
else:
# Untyped method, we can still do an INVOKE_METHOD
code_gen.visit(self.target.value)
code_gen.update_lineno(node)
for arg in node.args:
code_gen.visit(arg)
if code_gen.get_type(self.target.value).klass.is_exact:
code_gen.emit("INVOKE_FUNCTION", (self.type_descr, len(node.args) + 1))
else:
code_gen.emit_invoke_method(self.type_descr, len(node.args))
class StrictBuiltins(Object[Class]):
def __init__(self, builtins: Dict[str, Value]) -> None:
super().__init__(DICT_TYPE)
self.builtins = builtins
def bind_subscr(
self, node: ast.Subscript, type: Value, visitor: TypeBinder
) -> None:
slice = node.slice
type = DYNAMIC
if isinstance(slice, ast.Index):
val = slice.value
if isinstance(val, ast.Str):
builtin = self.builtins.get(val.s)
if builtin is not None:
type = builtin
elif isinstance(val, ast.Constant):
svalue = val.value
if isinstance(svalue, str):
builtin = self.builtins.get(svalue)
if builtin is not None:
type = builtin
visitor.set_type(node, type)
def get_default_value(default: expr) -> object:
if not isinstance(default, (Constant, Str, Num, Bytes, NameConstant, ast.Ellipsis)):
default = AstOptimizer().visit(default)
if isinstance(default, Str):
return default.s
elif isinstance(default, Num):
return default.n
elif isinstance(default, Bytes):
return default.s
elif isinstance(default, ast.Ellipsis):
return ...
elif isinstance(default, (ast.Constant, ast.NameConstant)):
return default.value
else:
return default
# Bringing up the type system is a little special as we have dependencies
# amongst type and object
TYPE_TYPE = Class.__new__(Class)
TYPE_TYPE.type_name = TypeName("builtins", "type")
TYPE_TYPE.klass = TYPE_TYPE
TYPE_TYPE.instance = TYPE_TYPE
TYPE_TYPE.members = {}
TYPE_TYPE.is_exact = False
TYPE_TYPE.is_final = False
TYPE_TYPE._mro = None
TYPE_TYPE._mro_type_descrs = None
class Slot(Object[TClassInv]):
def __init__(
self,
type_ref: TypeRef,
name: str,
container_type: Class,
assignment: Optional[AST] = None,
) -> None:
super().__init__(MEMBER_TYPE)
self.container_type = container_type
self.slot_name = name
self._type_ref = type_ref
self.assignment = assignment
def bind_descr_get(
self,
node: ast.Attribute,
inst: Optional[Object[TClassInv]],
ctx: TClassInv,
visitor: TypeBinder,
type_ctx: Optional[Class],
) -> None:
if inst is None:
visitor.set_type(node, self)
return
visitor.set_type(node, self.decl_type.instance)
@property
def decl_type(self) -> Class:
type = self._type_ref.resolved(is_declaration=True)
if isinstance(type, FinalClass):
return type.inner_type()
return type
@property
def is_final(self) -> bool:
return isinstance(self._type_ref.resolved(is_declaration=True), FinalClass)
@property
def type_descr(self) -> TypeDescr:
return self.decl_type.type_descr
# TODO (aniketpanse): move these to a better place
OBJECT_TYPE = Class(TypeName("builtins", "object"))
OBJECT = OBJECT_TYPE.instance
DYNAMIC_TYPE = DynamicClass()
DYNAMIC = DYNAMIC_TYPE.instance
class BoxFunction(Object[Class]):
def bind_call(
self, node: ast.Call, visitor: TypeBinder, type_ctx: Optional[Class]
) -> NarrowingEffect:
if len(node.args) != 1:
raise visitor.syntax_error("box only accepts a single argument", node)
arg = node.args[0]
visitor.visit(arg)
arg_type = visitor.get_type(arg)
if isinstance(arg_type, CIntInstance):
typ = BOOL_TYPE if arg_type.constant == TYPED_BOOL else INT_EXACT_TYPE
visitor.set_type(node, typ.instance)
elif isinstance(arg_type, CDoubleInstance):
visitor.set_type(node, FLOAT_EXACT_TYPE.instance)
else:
raise visitor.syntax_error(
f"can't box non-primitive: {arg_type.name}", node
)
return NO_EFFECT
def emit_call(self, node: ast.Call, code_gen: Static38CodeGenerator) -> None:
code_gen.get_type(node.args[0]).emit_box(node.args[0], code_gen)
class UnboxFunction(Object[Class]):
def bind_call(
self, node: ast.Call, visitor: TypeBinder, type_ctx: Optional[Class]
) -> NarrowingEffect:
if len(node.args) != 1:
raise visitor.syntax_error("unbox only accepts a single argument", node)
for arg in node.args:
visitor.visit(arg, DYNAMIC)
self.check_args_for_primitives(node, visitor)
visitor.set_type(node, type_ctx or INT64_VALUE)
return NO_EFFECT
def emit_call(self, node: ast.Call, code_gen: Static38CodeGenerator) -> None:
code_gen.get_type(node).emit_unbox(node.args[0], code_gen)
class LenFunction(Object[Class]):
def __init__(self, klass: Class, boxed: bool) -> None:
super().__init__(klass)
self.boxed = boxed
@property
def name(self) -> str:
return f"{'' if self.boxed else 'c'}len function"
def bind_call(
self, node: ast.Call, visitor: TypeBinder, type_ctx: Optional[Class]
) -> NarrowingEffect:
if len(node.args) != 1:
visitor.syntax_error(
f"len() does not accept more than one arguments ({len(node.args)} given)",
node,
)
arg = node.args[0]
visitor.visit(arg)
arg_type = visitor.get_type(arg)
if not self.boxed and arg_type.get_fast_len_type() is None:
raise visitor.syntax_error(
f"bad argument type '{arg_type.name}' for clen()", arg
)
self.check_args_for_primitives(node, visitor)
output_type = INT_EXACT_TYPE.instance if self.boxed else INT64_TYPE.instance
visitor.set_type(node, output_type)
return NO_EFFECT
def emit_call(self, node: ast.Call, code_gen: Static38CodeGenerator) -> None:
code_gen.get_type(node.args[0]).emit_len(node, code_gen, boxed=self.boxed)
class SortedFunction(Object[Class]):
@property
def name(self) -> str:
return "sorted function"
def bind_call(
self, node: ast.Call, visitor: TypeBinder, type_ctx: Optional[Class]
) -> NarrowingEffect:
if len(node.args) != 1:
visitor.syntax_error(
f"sorted() accepts one positional argument ({len(node.args)} given)",
node,
)
visitor.visit(node.args[0])
for kw in node.keywords:
visitor.visit(kw.value)
self.check_args_for_primitives(node, visitor)
visitor.set_type(node, LIST_EXACT_TYPE.instance)
return NO_EFFECT
def emit_call(self, node: ast.Call, code_gen: Static38CodeGenerator) -> None:
super().emit_call(node, code_gen)
code_gen.emit("REFINE_TYPE", LIST_EXACT_TYPE.type_descr)
class ExtremumFunction(Object[Class]):
def __init__(self, klass: Class, is_min: bool) -> None:
super().__init__(klass)
self.is_min = is_min
@property
def _extremum(self) -> str:
return "min" if self.is_min else "max"
@property
def name(self) -> str:
return f"{self._extremum} function"
def emit_call(self, node: ast.Call, code_gen: Static38CodeGenerator) -> None:
if (
# We only specialize for two args
len(node.args) != 2
# We don't support specialization if any kwargs are present
or len(node.keywords) > 0
# If we have any *args, we skip specialization
or any(isinstance(a, ast.Starred) for a in node.args)
):
return super().emit_call(node, code_gen)
# Compile `min(a, b)` to a ternary expression, `a if a <= b else b`.
# Similar for `max(a, b).
endblock = code_gen.newBlock(f"{self._extremum}_end")
elseblock = code_gen.newBlock(f"{self._extremum}_else")
for a in node.args:
code_gen.visit(a)
if self.is_min:
op = "<="
else:
op = ">="
code_gen.emit("DUP_TOP_TWO")
code_gen.emit("COMPARE_OP", op)
code_gen.emit("POP_JUMP_IF_FALSE", elseblock)
# Remove `b` from stack, `a` was the minimum
code_gen.emit("POP_TOP")
code_gen.emit("JUMP_FORWARD", endblock)
code_gen.nextBlock(elseblock)
# Remove `a` from the stack, `b` was the minimum
code_gen.emit("ROT_TWO")
code_gen.emit("POP_TOP")
code_gen.nextBlock(endblock)
class IsInstanceFunction(Object[Class]):
def __init__(self) -> None:
super().__init__(FUNCTION_TYPE)
@property
def name(self) -> str:
return "isinstance function"
def bind_call(
self, node: ast.Call, visitor: TypeBinder, type_ctx: Optional[Class]
) -> NarrowingEffect:
if node.keywords:
visitor.syntax_error("isinstance() does not accept keyword arguments", node)
for arg in node.args:
visitor.visit(arg)
self.check_args_for_primitives(node, visitor)
visitor.set_type(node, BOOL_TYPE.instance)
if len(node.args) == 2:
arg0 = node.args[0]
if not isinstance(arg0, ast.Name):
return NO_EFFECT
arg1 = node.args[1]
klass_type = None
if isinstance(arg1, ast.Tuple):
types = tuple(visitor.get_type(el) for el in arg1.elts)
if all(isinstance(t, Class) for t in types):
klass_type = UNION_TYPE.make_generic_type(
types, visitor.symtable.generic_types
)
else:
arg1_type = visitor.get_type(node.args[1])
if isinstance(arg1_type, Class):
klass_type = inexact(arg1_type)
if klass_type is not None:
return IsInstanceEffect(
arg0.id,
visitor.get_type(arg0),
inexact(klass_type.instance),
visitor,
)
return NO_EFFECT
class IsSubclassFunction(Object[Class]):
def __init__(self) -> None:
super().__init__(FUNCTION_TYPE)
@property
def name(self) -> str:
return "issubclass function"
def bind_call(
self, node: ast.Call, visitor: TypeBinder, type_ctx: Optional[Class]
) -> NarrowingEffect:
if node.keywords:
raise visitor.syntax_error(
"issubclass() does not accept keyword arguments", node
)
for arg in node.args:
visitor.visit(arg)
visitor.set_type(node, BOOL_TYPE.instance)
self.check_args_for_primitives(node, visitor)
return NO_EFFECT
class RevealTypeFunction(Object[Class]):
def __init__(self) -> None:
super().__init__(FUNCTION_TYPE)
@property
def name(self) -> str:
return "reveal_type function"
def bind_call(
self, node: ast.Call, visitor: TypeBinder, type_ctx: Optional[Class]
) -> NarrowingEffect:
if node.keywords:
raise visitor.syntax_error(
"reveal_type() does not accept keyword arguments", node
)
if len(node.args) != 1:
raise visitor.syntax_error(
"reveal_type() accepts exactly one argument", node
)
arg = node.args[0]
visitor.visit(arg)
arg_type = visitor.get_type(arg)
msg = f"reveal_type({to_expr(arg)}): '{arg_type.name}'"
if isinstance(arg, ast.Name) and arg.id in visitor.decl_types:
decl_type = visitor.decl_types[arg.id].type
local_type = visitor.local_types[arg.id]
msg += f", '{arg.id}' has declared type '{decl_type.name}' and local type '{local_type.name}'"
raise visitor.syntax_error(msg, node)
return NO_EFFECT
class NumClass(Class):
def __init__(
self,
name: TypeName,
pytype: Optional[Type[object]] = None,
is_exact: bool = False,
literal_value: Optional[int] = None,
) -> None:
bases: List[Class] = [OBJECT_TYPE]
if literal_value is not None:
is_exact = True
bases = [INT_EXACT_TYPE]
instance = NumExactInstance(self) if is_exact else NumInstance(self)
super().__init__(
name,
bases,
instance,
pytype=pytype,
is_exact=is_exact,
)
self.literal_value = literal_value
def can_assign_from(self, src: Class) -> bool:
if isinstance(src, NumClass):
if self.literal_value is not None:
return src.literal_value == self.literal_value
if self.is_exact and src.is_exact and self.type_descr == src.type_descr:
return True
return super().can_assign_from(src)
class NumInstance(Object[NumClass]):
def bind_unaryop(
self, node: ast.UnaryOp, visitor: TypeBinder, type_ctx: Optional[Class]
) -> None:
if isinstance(node.op, (ast.USub, ast.Invert, ast.UAdd)):
visitor.set_type(node, self)
else:
assert isinstance(node.op, ast.Not)
visitor.set_type(node, BOOL_TYPE.instance)
def bind_constant(self, node: ast.Constant, visitor: TypeBinder) -> None:
self._bind_constant(node.value, node, visitor)
def _bind_constant(
self, value: object, node: ast.expr, visitor: TypeBinder
) -> None:
value_inst = CONSTANT_TYPES.get(type(value), self)
visitor.set_type(node, value_inst)
visitor.check_can_assign_from(self.klass, value_inst.klass, node)
class NumExactInstance(NumInstance):
@property
def name(self) -> str:
if self.klass.literal_value is not None:
return f"Literal[{self.klass.literal_value}]"
return super().name
def bind_binop(
self, node: ast.BinOp, visitor: TypeBinder, type_ctx: Optional[Class]
) -> bool:
ltype = visitor.get_type(node.left)
rtype = visitor.get_type(node.right)
if INT_EXACT_TYPE.can_assign_from(
ltype.klass
) and INT_EXACT_TYPE.can_assign_from(rtype.klass):
if isinstance(node.op, ast.Div):
visitor.set_type(node, FLOAT_EXACT_TYPE.instance)
else:
visitor.set_type(node, INT_EXACT_TYPE.instance)
return True
return False
def parse_param(info: Dict[str, object], idx: int) -> Parameter:
name = info.get("name", "")
assert isinstance(name, str)
return Parameter(
name,
idx,
ResolvedTypeRef(parse_type(info)),
"default" in info,
info.get("default"),
False,
)
def parse_typed_signature(
sig: Dict[str, object], klass: Optional[Class] = None
) -> Tuple[Tuple[Parameter, ...], Class]:
args = sig["args"]
assert isinstance(args, list)
if klass is not None:
signature = [Parameter("self", 0, ResolvedTypeRef(klass), False, None, False)]
else:
signature = []
for idx, arg in enumerate(args):
signature.append(parse_param(arg, idx + 1))
return_info = sig["return"]
assert isinstance(return_info, dict)
return_type = parse_type(return_info)
return tuple(signature), return_type
def reflect_builtin_function(obj: BuiltinFunctionType) -> BuiltinFunction:
sig = getattr(obj, "__typed_signature__", None)
if sig is not None:
signature, return_type = parse_typed_signature(sig)
method = BuiltinFunction(
obj.__name__,
obj.__module__,
signature,
ResolvedTypeRef(return_type),
)
else:
method = BuiltinFunction(obj.__name__, obj.__module__)
return method
def reflect_method_desc(
obj: MethodDescriptorType, klass: Class
) -> BuiltinMethodDescriptor:
sig = getattr(obj, "__typed_signature__", None)
if sig is not None:
signature, return_type = parse_typed_signature(sig, klass)
method = BuiltinMethodDescriptor(
obj.__name__,
klass,
signature,
ResolvedTypeRef(return_type),
)
else:
method = BuiltinMethodDescriptor(obj.__name__, klass)
return method
def make_type_dict(klass: Class, t: Type[object]) -> Dict[str, Value]:
ret: Dict[str, Value] = {}
for k in t.__dict__.keys():
obj = getattr(t, k)
if isinstance(obj, MethodDescriptorType):
ret[k] = reflect_method_desc(obj, klass)
return ret
def common_sequence_emit_len(
node: ast.Call, code_gen: Static38CodeGenerator, oparg: int, boxed: bool
) -> None:
if len(node.args) != 1:
raise code_gen.syntax_error(
f"Can only pass a single argument when checking sequence length", node
)
code_gen.visit(node.args[0])
code_gen.emit("FAST_LEN", oparg)
if boxed:
signed = True
code_gen.emit("PRIMITIVE_BOX", int(signed))
def common_sequence_emit_jumpif(
test: AST,
next: Block,
is_if_true: bool,
code_gen: Static38CodeGenerator,
oparg: int,
) -> None:
code_gen.visit(test)
code_gen.emit("FAST_LEN", oparg)
code_gen.emit("POP_JUMP_IF_NONZERO" if is_if_true else "POP_JUMP_IF_ZERO", next)
def common_sequence_emit_forloop(
node: ast.For, code_gen: Static38CodeGenerator, oparg: int
) -> None:
descr = ("__static__", "int64")
start = code_gen.newBlock(f"seq_forloop_start")
anchor = code_gen.newBlock(f"seq_forloop_anchor")
after = code_gen.newBlock(f"seq_forloop_after")
with code_gen.new_loopidx() as loop_idx:
code_gen.set_lineno(node)
code_gen.push_loop(FOR_LOOP, start, after)
code_gen.visit(node.iter)
code_gen.emit("PRIMITIVE_LOAD_CONST", (0, TYPED_INT64))
code_gen.emit("STORE_LOCAL", (loop_idx, descr))
code_gen.nextBlock(start)
code_gen.emit("DUP_TOP") # used for SEQUENCE_GET
code_gen.emit("DUP_TOP") # used for FAST_LEN
code_gen.emit("FAST_LEN", oparg)
code_gen.emit("LOAD_LOCAL", (loop_idx, descr))
code_gen.emit("INT_COMPARE_OP", PRIM_OP_GT_INT)
code_gen.emit("POP_JUMP_IF_ZERO", anchor)
code_gen.emit("LOAD_LOCAL", (loop_idx, descr))
if oparg == FAST_LEN_LIST:
code_gen.emit("SEQUENCE_GET", SEQ_LIST | SEQ_SUBSCR_UNCHECKED)
else:
# todo - we need to implement TUPLE_GET which supports primitive index
code_gen.emit("PRIMITIVE_BOX", 1) # 1 is for signed
code_gen.emit("BINARY_SUBSCR", 2)
code_gen.emit("LOAD_LOCAL", (loop_idx, descr))
code_gen.emit("PRIMITIVE_LOAD_CONST", (1, TYPED_INT64))
code_gen.emit("PRIMITIVE_BINARY_OP", PRIM_OP_ADD_INT)
code_gen.emit("STORE_LOCAL", (loop_idx, descr))
code_gen.visit(node.target)
code_gen.visit(node.body)
code_gen.emit("JUMP_ABSOLUTE", start)
code_gen.nextBlock(anchor)
code_gen.emit("POP_TOP") # Pop loop index
code_gen.emit("POP_TOP") # Pop list
code_gen.pop_loop()
if node.orelse:
code_gen.visit(node.orelse)
code_gen.nextBlock(after)
class TupleClass(Class):
def __init__(self, is_exact: bool = False) -> None:
instance = TupleExactInstance(self) if is_exact else TupleInstance(self)
super().__init__(
type_name=TypeName("builtins", "tuple"),
bases=[OBJECT_TYPE],
instance=instance,
is_exact=is_exact,
pytype=tuple,
)
class TupleInstance(Object[TupleClass]):
def get_fast_len_type(self) -> int:
return FAST_LEN_TUPLE | ((not self.klass.is_exact) << 4)
def emit_len(
self, node: ast.Call, code_gen: Static38CodeGenerator, boxed: bool
) -> None:
return common_sequence_emit_len(
node, code_gen, self.get_fast_len_type(), boxed=boxed
)
def emit_jumpif(
self, test: AST, next: Block, is_if_true: bool, code_gen: Static38CodeGenerator
) -> None:
return common_sequence_emit_jumpif(
test, next, is_if_true, code_gen, self.get_fast_len_type()
)
def emit_binop(self, node: ast.BinOp, code_gen: Static38CodeGenerator) -> None:
if maybe_emit_sequence_repeat(node, code_gen):
return
code_gen.defaultVisit(node)
class TupleExactInstance(TupleInstance):
def bind_binop(
self, node: ast.BinOp, visitor: TypeBinder, type_ctx: Optional[Class]
) -> bool:
rtype = visitor.get_type(node.right).klass
if isinstance(node.op, ast.Mult) and (
INT_TYPE.can_assign_from(rtype) or rtype in SIGNED_CINT_TYPES
):
visitor.set_type(node, TUPLE_EXACT_TYPE.instance)
return True
return super().bind_binop(node, visitor, type_ctx)
def bind_reverse_binop(
self, node: ast.BinOp, visitor: TypeBinder, type_ctx: Optional[Class]
) -> bool:
ltype = visitor.get_type(node.left).klass
if isinstance(node.op, ast.Mult) and (
INT_TYPE.can_assign_from(ltype) or ltype in SIGNED_CINT_TYPES
):
visitor.set_type(node, TUPLE_EXACT_TYPE.instance)
return True
return super().bind_reverse_binop(node, visitor, type_ctx)
def emit_forloop(self, node: ast.For, code_gen: Static38CodeGenerator) -> None:
if not isinstance(node.target, ast.Name):
# We don't yet support `for a, b in my_tuple: ...`
return super().emit_forloop(node, code_gen)
return common_sequence_emit_forloop(node, code_gen, FAST_LEN_TUPLE)
class SetClass(Class):
def __init__(self, is_exact: bool = False) -> None:
super().__init__(
type_name=TypeName("builtins", "set"),
bases=[OBJECT_TYPE],
instance=SetInstance(self),
is_exact=is_exact,
pytype=tuple,
)
class SetInstance(Object[SetClass]):
def get_fast_len_type(self) -> int:
return FAST_LEN_SET | ((not self.klass.is_exact) << 4)
def emit_len(
self, node: ast.Call, code_gen: Static38CodeGenerator, boxed: bool
) -> None:
if len(node.args) != 1:
raise code_gen.syntax_error(
"Can only pass a single argument when checking set length", node
)
code_gen.visit(node.args[0])
code_gen.emit("FAST_LEN", self.get_fast_len_type())
if boxed:
signed = True
code_gen.emit("PRIMITIVE_BOX", int(signed))
def emit_jumpif(
self, test: AST, next: Block, is_if_true: bool, code_gen: Static38CodeGenerator
) -> None:
code_gen.visit(test)
code_gen.emit("FAST_LEN", self.get_fast_len_type())
code_gen.emit("POP_JUMP_IF_NONZERO" if is_if_true else "POP_JUMP_IF_ZERO", next)
def maybe_emit_sequence_repeat(
node: ast.BinOp, code_gen: Static38CodeGenerator
) -> bool:
if not isinstance(node.op, ast.Mult):
return False
for seq, num, rev in [
(node.left, node.right, 0),
(node.right, node.left, SEQ_REPEAT_REVERSED),
]:
seq_type = code_gen.get_type(seq).klass
num_type = code_gen.get_type(num).klass
oparg = None
if TUPLE_TYPE.can_assign_from(seq_type):
oparg = SEQ_TUPLE
elif LIST_TYPE.can_assign_from(seq_type):
oparg = SEQ_LIST
if oparg is None:
continue
if num_type in SIGNED_CINT_TYPES:
oparg |= SEQ_REPEAT_PRIMITIVE_NUM
elif not INT_TYPE.can_assign_from(num_type):
continue
if not seq_type.is_exact:
oparg |= SEQ_REPEAT_INEXACT_SEQ
if not num_type.is_exact:
oparg |= SEQ_REPEAT_INEXACT_NUM
oparg |= rev
code_gen.visit(seq)
code_gen.visit(num)
code_gen.emit("SEQUENCE_REPEAT", oparg)
return True
return False
class ListAppendMethod(BuiltinMethodDescriptor):
def bind_descr_get(
self,
node: ast.Attribute,
inst: Optional[Object[TClassInv]],
ctx: TClassInv,
visitor: TypeBinder,
type_ctx: Optional[Class],
) -> None:
if inst is None:
visitor.set_type(node, self)
else:
visitor.set_type(node, ListAppendBuiltinMethod(self, node))
class ListAppendBuiltinMethod(BuiltinMethod):
def emit_call(self, node: ast.Call, code_gen: Static38CodeGenerator) -> None:
if len(node.args) == 1 and not node.keywords:
code_gen.visit(self.target.value)
code_gen.visit(node.args[0])
code_gen.emit("LIST_APPEND", 1)
return
return super().emit_call(node, code_gen)
class ListClass(Class):
def __init__(self, is_exact: bool = False) -> None:
instance = ListExactInstance(self) if is_exact else ListInstance(self)
super().__init__(
type_name=TypeName("builtins", "list"),
bases=[OBJECT_TYPE],
instance=instance,
is_exact=is_exact,
pytype=list,
)
if is_exact:
self.members["append"] = ListAppendMethod("append", self)
class ListInstance(Object[ListClass]):
def get_fast_len_type(self) -> int:
return FAST_LEN_LIST | ((not self.klass.is_exact) << 4)
def get_subscr_type(self) -> int:
return SEQ_LIST_INEXACT
def emit_len(
self, node: ast.Call, code_gen: Static38CodeGenerator, boxed: bool
) -> None:
return common_sequence_emit_len(
node, code_gen, self.get_fast_len_type(), boxed=boxed
)
def emit_jumpif(
self, test: AST, next: Block, is_if_true: bool, code_gen: Static38CodeGenerator
) -> None:
return common_sequence_emit_jumpif(
test, next, is_if_true, code_gen, self.get_fast_len_type()
)
def bind_subscr(
self, node: ast.Subscript, type: Value, visitor: TypeBinder
) -> None:
if type.klass not in SIGNED_CINT_TYPES:
super().bind_subscr(node, type, visitor)
visitor.set_type(node, DYNAMIC)
def emit_subscr(
self, node: ast.Subscript, aug_flag: bool, code_gen: Static38CodeGenerator
) -> None:
index_type = code_gen.get_type(node.slice)
if index_type.klass not in SIGNED_CINT_TYPES:
return super().emit_subscr(node, aug_flag, code_gen)
code_gen.update_lineno(node)
code_gen.visit(node.value)
code_gen.visit(node.slice)
if isinstance(node.ctx, ast.Load):
code_gen.emit("SEQUENCE_GET", self.get_subscr_type())
elif isinstance(node.ctx, ast.Store):
code_gen.emit("SEQUENCE_SET", self.get_subscr_type())
elif isinstance(node.ctx, ast.Del):
code_gen.emit("LIST_DEL")
def emit_binop(self, node: ast.BinOp, code_gen: Static38CodeGenerator) -> None:
if maybe_emit_sequence_repeat(node, code_gen):
return
code_gen.defaultVisit(node)
class ListExactInstance(ListInstance):
def get_subscr_type(self) -> int:
return SEQ_LIST
def bind_binop(
self, node: ast.BinOp, visitor: TypeBinder, type_ctx: Optional[Class]
) -> bool:
rtype = visitor.get_type(node.right).klass
if isinstance(node.op, ast.Mult) and (
INT_TYPE.can_assign_from(rtype) or rtype in SIGNED_CINT_TYPES
):
visitor.set_type(node, LIST_EXACT_TYPE.instance)
return True
return super().bind_binop(node, visitor, type_ctx)
def bind_reverse_binop(
self, node: ast.BinOp, visitor: TypeBinder, type_ctx: Optional[Class]
) -> bool:
ltype = visitor.get_type(node.left).klass
if isinstance(node.op, ast.Mult) and (
INT_TYPE.can_assign_from(ltype) or ltype in SIGNED_CINT_TYPES
):
visitor.set_type(node, LIST_EXACT_TYPE.instance)
return True
return super().bind_reverse_binop(node, visitor, type_ctx)
def emit_forloop(self, node: ast.For, code_gen: Static38CodeGenerator) -> None:
if not isinstance(node.target, ast.Name):
# We don't yet support `for a, b in my_list: ...`
return super().emit_forloop(node, code_gen)
return common_sequence_emit_forloop(node, code_gen, FAST_LEN_LIST)
class StrClass(Class):
def __init__(self, is_exact: bool = False) -> None:
super().__init__(
type_name=TypeName("builtins", "str"),
bases=[OBJECT_TYPE],
instance=StrInstance(self),
is_exact=is_exact,
pytype=str,
)
class StrInstance(Object[StrClass]):
def get_fast_len_type(self) -> int:
return FAST_LEN_STR | ((not self.klass.is_exact) << 4)
def emit_len(
self, node: ast.Call, code_gen: Static38CodeGenerator, boxed: bool
) -> None:
return common_sequence_emit_len(
node, code_gen, self.get_fast_len_type(), boxed=boxed
)
def emit_jumpif(
self, test: AST, next: Block, is_if_true: bool, code_gen: Static38CodeGenerator
) -> None:
return common_sequence_emit_jumpif(
test, next, is_if_true, code_gen, self.get_fast_len_type()
)
class DictClass(Class):
def __init__(self, is_exact: bool = False) -> None:
super().__init__(
type_name=TypeName("builtins", "dict"),
bases=[OBJECT_TYPE],
instance=DictInstance(self),
is_exact=is_exact,
pytype=dict,
)
class DictInstance(Object[DictClass]):
def get_fast_len_type(self) -> int:
return FAST_LEN_DICT | ((not self.klass.is_exact) << 4)
def emit_len(
self, node: ast.Call, code_gen: Static38CodeGenerator, boxed: bool
) -> None:
if len(node.args) != 1:
raise code_gen.syntax_error(
"Can only pass a single argument when checking dict length", node
)
code_gen.visit(node.args[0])
code_gen.emit("FAST_LEN", self.get_fast_len_type())
if boxed:
signed = True
code_gen.emit("PRIMITIVE_BOX", int(signed))
def emit_jumpif(
self, test: AST, next: Block, is_if_true: bool, code_gen: Static38CodeGenerator
) -> None:
code_gen.visit(test)
code_gen.emit("FAST_LEN", self.get_fast_len_type())
code_gen.emit("POP_JUMP_IF_NONZERO" if is_if_true else "POP_JUMP_IF_ZERO", next)
FUNCTION_TYPE = Class(TypeName("types", "FunctionType"))
METHOD_TYPE = Class(TypeName("types", "MethodType"))
MEMBER_TYPE = Class(TypeName("types", "MemberDescriptorType"))
BUILTIN_METHOD_DESC_TYPE = Class(TypeName("types", "MethodDescriptorType"))
BUILTIN_METHOD_TYPE = Class(TypeName("types", "BuiltinMethodType"))
ARG_TYPE = Class(TypeName("builtins", "arg"))
SLICE_TYPE = Class(TypeName("builtins", "slice"))
# builtin types
NONE_TYPE = NoneType()
STR_TYPE = StrClass()
STR_EXACT_TYPE = StrClass(is_exact=True)
INT_TYPE = NumClass(TypeName("builtins", "int"), pytype=int)
INT_EXACT_TYPE = NumClass(TypeName("builtins", "int"), pytype=int, is_exact=True)
FLOAT_TYPE = NumClass(TypeName("builtins", "float"), pytype=float)
FLOAT_EXACT_TYPE = NumClass(TypeName("builtins", "float"), pytype=float, is_exact=True)
COMPLEX_TYPE = NumClass(TypeName("builtins", "complex"), pytype=complex)
COMPLEX_EXACT_TYPE = NumClass(
TypeName("builtins", "complex"), pytype=complex, is_exact=True
)
BYTES_TYPE = Class(TypeName("builtins", "bytes"), [OBJECT_TYPE], pytype=bytes)
BOOL_TYPE = Class(TypeName("builtins", "bool"), [OBJECT_TYPE], pytype=bool)
ELLIPSIS_TYPE = Class(TypeName("builtins", "ellipsis"), [OBJECT_TYPE], pytype=type(...))
DICT_TYPE = DictClass(is_exact=False)
DICT_EXACT_TYPE = DictClass(is_exact=True)
TUPLE_TYPE = TupleClass()
TUPLE_EXACT_TYPE = TupleClass(is_exact=True)
SET_TYPE = SetClass()
SET_EXACT_TYPE = SetClass(is_exact=True)
LIST_TYPE = ListClass()
LIST_EXACT_TYPE = ListClass(is_exact=True)
BASE_EXCEPTION_TYPE = Class(TypeName("builtins", "BaseException"), pytype=BaseException)
EXCEPTION_TYPE = Class(
TypeName("builtins", "Exception"),
bases=[BASE_EXCEPTION_TYPE],
pytype=Exception,
)
STATIC_METHOD_TYPE = StaticMethodDecorator(
TypeName("builtins", "staticmethod"),
bases=[OBJECT_TYPE],
pytype=staticmethod,
)
FINAL_METHOD_TYPE = TypingFinalDecorator(TypeName("typing", "final"))
ALLOW_WEAKREFS_TYPE = AllowWeakrefsDecorator(TypeName("__static__", "allow_weakrefs"))
DYNAMIC_RETURN_TYPE = DynamicReturnDecorator(TypeName("__static__", "dynamic_return"))
INLINE_TYPE = InlineFunctionDecorator(TypeName("__static__", "inline"))
DONOTCOMPILE_TYPE = DoNotCompileDecorator(TypeName("__static__", "_donotcompile"))
RESOLVED_INT_TYPE = ResolvedTypeRef(INT_TYPE)
RESOLVED_STR_TYPE = ResolvedTypeRef(STR_TYPE)
RESOLVED_NONE_TYPE = ResolvedTypeRef(NONE_TYPE)
TYPE_TYPE.bases = [OBJECT_TYPE]
CONSTANT_TYPES: Mapping[Type[object], Value] = {
str: STR_EXACT_TYPE.instance,
int: INT_EXACT_TYPE.instance,
float: FLOAT_EXACT_TYPE.instance,
complex: COMPLEX_EXACT_TYPE.instance,
bytes: BYTES_TYPE.instance,
bool: BOOL_TYPE.instance,
type(None): NONE_TYPE.instance,
tuple: TUPLE_EXACT_TYPE.instance,
type(...): ELLIPSIS_TYPE.instance,
}
NAMED_TUPLE_TYPE = Class(TypeName("typing", "NamedTuple"))
class FinalClass(GenericClass):
is_variadic = True
def make_generic_type(
self,
index: Tuple[Class, ...],
generic_types: GenericTypesDict,
) -> Class:
if len(index) > 1:
raise TypedSyntaxError(
f"Final types can only have a single type arg. Given: {str(index)}"
)
return super(FinalClass, self).make_generic_type(index, generic_types)
def inner_type(self) -> Class:
if self.type_args:
return self.type_args[0]
else:
return DYNAMIC_TYPE
class UnionTypeName(GenericTypeName):
@property
def opt_type(self) -> Optional[Class]:
"""If we're an Optional (i.e. Union[T, None]), return T, otherwise None."""
# Assumes well-formed union (no duplicate elements, >1 element)
opt_type = None
if len(self.args) == 2:
if self.args[0] is NONE_TYPE:
opt_type = self.args[1]
elif self.args[1] is NONE_TYPE:
opt_type = self.args[0]
return opt_type
@property
def type_descr(self) -> TypeDescr:
opt_type = self.opt_type
if opt_type is not None:
return opt_type.type_descr + ("?",)
# the runtime does not support unions beyond optional, so just fall back
# to dynamic for runtime purposes
return DYNAMIC_TYPE.type_descr
@property
def friendly_name(self) -> str:
opt_type = self.opt_type
if opt_type is not None:
return f"Optional[{opt_type.instance.name}]"
return super().friendly_name
class UnionType(GenericClass):
type_name: UnionTypeName
# Union is a variadic generic, so we don't give the unbound Union any
# GenericParameters, and we allow it to accept any number of type args.
is_variadic = True
def __init__(
self,
type_name: Optional[UnionTypeName] = None,
type_def: Optional[GenericClass] = None,
instance_type: Optional[Type[Object[Class]]] = None,
generic_types: Optional[GenericTypesDict] = None,
) -> None:
instance_type = instance_type or UnionInstance
super().__init__(
type_name or UnionTypeName("typing", "Union", ()),
bases=[],
instance=instance_type(self),
type_def=type_def,
)
self.generic_types = generic_types
@property
def opt_type(self) -> Optional[Class]:
return self.type_name.opt_type
def issubclass(self, src: Class) -> bool:
if isinstance(src, UnionType):
return all(self.issubclass(t) for t in src.type_args)
return any(t.issubclass(src) for t in self.type_args)
def make_generic_type(
self,
index: Tuple[Class, ...],
generic_types: GenericTypesDict,
) -> Class:
instantiations = generic_types.get(self)
if instantiations is not None:
instance = instantiations.get(index)
if instance is not None:
return instance
else:
generic_types[self] = instantiations = {}
type_args = self._simplify_args(index)
if len(type_args) == 1 and not type_args[0].is_generic_parameter:
return type_args[0]
type_name = UnionTypeName(self.type_name.module, self.type_name.name, type_args)
if any(isinstance(a, CType) for a in type_args):
raise TypedSyntaxError(
f"invalid union type {type_name.friendly_name}; unions cannot include primitive types"
)
ThisUnionType = type(self)
if type_name.opt_type is not None:
ThisUnionType = OptionalType
instantiations[index] = concrete = ThisUnionType(
type_name,
type_def=self,
generic_types=generic_types,
)
return concrete
def _simplify_args(self, args: Sequence[Class]) -> Tuple[Class, ...]:
args = self._flatten_args(args)
remove = set()
for i, arg1 in enumerate(args):
if i in remove:
continue
for j, arg2 in enumerate(args):
# TODO this should be is_subtype_of once we split that from can_assign_from
if i != j and arg1.can_assign_from(arg2):
remove.add(j)
return tuple(arg for i, arg in enumerate(args) if i not in remove)
def _flatten_args(self, args: Sequence[Class]) -> Sequence[Class]:
new_args = []
for arg in args:
if isinstance(arg, UnionType):
new_args.extend(self._flatten_args(arg.type_args))
else:
new_args.append(arg)
return new_args
class UnionInstance(Object[UnionType]):
def _generic_bind(
self,
node: ast.AST,
callback: typingCallable[[Class], object],
description: str,
visitor: TypeBinder,
) -> List[object]:
if self.klass.is_generic_type_definition:
raise visitor.syntax_error(f"cannot {description} unbound Union", node)
result_types: List[Class] = []
ret_types: List[object] = []
try:
for el in self.klass.type_args:
ret_types.append(callback(el))
result_types.append(visitor.get_type(node).klass)
except TypedSyntaxError as e:
raise visitor.syntax_error(f"{self.name}: {e.msg}", node)
union = UNION_TYPE.make_generic_type(
tuple(result_types), visitor.symtable.generic_types
)
visitor.set_type(node, union.instance)
return ret_types
def bind_attr(
self, node: ast.Attribute, visitor: TypeBinder, type_ctx: Optional[Class]
) -> None:
def cb(el: Class) -> None:
return el.instance.bind_attr(node, visitor, type_ctx)
self._generic_bind(
node,
cb,
"access attribute from",
visitor,
)
def bind_call(
self, node: ast.Call, visitor: TypeBinder, type_ctx: Optional[Class]
) -> NarrowingEffect:
def cb(el: Class) -> NarrowingEffect:
return el.instance.bind_call(node, visitor, type_ctx)
self._generic_bind(node, cb, "call", visitor)
return NO_EFFECT
def bind_subscr(
self, node: ast.Subscript, type: Value, visitor: TypeBinder
) -> None:
def cb(el: Class) -> None:
return el.instance.bind_subscr(node, type, visitor)
self._generic_bind(node, cb, "subscript", visitor)
def bind_unaryop(
self, node: ast.UnaryOp, visitor: TypeBinder, type_ctx: Optional[Class]
) -> None:
def cb(el: Class) -> None:
return el.instance.bind_unaryop(node, visitor, type_ctx)
self._generic_bind(
node,
cb,
"unary op",
visitor,
)
def bind_compare(
self,
node: ast.Compare,
left: expr,
op: cmpop,
right: expr,
visitor: TypeBinder,
type_ctx: Optional[Class],
) -> bool:
def cb(el: Class) -> bool:
return el.instance.bind_compare(node, left, op, right, visitor, type_ctx)
rets = self._generic_bind(node, cb, "compare", visitor)
return all(rets)
def bind_reverse_compare(
self,
node: ast.Compare,
left: expr,
op: cmpop,
right: expr,
visitor: TypeBinder,
type_ctx: Optional[Class],
) -> bool:
def cb(el: Class) -> bool:
return el.instance.bind_reverse_compare(
node, left, op, right, visitor, type_ctx
)
rets = self._generic_bind(node, cb, "compare", visitor)
return all(rets)
class OptionalType(UnionType):
"""UnionType for instantiations with [T, None], and to support Optional[T] special form."""
is_variadic = False
def __init__(
self,
type_name: Optional[UnionTypeName] = None,
type_def: Optional[GenericClass] = None,
generic_types: Optional[GenericTypesDict] = None,
) -> None:
super().__init__(
type_name
or UnionTypeName("typing", "Optional", (GenericParameter("T", 0),)),
type_def=type_def,
instance_type=OptionalInstance,
generic_types=generic_types,
)
@property
def opt_type(self) -> Class:
opt_type = self.type_name.opt_type
if opt_type is None:
params = ", ".join(t.name for t in self.type_args)
raise TypeError(f"OptionalType has invalid type parameters {params}")
return opt_type
def make_generic_type(
self, index: Tuple[Class, ...], generic_types: GenericTypesDict
) -> Class:
assert len(index) == 1
if not index[0].is_generic_parameter:
# Optional[T] is syntactic sugar for Union[T, None]
index = index + (NONE_TYPE,)
return super().make_generic_type(index, generic_types)
class OptionalInstance(UnionInstance):
"""Only exists for typing purposes (so we know .klass is OptionalType)."""
klass: OptionalType
class ArrayInstance(Object["ArrayClass"]):
def _seq_type(self) -> int:
idx = self.klass.index
if not isinstance(idx, CIntType):
# should never happen
raise SyntaxError(f"Invalid Array type: {idx}")
size = idx.size
if size == 0:
return SEQ_ARRAY_INT8 if idx.signed else SEQ_ARRAY_UINT8
elif size == 1:
return SEQ_ARRAY_INT16 if idx.signed else SEQ_ARRAY_UINT16
elif size == 2:
return SEQ_ARRAY_INT32 if idx.signed else SEQ_ARRAY_UINT32
elif size == 3:
return SEQ_ARRAY_INT64 if idx.signed else SEQ_ARRAY_UINT64
else:
raise SyntaxError(f"Invalid Array size: {size}")
def bind_subscr(
self, node: ast.Subscript, type: Value, visitor: TypeBinder
) -> None:
if type == SLICE_TYPE.instance:
# Slicing preserves type
return visitor.set_type(node, self)
visitor.set_type(node, self.klass.index.instance)
def emit_subscr(
self, node: ast.Subscript, aug_flag: bool, code_gen: Static38CodeGenerator
) -> None:
index_type = code_gen.get_type(node.slice)
is_del = isinstance(node.ctx, ast.Del)
index_is_python_int = INT_TYPE.can_assign_from(index_type.klass)
index_is_primitive_int = isinstance(index_type.klass, CIntType)
# ARRAY_{GET,SET} support only integer indices and don't support del;
# otherwise defer to the usual bytecode
if is_del or not (index_is_python_int or index_is_primitive_int):
return super().emit_subscr(node, aug_flag, code_gen)
code_gen.update_lineno(node)
code_gen.visit(node.value)
code_gen.visit(node.slice)
if index_is_python_int:
# If the index is not a primitive, unbox its value to an int64, our implementation of
# SEQUENCE_{GET/SET} expects the index to be a primitive int.
code_gen.emit("PRIMITIVE_UNBOX", INT64_TYPE.instance.as_oparg())
if isinstance(node.ctx, ast.Store) and not aug_flag:
code_gen.emit("SEQUENCE_SET", self._seq_type())
elif isinstance(node.ctx, ast.Load) or aug_flag:
if aug_flag:
code_gen.emit("DUP_TOP_TWO")
code_gen.emit("SEQUENCE_GET", self._seq_type())
def emit_store_subscr(
self, node: ast.Subscript, code_gen: Static38CodeGenerator
) -> None:
code_gen.emit("ROT_THREE")
code_gen.emit("SEQUENCE_SET", self._seq_type())
def __repr__(self) -> str:
return f"{self.klass.type_name.name}[{self.klass.index.name!r}]"
def get_fast_len_type(self) -> int:
return FAST_LEN_ARRAY | ((not self.klass.is_exact) << 4)
def emit_len(
self, node: ast.Call, code_gen: Static38CodeGenerator, boxed: bool
) -> None:
if len(node.args) != 1:
raise code_gen.syntax_error(
"Can only pass a single argument when checking array length", node
)
code_gen.visit(node.args[0])
code_gen.emit("FAST_LEN", self.get_fast_len_type())
if boxed:
signed = True
code_gen.emit("PRIMITIVE_BOX", int(signed))
class ArrayClass(GenericClass):
def __init__(
self,
name: GenericTypeName,
bases: Optional[List[Class]] = None,
instance: Optional[Object[Class]] = None,
klass: Optional[Class] = None,
members: Optional[Dict[str, Value]] = None,
type_def: Optional[GenericClass] = None,
is_exact: bool = False,
pytype: Optional[Type[object]] = None,
) -> None:
default_bases: List[Class] = [OBJECT_TYPE]
default_instance: Object[Class] = ArrayInstance(self)
super().__init__(
name,
bases or default_bases,
instance or default_instance,
klass,
members,
type_def,
is_exact,
pytype,
)
@property
def index(self) -> Class:
return self.type_args[0]
def make_generic_type(
self, index: Tuple[Class, ...], generic_types: GenericTypesDict
) -> Class:
for tp in index:
if tp not in ALLOWED_ARRAY_TYPES:
raise TypedSyntaxError(
f"Invalid {self.gen_name.name} element type: {tp.instance.name}"
)
return super().make_generic_type(index, generic_types)
class VectorClass(ArrayClass):
def __init__(
self,
name: GenericTypeName,
bases: Optional[List[Class]] = None,
instance: Optional[Object[Class]] = None,
klass: Optional[Class] = None,
members: Optional[Dict[str, Value]] = None,
type_def: Optional[GenericClass] = None,
is_exact: bool = False,
pytype: Optional[Type[object]] = None,
) -> None:
super().__init__(
name,
bases,
instance,
klass,
members,
type_def,
is_exact,
pytype,
)
self.members["append"] = BuiltinMethodDescriptor(
"append",
self,
(
Parameter("self", 0, ResolvedTypeRef(self), False, None, False),
Parameter(
"v",
0,
ResolvedTypeRef(VECTOR_TYPE_PARAM),
False,
None,
False,
),
),
)
BUILTIN_GENERICS: Dict[Class, Dict[GenericTypeIndex, Class]] = {}
UNION_TYPE = UnionType()
OPTIONAL_TYPE = OptionalType()
FINAL_TYPE = FinalClass(GenericTypeName("typing", "Final", ()))
CHECKED_DICT_TYPE_NAME = GenericTypeName(
"__static__", "chkdict", (GenericParameter("K", 0), GenericParameter("V", 1))
)
class CheckedDict(GenericClass):
def __init__(
self,
name: GenericTypeName,
bases: Optional[List[Class]] = None,
instance: Optional[Object[Class]] = None,
klass: Optional[Class] = None,
members: Optional[Dict[str, Value]] = None,
type_def: Optional[GenericClass] = None,
is_exact: bool = False,
pytype: Optional[Type[object]] = None,
) -> None:
if instance is None:
instance = CheckedDictInstance(self)
super().__init__(
name,
bases,
instance,
klass,
members,
type_def,
is_exact,
pytype,
)
class CheckedDictInstance(Object[CheckedDict]):
def bind_subscr(
self, node: ast.Subscript, type: Value, visitor: TypeBinder
) -> None:
visitor.visit(node.slice, self.klass.gen_name.args[0].instance)
visitor.set_type(node, self.klass.gen_name.args[1].instance)
def emit_subscr(
self, node: ast.Subscript, aug_flag: bool, code_gen: Static38CodeGenerator
) -> None:
if isinstance(node.ctx, ast.Load):
code_gen.visit(node.value)
code_gen.visit(node.slice)
dict_descr = self.klass.type_descr
update_descr = dict_descr + ("__getitem__",)
code_gen.emit_invoke_method(update_descr, 1)
elif isinstance(node.ctx, ast.Store):
code_gen.visit(node.value)
code_gen.emit("ROT_TWO")
code_gen.visit(node.slice)
code_gen.emit("ROT_TWO")
dict_descr = self.klass.type_descr
setitem_descr = dict_descr + ("__setitem__",)
code_gen.emit_invoke_method(setitem_descr, 2)
code_gen.emit("POP_TOP")
else:
code_gen.defaultVisit(node, aug_flag)
def get_fast_len_type(self) -> int:
return FAST_LEN_DICT | ((not self.klass.is_exact) << 4)
def emit_len(
self, node: ast.Call, code_gen: Static38CodeGenerator, boxed: bool
) -> None:
if len(node.args) != 1:
raise code_gen.syntax_error(
"Can only pass a single argument when checking dict length", node
)
code_gen.visit(node.args[0])
code_gen.emit("FAST_LEN", self.get_fast_len_type())
if boxed:
signed = True
code_gen.emit("PRIMITIVE_BOX", int(signed))
def emit_jumpif(
self, test: AST, next: Block, is_if_true: bool, code_gen: Static38CodeGenerator
) -> None:
code_gen.visit(test)
code_gen.emit("FAST_LEN", self.get_fast_len_type())
code_gen.emit("POP_JUMP_IF_NONZERO" if is_if_true else "POP_JUMP_IF_ZERO", next)
class CastFunction(Object[Class]):
def bind_call(
self, node: ast.Call, visitor: TypeBinder, type_ctx: Optional[Class]
) -> NarrowingEffect:
if len(node.args) != 2:
raise visitor.syntax_error(
"cast requires two parameters: type and value", node
)
for arg in node.args:
visitor.visit(arg)
self.check_args_for_primitives(node, visitor)
cast_type = visitor.cur_mod.resolve_annotation(node.args[0])
if cast_type is None:
raise visitor.syntax_error("cast to unknown type", node)
visitor.set_type(node, cast_type.instance)
return NO_EFFECT
def emit_call(self, node: ast.Call, code_gen: Static38CodeGenerator) -> None:
code_gen.visit(node.args[1])
code_gen.emit("CAST", code_gen.get_type(node).klass.type_descr)
prim_name_to_type: Mapping[str, int] = {
"int8": TYPED_INT8,
"int16": TYPED_INT16,
"int32": TYPED_INT32,
"int64": TYPED_INT64,
"uint8": TYPED_UINT8,
"uint16": TYPED_UINT16,
"uint32": TYPED_UINT32,
"uint64": TYPED_UINT64,
}
class CInstance(Value, Generic[TClass]):
_op_name: Dict[Type[ast.operator], str] = {
ast.Add: "add",
ast.Sub: "subtract",
ast.Mult: "multiply",
ast.FloorDiv: "divide",
ast.Div: "divide",
ast.Mod: "modulus",
ast.LShift: "left shift",
ast.RShift: "right shift",
ast.BitOr: "bitwise or",
ast.BitXor: "xor",
ast.BitAnd: "bitwise and",
}
@property
def name(self) -> str:
return self.klass.instance_name
def binop_error(self, left: Value, right: Value, op: ast.operator) -> str:
return f"cannot {self._op_name[type(op)]} {left.name} and {right.name}"
def bind_reverse_binop(
self, node: ast.BinOp, visitor: TypeBinder, type_ctx: Optional[Class]
) -> bool:
try:
visitor.visit(node.left, self)
except TypedSyntaxError:
raise visitor.syntax_error(
self.binop_error(visitor.get_type(node.left), self, node.op), node
)
visitor.set_type(node, self)
return True
def get_op_id(self, op: AST) -> int:
raise NotImplementedError("Must be implemented in the subclass")
def emit_binop(self, node: ast.BinOp, code_gen: Static38CodeGenerator) -> None:
code_gen.update_lineno(node)
common_type = code_gen.get_type(node)
code_gen.visit(node.left)
ltype = code_gen.get_type(node.left)
if ltype != common_type:
common_type.emit_convert(ltype, code_gen)
code_gen.visit(node.right)
rtype = code_gen.get_type(node.right)
if rtype != common_type:
common_type.emit_convert(rtype, code_gen)
op = self.get_op_id(node.op)
code_gen.emit("PRIMITIVE_BINARY_OP", op)
def emit_augassign(
self, node: ast.AugAssign, code_gen: Static38CodeGenerator
) -> None:
code_gen.set_lineno(node)
aug_node = wrap_aug(node.target)
code_gen.visit(aug_node, "load")
code_gen.visit(node.value)
code_gen.emit("PRIMITIVE_BINARY_OP", self.get_op_id(node.op))
code_gen.visit(aug_node, "store")
class CIntInstance(CInstance["CIntType"]):
def __init__(self, klass: CIntType, constant: int, size: int, signed: bool) -> None:
super().__init__(klass)
self.constant = constant
self.size = size
self.signed = signed
def as_oparg(self) -> int:
return self.constant
_int_binary_opcode_signed: Mapping[Type[ast.AST], int] = {
ast.Lt: PRIM_OP_LT_INT,
ast.Gt: PRIM_OP_GT_INT,
ast.Eq: PRIM_OP_EQ_INT,
ast.NotEq: PRIM_OP_NE_INT,
ast.LtE: PRIM_OP_LE_INT,
ast.GtE: PRIM_OP_GE_INT,
ast.Add: PRIM_OP_ADD_INT,
ast.Sub: PRIM_OP_SUB_INT,
ast.Mult: PRIM_OP_MUL_INT,
ast.FloorDiv: PRIM_OP_DIV_INT,
ast.Div: PRIM_OP_DIV_INT,
ast.Mod: PRIM_OP_MOD_INT,
ast.LShift: PRIM_OP_LSHIFT_INT,
ast.RShift: PRIM_OP_RSHIFT_INT,
ast.BitOr: PRIM_OP_OR_INT,
ast.BitXor: PRIM_OP_XOR_INT,
ast.BitAnd: PRIM_OP_AND_INT,
}
_int_binary_opcode_unsigned: Mapping[Type[ast.AST], int] = {
ast.Lt: PRIM_OP_LT_UN_INT,
ast.Gt: PRIM_OP_GT_UN_INT,
ast.Eq: PRIM_OP_EQ_INT,
ast.NotEq: PRIM_OP_NE_INT,
ast.LtE: PRIM_OP_LE_UN_INT,
ast.GtE: PRIM_OP_GE_UN_INT,
ast.Add: PRIM_OP_ADD_INT,
ast.Sub: PRIM_OP_SUB_INT,
ast.Mult: PRIM_OP_MUL_INT,
ast.FloorDiv: PRIM_OP_DIV_UN_INT,
ast.Div: PRIM_OP_DIV_UN_INT,
ast.Mod: PRIM_OP_MOD_UN_INT,
ast.LShift: PRIM_OP_LSHIFT_INT,
ast.RShift: PRIM_OP_RSHIFT_INT,
ast.RShift: PRIM_OP_RSHIFT_UN_INT,
ast.BitOr: PRIM_OP_OR_INT,
ast.BitXor: PRIM_OP_XOR_INT,
ast.BitAnd: PRIM_OP_AND_INT,
}
def get_op_id(self, op: AST) -> int:
return (
self._int_binary_opcode_signed[type(op)]
if self.signed
else (self._int_binary_opcode_unsigned[type(op)])
)
def validate_mixed_math(self, other: Value) -> Optional[Value]:
if self.constant == TYPED_BOOL:
return None
if other is self:
return self
elif isinstance(other, CIntInstance):
if other.constant == TYPED_BOOL:
return None
if self.signed == other.signed:
# signs match, we can just treat this as a comparison of the larger type
if self.size > other.size:
return self
else:
return other
else:
new_size = max(
self.size if self.signed else self.size + 1,
other.size if other.signed else other.size + 1,
)
if new_size <= TYPED_INT_64BIT:
# signs don't match, but we can promote to the next highest data type
return SIGNED_CINT_TYPES[new_size].instance
return None
def bind_compare(
self,
node: ast.Compare,
left: expr,
op: cmpop,
right: expr,
visitor: TypeBinder,
type_ctx: Optional[Class],
) -> bool:
rtype = visitor.get_type(right)
if rtype != self and not isinstance(rtype, CIntInstance):
try:
visitor.visit(right, self)
except TypedSyntaxError:
# Report a better error message than the generic can't be used
raise visitor.syntax_error(
f"can't compare {self.name} to {visitor.get_type(right).name}",
node,
)
compare_type = self.validate_mixed_math(visitor.get_type(right))
if compare_type is None:
raise visitor.syntax_error(
f"can't compare {self.name} to {visitor.get_type(right).name}", node
)
visitor.set_type(op, compare_type)
visitor.set_type(node, CBOOL_TYPE.instance)
return True
def bind_reverse_compare(
self,
node: ast.Compare,
left: expr,
op: cmpop,
right: expr,
visitor: TypeBinder,
type_ctx: Optional[Class],
) -> bool:
if not isinstance(visitor.get_type(left), CIntInstance):
try:
visitor.visit(left, self)
except TypedSyntaxError:
# Report a better error message than the generic can't be used
raise visitor.syntax_error(
f"can't compare {self.name} to {visitor.get_type(right).name}", node
)
compare_type = self.validate_mixed_math(visitor.get_type(left))
if compare_type is None:
raise visitor.syntax_error(
f"can't compare {visitor.get_type(left).name} to {self.name}", node
)
visitor.set_type(op, compare_type)
visitor.set_type(node, CBOOL_TYPE.instance)
return True
return False
def emit_compare(self, op: cmpop, code_gen: Static38CodeGenerator) -> None:
code_gen.emit("INT_COMPARE_OP", self.get_op_id(op))
def emit_augname(
self, node: AugName, code_gen: Static38CodeGenerator, mode: str
) -> None:
if mode == "load":
code_gen.emit("LOAD_LOCAL", (node.id, self.klass.type_descr))
elif mode == "store":
code_gen.emit("STORE_LOCAL", (node.id, self.klass.type_descr))
def validate_int(self, val: object, node: ast.AST, visitor: TypeBinder) -> None:
if not isinstance(val, int):
raise visitor.syntax_error(
f"{type(val).__name__} cannot be used in a context where an int is expected",
node,
)
bits = 8 << self.size
if self.signed:
low = -(1 << (bits - 1))
high = (1 << (bits - 1)) - 1
else:
low = 0
high = (1 << bits) - 1
if not low <= val <= high:
# We set a type here so that when call handles the syntax error and tries to
# improve the error message to "positional argument type mismatch" it can
# successfully get the type
visitor.set_type(node, INT_TYPE.instance)
raise visitor.syntax_error(
f"constant {val} is outside of the range {low} to {high} for {self.name}",
node,
)
def bind_constant(self, node: ast.Constant, visitor: TypeBinder) -> None:
self.validate_int(node.value, node, visitor)
visitor.set_type(node, self)
def emit_constant(
self, node: ast.Constant, code_gen: Static38CodeGenerator
) -> None:
val = node.value
if self.constant == TYPED_BOOL:
val = bool(val)
code_gen.emit("PRIMITIVE_LOAD_CONST", (val, self.as_oparg()))
def emit_name(self, node: ast.Name, code_gen: Static38CodeGenerator) -> None:
if isinstance(node.ctx, ast.Load):
code_gen.emit("LOAD_LOCAL", (node.id, self.klass.type_descr))
elif isinstance(node.ctx, ast.Store):
code_gen.emit("STORE_LOCAL", (node.id, self.klass.type_descr))
else:
raise TypedSyntaxError("unsupported op")
def emit_jumpif(
self, test: AST, next: Block, is_if_true: bool, code_gen: Static38CodeGenerator
) -> None:
code_gen.visit(test)
code_gen.emit("POP_JUMP_IF_NONZERO" if is_if_true else "POP_JUMP_IF_ZERO", next)
def emit_jumpif_pop(
self, test: AST, next: Block, is_if_true: bool, code_gen: Static38CodeGenerator
) -> None:
code_gen.visit(test)
code_gen.emit(
"JUMP_IF_NONZERO_OR_POP" if is_if_true else "JUMP_IF_ZERO_OR_POP", next
)
def bind_binop(
self, node: ast.BinOp, visitor: TypeBinder, type_ctx: Optional[Class]
) -> bool:
if self.constant == TYPED_BOOL:
raise TypedSyntaxError(
f"cbool is not a valid operand type for {self._op_name[type(node.op)]}"
)
rinst = visitor.get_type(node.right)
if rinst != self:
if rinst.klass == LIST_EXACT_TYPE:
visitor.set_type(node, LIST_EXACT_TYPE.instance)
return True
if rinst.klass == TUPLE_EXACT_TYPE:
visitor.set_type(node, TUPLE_EXACT_TYPE.instance)
return True
try:
visitor.visit(node.right, type_ctx or INT64_VALUE)
except TypedSyntaxError:
# Report a better error message than the generic can't be used
raise visitor.syntax_error(
self.binop_error(self, visitor.get_type(node.right), node.op),
node,
)
if type_ctx is None:
type_ctx = self.validate_mixed_math(visitor.get_type(node.right))
if type_ctx is None:
raise visitor.syntax_error(
self.binop_error(self, visitor.get_type(node.right), node.op),
node,
)
visitor.set_type(node, type_ctx)
return True
def emit_box(self, node: expr, code_gen: Static38CodeGenerator) -> None:
code_gen.visit(node)
type = code_gen.get_type(node)
if isinstance(type, CIntInstance):
code_gen.emit("PRIMITIVE_BOX", self.as_oparg())
else:
raise RuntimeError("unsupported box type: " + type.name)
def emit_unbox(self, node: expr, code_gen: Static38CodeGenerator) -> None:
final_val = code_gen.get_final_literal(node)
if final_val is not None:
return self.emit_constant(final_val, code_gen)
typ = code_gen.get_type(node).klass
if isinstance(typ, NumClass) and typ.literal_value is not None:
code_gen.emit("PRIMITIVE_LOAD_CONST", (typ.literal_value, self.as_oparg()))
return
code_gen.visit(node)
code_gen.emit("PRIMITIVE_UNBOX", self.as_oparg())
def bind_unaryop(
self, node: ast.UnaryOp, visitor: TypeBinder, type_ctx: Optional[Class]
) -> None:
if isinstance(node.op, (ast.USub, ast.Invert, ast.UAdd)):
visitor.set_type(node, self)
else:
assert isinstance(node.op, ast.Not)
visitor.set_type(node, BOOL_TYPE.instance)
def emit_unaryop(self, node: ast.UnaryOp, code_gen: Static38CodeGenerator) -> None:
code_gen.update_lineno(node)
if isinstance(node.op, ast.USub):
code_gen.visit(node.operand)
code_gen.emit("PRIMITIVE_UNARY_OP", PRIM_OP_NEG_INT)
elif isinstance(node.op, ast.Invert):
code_gen.visit(node.operand)
code_gen.emit("PRIMITIVE_UNARY_OP", PRIM_OP_INV_INT)
elif isinstance(node.op, ast.UAdd):
code_gen.visit(node.operand)
elif isinstance(node.op, ast.Not):
raise NotImplementedError()
def emit_convert(self, to_type: Value, code_gen: Static38CodeGenerator) -> None:
assert isinstance(to_type, CIntInstance)
# Lower nibble is type-from, higher nibble is type-to.
code_gen.emit("CONVERT_PRIMITIVE", (self.as_oparg() << 4) | to_type.as_oparg())
class CIntType(CType):
instance: CIntInstance
def __init__(self, constant: int, name_override: Optional[str] = None) -> None:
self.constant = constant
# See TYPED_SIZE macro
self.size: int = (constant >> 1) & 3
self.signed: bool = bool(constant & 1)
if name_override is None:
name = ("" if self.signed else "u") + "int" + str(8 << self.size)
else:
name = name_override
super().__init__(
TypeName("__static__", name),
[],
CIntInstance(self, self.constant, self.size, self.signed),
)
def can_assign_from(self, src: Class) -> bool:
if isinstance(src, CIntType):
if src.size <= self.size and src.signed == self.signed:
# assignment to same or larger size, with same sign
# is allowed
return True
if src.size < self.size and self.signed:
# assignment to larger signed size from unsigned is
# allowed
return True
return super().can_assign_from(src)
def bind_call(
self, node: ast.Call, visitor: TypeBinder, type_ctx: Optional[Class]
) -> NarrowingEffect:
if len(node.args) != 1:
raise visitor.syntax_error(
f"{self.name} requires a single argument ({len(node.args)} given)", node
)
visitor.set_type(node, self.instance)
arg = node.args[0]
try:
visitor.visit(arg, self.instance)
except TypedSyntaxError:
visitor.visit(arg)
arg_type = visitor.get_type(arg)
if (
arg_type is not INT_TYPE.instance
and arg_type is not DYNAMIC
and arg_type is not OBJECT
):
raise
return NO_EFFECT
def emit_call(self, node: ast.Call, code_gen: Static38CodeGenerator) -> None:
if len(node.args) != 1:
raise code_gen.syntax_error(
f"{self.name} requires a single argument ({len(node.args)} given)", node
)
arg = node.args[0]
arg_type = code_gen.get_type(arg)
if isinstance(arg_type, CIntInstance):
code_gen.visit(arg)
if arg_type != self.instance:
self.instance.emit_convert(arg_type, code_gen)
else:
self.instance.emit_unbox(arg, code_gen)
class CDoubleInstance(CInstance["CDoubleType"]):
_double_binary_opcode_signed: Mapping[Type[ast.AST], int] = {
ast.Add: PRIM_OP_ADD_DBL,
ast.Sub: PRIM_OP_SUB_DBL,
ast.Mult: PRIM_OP_MUL_DBL,
ast.Div: PRIM_OP_DIV_DBL,
}
def get_op_id(self, op: AST) -> int:
return self._double_binary_opcode_signed[type(op)]
def as_oparg(self) -> int:
return TYPED_DOUBLE
def emit_name(self, node: ast.Name, code_gen: Static38CodeGenerator) -> None:
if isinstance(node.ctx, ast.Load):
code_gen.emit("LOAD_LOCAL", (node.id, self.klass.type_descr))
elif isinstance(node.ctx, ast.Store):
code_gen.emit("STORE_LOCAL", (node.id, self.klass.type_descr))
else:
raise TypedSyntaxError("unsupported op")
def bind_binop(
self, node: ast.BinOp, visitor: TypeBinder, type_ctx: Optional[Class]
) -> bool:
rtype = visitor.get_type(node.right)
if rtype != self or type(node.op) not in self._double_binary_opcode_signed:
raise visitor.syntax_error(self.binop_error(self, rtype, node.op), node)
visitor.set_type(node, self)
return True
def bind_constant(self, node: ast.Constant, visitor: TypeBinder) -> None:
visitor.set_type(node, self)
def emit_constant(
self, node: ast.Constant, code_gen: Static38CodeGenerator
) -> None:
code_gen.emit("PRIMITIVE_LOAD_CONST", (float(node.value), self.as_oparg()))
def emit_box(self, node: expr, code_gen: Static38CodeGenerator) -> None:
code_gen.visit(node)
type = code_gen.get_type(node)
if isinstance(type, CDoubleInstance):
code_gen.emit("PRIMITIVE_BOX", self.as_oparg())
else:
raise RuntimeError("unsupported box type: " + type.name)
class CDoubleType(CType):
def __init__(self) -> None:
super().__init__(
TypeName("__static__", "double"),
[OBJECT_TYPE],
CDoubleInstance(self),
)
CBOOL_TYPE = CIntType(TYPED_BOOL, name_override="cbool")
INT8_TYPE = CIntType(TYPED_INT8)
INT16_TYPE = CIntType(TYPED_INT16)
INT32_TYPE = CIntType(TYPED_INT32)
INT64_TYPE = CIntType(TYPED_INT64)
UINT8_TYPE = CIntType(TYPED_UINT8)
UINT16_TYPE = CIntType(TYPED_UINT16)
UINT32_TYPE = CIntType(TYPED_UINT32)
UINT64_TYPE = CIntType(TYPED_UINT64)
INT64_VALUE = INT64_TYPE.instance
CHAR_TYPE = CIntType(TYPED_INT8, name_override="char")
DOUBLE_TYPE = CDoubleType()
ARRAY_TYPE = ArrayClass(
GenericTypeName("__static__", "Array", (GenericParameter("T", 0),))
)
ARRAY_EXACT_TYPE = ArrayClass(
GenericTypeName("__static__", "Array", (GenericParameter("T", 0),)), is_exact=True
)
# Vectors are just currently a special type of array that support
# methods that resize them.
VECTOR_TYPE_PARAM = GenericParameter("T", 0)
VECTOR_TYPE_NAME = GenericTypeName("__static__", "Vector", (VECTOR_TYPE_PARAM,))
VECTOR_TYPE = VectorClass(VECTOR_TYPE_NAME, is_exact=True)
ALLOWED_ARRAY_TYPES: List[Class] = [
INT8_TYPE,
INT16_TYPE,
INT32_TYPE,
INT64_TYPE,
UINT8_TYPE,
UINT16_TYPE,
UINT32_TYPE,
UINT64_TYPE,
CHAR_TYPE,
DOUBLE_TYPE,
FLOAT_TYPE,
]
SIGNED_CINT_TYPES = [INT8_TYPE, INT16_TYPE, INT32_TYPE, INT64_TYPE]
UNSIGNED_CINT_TYPES: List[CIntType] = [
UINT8_TYPE,
UINT16_TYPE,
UINT32_TYPE,
UINT64_TYPE,
]
ALL_CINT_TYPES: Sequence[CIntType] = SIGNED_CINT_TYPES + UNSIGNED_CINT_TYPES
NAME_TO_TYPE: Mapping[object, Class] = {
"NoneType": NONE_TYPE,
"object": OBJECT_TYPE,
"str": STR_TYPE,
"__static__.int8": INT8_TYPE,
"__static__.int16": INT16_TYPE,
"__static__.int32": INT32_TYPE,
"__static__.int64": INT64_TYPE,
"__static__.uint8": UINT8_TYPE,
"__static__.uint16": UINT16_TYPE,
"__static__.uint32": UINT32_TYPE,
"__static__.uint64": UINT64_TYPE,
}
def parse_type(info: Dict[str, object]) -> Class:
optional = info.get("optional", False)
type = info.get("type")
if type:
klass = NAME_TO_TYPE.get(type)
if klass is None:
raise NotImplementedError("unsupported type: " + str(type))
else:
type_param = info.get("type_param")
assert isinstance(type_param, int)
klass = GenericParameter("T" + str(type_param), type_param)
if optional:
return OPTIONAL_TYPE.make_generic_type((klass,), BUILTIN_GENERICS)
return klass
CHECKED_DICT_TYPE = CheckedDict(CHECKED_DICT_TYPE_NAME, [OBJECT_TYPE], pytype=chkdict)
CHECKED_DICT_EXACT_TYPE = CheckedDict(
CHECKED_DICT_TYPE_NAME, [OBJECT_TYPE], pytype=chkdict, is_exact=True
)
EXACT_TYPES: Mapping[Class, Class] = {
ARRAY_TYPE: ARRAY_EXACT_TYPE,
LIST_TYPE: LIST_EXACT_TYPE,
TUPLE_TYPE: TUPLE_EXACT_TYPE,
INT_TYPE: INT_EXACT_TYPE,
FLOAT_TYPE: FLOAT_EXACT_TYPE,
COMPLEX_TYPE: COMPLEX_EXACT_TYPE,
DICT_TYPE: DICT_EXACT_TYPE,
CHECKED_DICT_TYPE: CHECKED_DICT_EXACT_TYPE,
SET_TYPE: SET_EXACT_TYPE,
STR_TYPE: STR_EXACT_TYPE,
}
EXACT_INSTANCES: Mapping[Value, Value] = {
k.instance: v.instance for k, v in EXACT_TYPES.items()
}
INEXACT_TYPES: Mapping[Class, Class] = {v: k for k, v in EXACT_TYPES.items()}
INEXACT_INSTANCES: Mapping[Value, Value] = {v: k for k, v in EXACT_INSTANCES.items()}
def exact(maybe_inexact: Value) -> Value:
if isinstance(maybe_inexact, UnionInstance):
return exact_type(maybe_inexact.klass).instance
exact = EXACT_INSTANCES.get(maybe_inexact)
return exact or maybe_inexact
def inexact(maybe_exact: Value) -> Value:
if isinstance(maybe_exact, UnionInstance):
return inexact_type(maybe_exact.klass).instance
inexact = INEXACT_INSTANCES.get(maybe_exact)
return inexact or maybe_exact
def exact_type(maybe_inexact: Class) -> Class:
if isinstance(maybe_inexact, UnionType):
generic_types = maybe_inexact.generic_types
if generic_types is not None:
return UNION_TYPE.make_generic_type(
tuple(exact_type(a) for a in maybe_inexact.type_args), generic_types
)
exact = EXACT_TYPES.get(maybe_inexact)
return exact or maybe_inexact
def inexact_type(maybe_exact: Class) -> Class:
if isinstance(maybe_exact, UnionType):
generic_types = maybe_exact.generic_types
if generic_types is not None:
return UNION_TYPE.make_generic_type(
tuple(inexact_type(a) for a in maybe_exact.type_args), generic_types
)
inexact = INEXACT_TYPES.get(maybe_exact)
return inexact or maybe_exact
if spamobj is not None:
SPAM_OBJ = GenericClass(
GenericTypeName("xxclassloader", "spamobj", (GenericParameter("T", 0),)),
pytype=spamobj,
)
XXGENERIC_T = GenericParameter("T", 0)
XXGENERIC_U = GenericParameter("U", 1)
XXGENERIC_TYPE_NAME = GenericTypeName(
"xxclassloader", "XXGeneric", (XXGENERIC_T, XXGENERIC_U)
)
class XXGeneric(GenericClass):
def __init__(
self,
name: GenericTypeName,
bases: Optional[List[Class]] = None,
instance: Optional[Object[Class]] = None,
klass: Optional[Class] = None,
members: Optional[Dict[str, Value]] = None,
type_def: Optional[GenericClass] = None,
is_exact: bool = False,
pytype: Optional[Type[object]] = None,
) -> None:
super().__init__(
name,
bases,
instance,
klass,
members,
type_def,
is_exact,
pytype,
)
self.members["foo"] = BuiltinMethodDescriptor(
"foo",
self,
(
Parameter("self", 0, ResolvedTypeRef(self), False, None, False),
Parameter(
"t",
0,
ResolvedTypeRef(XXGENERIC_T),
False,
None,
False,
),
Parameter(
"u",
0,
ResolvedTypeRef(XXGENERIC_U),
False,
None,
False,
),
),
)
XX_GENERIC_TYPE = XXGeneric(XXGENERIC_TYPE_NAME)
else:
SPAM_OBJ: Optional[GenericClass] = None
class GenericVisitor(ASTVisitor):
def __init__(self, module_name: str, filename: str) -> None:
super().__init__()
self.module_name = module_name
self.filename = filename
def visit(self, node: Union[AST, Sequence[AST]], *args: object) -> Optional[object]:
# if we have a sequence of nodes, don't catch TypedSyntaxError here;
# walk_list will call us back with each individual node in turn and we
# can catch errors and add node info then.
ctx = (
error_context(self.filename, node)
if isinstance(node, AST)
else nullcontext()
)
with ctx:
return super().visit(node, *args)
def syntax_error(self, msg: str, node: AST) -> TypedSyntaxError:
return syntax_error(msg, self.filename, node)
class InitVisitor(ASTVisitor):
def __init__(
self, module: ModuleTable, klass: Class, init_func: FunctionDef
) -> None:
super().__init__()
self.module = module
self.klass = klass
self.init_func = init_func
def visitAnnAssign(self, node: AnnAssign) -> None:
target = node.target
if isinstance(target, Attribute):
value = target.value
if (
isinstance(value, ast.Name)
and value.id == self.init_func.args.args[0].arg
):
attr = target.attr
self.klass.define_slot(
attr,
TypeRef(self.module, node.annotation),
assignment=node,
)
def visitAssign(self, node: Assign) -> None:
for target in node.targets:
if not isinstance(target, Attribute):
continue
value = target.value
if (
isinstance(value, ast.Name)
and value.id == self.init_func.args.args[0].arg
):
attr = target.attr
self.klass.define_slot(attr, assignment=node)
class DeclarationVisitor(GenericVisitor):
def __init__(self, mod_name: str, filename: str, symbols: SymbolTable) -> None:
super().__init__(mod_name, filename)
self.symbols = symbols
self.module = symbols[mod_name] = ModuleTable(mod_name, filename, symbols)
def finish_bind(self) -> None:
self.module.finish_bind()
def visitAnnAssign(self, node: AnnAssign) -> None:
self.module.decls.append((node, None))
def visitClassDef(self, node: ClassDef) -> None:
bases = [self.module.resolve_type(base) or DYNAMIC_TYPE for base in node.bases]
if not bases:
bases.append(OBJECT_TYPE)
klass = Class(TypeName(self.module_name, node.name), bases)
self.module.decls.append((node, klass))
for item in node.body:
with error_context(self.filename, item):
if isinstance(item, (AsyncFunctionDef, FunctionDef)):
function = self._make_function(item)
if not function:
continue
klass.define_function(item.name, function, self)
if (
item.name != "__init__"
or not item.args.args
or not isinstance(item, FunctionDef)
):
continue
InitVisitor(self.module, klass, item).visit(item.body)
elif isinstance(item, AnnAssign):
# class C:
# x: foo
target = item.target
if isinstance(target, ast.Name):
klass.define_slot(
target.id,
TypeRef(self.module, item.annotation),
# Note down whether the slot has been assigned a value.
assignment=item if item.value else None,
)
for base in bases:
if base is NAMED_TUPLE_TYPE:
# In named tuples, the fields are actually elements
# of the tuple, so we can't do any advanced binding against it.
klass = DYNAMIC_TYPE
break
if base.is_final:
raise self.syntax_error(
f"Class `{klass.instance.name}` cannot subclass a Final class: `{base.instance.name}`",
node,
)
for d in node.decorator_list:
if klass is DYNAMIC_TYPE:
break
with error_context(self.filename, d):
decorator = self.module.resolve_type(d) or DYNAMIC_TYPE
klass = decorator.bind_decorate_class(klass)
self.module.children[node.name] = klass
def _visitFunc(self, node: Union[FunctionDef, AsyncFunctionDef]) -> None:
function = self._make_function(node)
if function:
self.module.children[function.func_name] = function
def _make_function(
self, node: Union[FunctionDef, AsyncFunctionDef]
) -> Function | StaticMethod | None:
func = Function(node, self.module, self.type_ref(node.returns))
for decorator in node.decorator_list:
decorator_type = self.module.resolve_type(decorator) or DYNAMIC_TYPE
func = decorator_type.bind_decorate_function(self, func)
if not isinstance(func, (Function, StaticMethod)):
return None
return func
def visitFunctionDef(self, node: FunctionDef) -> None:
self._visitFunc(node)
def visitAsyncFunctionDef(self, node: AsyncFunctionDef) -> None:
self._visitFunc(node)
def type_ref(self, ann: Optional[expr]) -> TypeRef:
if not ann:
return ResolvedTypeRef(DYNAMIC_TYPE)
return TypeRef(self.module, ann)
def visitImport(self, node: Import) -> None:
for name in node.names:
self.symbols.import_module(name.name)
def visitImportFrom(self, node: ImportFrom) -> None:
mod_name = node.module
if not mod_name or node.level:
raise NotImplementedError("relative imports aren't supported")
self.symbols.import_module(mod_name)
mod = self.symbols.modules.get(mod_name)
if mod is not None:
for name in node.names:
val = mod.children.get(name.name)
if val is not None:
self.module.children[name.asname or name.name] = val
# We don't pick up declarations in nested statements
def visitFor(self, node: For) -> None:
pass
def visitAsyncFor(self, node: AsyncFor) -> None:
pass
def visitWhile(self, node: While) -> None:
pass
def visitIf(self, node: If) -> None:
test = node.test
if isinstance(test, Name) and test.id == "TYPE_CHECKING":
self.visit(node.body)
def visitWith(self, node: With) -> None:
pass
def visitAsyncWith(self, node: AsyncWith) -> None:
pass
def visitTry(self, node: Try) -> None:
pass
class TypedSyntaxError(SyntaxError):
pass
class LocalsBranch:
"""Handles branching and merging local variable types"""
def __init__(self, scope: BindingScope) -> None:
self.scope = scope
self.entry_locals: Dict[str, Value] = dict(scope.local_types)
def copy(self) -> Dict[str, Value]:
"""Make a copy of the current local state"""
return dict(self.scope.local_types)
def restore(self, state: Optional[Dict[str, Value]] = None) -> None:
"""Restore the locals to the state when we entered"""
self.scope.local_types.clear()
self.scope.local_types.update(state or self.entry_locals)
def merge(self, entry_locals: Optional[Dict[str, Value]] = None) -> None:
"""Merge the entry locals, or a specific copy, into the current locals"""
# TODO: What about del's?
if entry_locals is None:
entry_locals = self.entry_locals
local_types = self.scope.local_types
for key, value in entry_locals.items():
if key in local_types:
if value != local_types[key]:
widest = self._widest_type(value, local_types[key])
local_types[key] = widest or self.scope.decl_types[key].type
continue
for key in local_types.keys():
# If a value isn't definitely assigned we can safely turn it
# back into the declared type
if key not in entry_locals and key in self.scope.decl_types:
local_types[key] = self.scope.decl_types[key].type
def _widest_type(self, *types: Value) -> Optional[Value]:
# TODO: this should be a join, rather than just reverting to decl_type
# if neither type is greater than the other
if len(types) == 1:
return types[0]
widest_type = None
for src in types:
if src == DYNAMIC:
return DYNAMIC
if widest_type is None or src.klass.can_assign_from(widest_type.klass):
widest_type = src
elif widest_type is not None and not widest_type.klass.can_assign_from(
src.klass
):
return None
return widest_type
class TypeDeclaration:
def __init__(self, typ: Value, is_final: bool = False) -> None:
self.type = typ
self.is_final = is_final
class BindingScope:
def __init__(self, node: AST) -> None:
self.node = node
self.local_types: Dict[str, Value] = {}
self.decl_types: Dict[str, TypeDeclaration] = {}
def branch(self) -> LocalsBranch:
return LocalsBranch(self)
def declare(self, name: str, typ: Value, is_final: bool = False) -> TypeDeclaration:
decl = TypeDeclaration(typ, is_final)
self.decl_types[name] = decl
self.local_types[name] = typ
return decl
class ModuleBindingScope(BindingScope):
def __init__(self, node: ast.Module, module: ModuleTable) -> None:
super().__init__(node)
self.module = module
for name, typ in self.module.children.items():
self.declare(name, typ)
def declare(self, name: str, typ: Value, is_final: bool = False) -> TypeDeclaration:
self.module.children[name] = typ
return super().declare(name, typ, is_final)
class NarrowingEffect:
"""captures type narrowing effects on variables"""
def and_(self, other: NarrowingEffect) -> NarrowingEffect:
if other is NoEffect:
return self
return AndEffect(self, other)
def or_(self, other: NarrowingEffect) -> NarrowingEffect:
if other is NoEffect:
return self
return OrEffect(self, other)
def not_(self) -> NarrowingEffect:
return NegationEffect(self)
def apply(self, local_types: Dict[str, Value]) -> None:
"""applies the given effect in the target scope"""
pass
def undo(self, local_types: Dict[str, Value]) -> None:
"""restores the type to its original value"""
pass
def reverse(self, local_types: Dict[str, Value]) -> None:
"""applies the reverse of the scope or reverts it if
there is no reverse"""
self.undo(local_types)
class AndEffect(NarrowingEffect):
def __init__(self, *effects: NarrowingEffect) -> None:
self.effects: Sequence[NarrowingEffect] = effects
def and_(self, other: NarrowingEffect) -> NarrowingEffect:
if other is NoEffect:
return self
elif isinstance(other, AndEffect):
return AndEffect(*self.effects, *other.effects)
return AndEffect(*self.effects, other)
def apply(self, local_types: Dict[str, Value]) -> None:
for effect in self.effects:
effect.apply(local_types)
def undo(self, local_types: Dict[str, Value]) -> None:
"""restores the type to its original value"""
for effect in self.effects:
effect.undo(local_types)
class OrEffect(NarrowingEffect):
def __init__(self, *effects: NarrowingEffect) -> None:
self.effects: Sequence[NarrowingEffect] = effects
def and_(self, other: NarrowingEffect) -> NarrowingEffect:
if other is NoEffect:
return self
elif isinstance(other, OrEffect):
return OrEffect(*self.effects, *other.effects)
return OrEffect(*self.effects, other)
def reverse(self, local_types: Dict[str, Value]) -> None:
for effect in self.effects:
effect.reverse(local_types)
def undo(self, local_types: Dict[str, Value]) -> None:
"""restores the type to its original value"""
for effect in self.effects:
effect.undo(local_types)
class NoEffect(NarrowingEffect):
def union(self, other: NarrowingEffect) -> NarrowingEffect:
return other
# Singleton instance for no effects
NO_EFFECT = NoEffect()
class NegationEffect(NarrowingEffect):
def __init__(self, negated: NarrowingEffect) -> None:
self.negated = negated
def not_(self) -> NarrowingEffect:
return self.negated
def apply(self, local_types: Dict[str, Value]) -> None:
self.negated.reverse(local_types)
def undo(self, local_types: Dict[str, Value]) -> None:
self.negated.undo(local_types)
def reverse(self, local_types: Dict[str, Value]) -> None:
self.negated.apply(local_types)
class IsInstanceEffect(NarrowingEffect):
def __init__(self, var: str, prev: Value, inst: Value, visitor: TypeBinder) -> None:
self.var = var
self.prev = prev
self.inst = inst
reverse = prev
if isinstance(prev, UnionInstance):
type_args = tuple(
ta for ta in prev.klass.type_args if not inst.klass.can_assign_from(ta)
)
reverse = UNION_TYPE.make_generic_type(
type_args, visitor.symtable.generic_types
).instance
self.rev: Value = reverse
def apply(self, local_types: Dict[str, Value]) -> None:
local_types[self.var] = self.inst
def undo(self, local_types: Dict[str, Value]) -> None:
local_types[self.var] = self.prev
def reverse(self, local_types: Dict[str, Value]) -> None:
local_types[self.var] = self.rev
class TerminalKind(IntEnum):
NonTerminal = 0
BreakOrContinue = 1
Return = 2
class TypeBinder(GenericVisitor):
"""Walks an AST and produces an optionally strongly typed AST, reporting errors when
operations are occuring that are not sound. Strong types are based upon places where
annotations occur which opt-in the strong typing"""
def __init__(
self,
symbols: SymbolVisitor,
filename: str,
symtable: SymbolTable,
module_name: str,
optimize: int = 0,
) -> None:
super().__init__(module_name, filename)
self.symbols = symbols
self.scopes: List[BindingScope] = []
self.symtable = symtable
self.cur_mod: ModuleTable = symtable[module_name]
self.optimize = optimize
self.terminals: Dict[AST, TerminalKind] = {}
self.inline_depth = 0
@property
def local_types(self) -> Dict[str, Value]:
return self.binding_scope.local_types
@property
def decl_types(self) -> Dict[str, TypeDeclaration]:
return self.binding_scope.decl_types
@property
def binding_scope(self) -> BindingScope:
return self.scopes[-1]
@property
def scope(self) -> AST:
return self.binding_scope.node
def maybe_set_local_type(self, name: str, local_type: Value) -> Value:
decl_type = self.decl_types[name].type
if local_type is DYNAMIC or not decl_type.klass.can_be_narrowed:
local_type = decl_type
self.local_types[name] = local_type
return local_type
def maybe_get_current_class(self) -> Optional[Class]:
scope = self.scope
if isinstance(scope, ClassDef):
klass = self.cur_mod.resolve_name(scope.name)
assert isinstance(klass, Class)
return klass
def visit(
self, node: Union[AST, Sequence[AST]], *args: object
) -> Optional[NarrowingEffect]:
"""This override is only here to give Pyre the return type information."""
ret = super().visit(node, *args)
if ret is not None:
assert isinstance(ret, NarrowingEffect)
return ret
return None
def get_final_literal(self, node: AST) -> Optional[ast.Constant]:
return self.cur_mod.get_final_literal(node, self.symbols.scopes[self.scope])
def declare_local(
self, target: ast.Name, typ: Value, is_final: bool = False
) -> None:
if target.id in self.decl_types:
raise self.syntax_error(
f"Cannot redefine local variable {target.id}", target
)
if isinstance(typ, CInstance):
self.check_primitive_scope(target)
self.binding_scope.declare(target.id, typ, is_final)
def check_static_import_flags(self, node: Module) -> None:
saw_doc_str = False
for stmt in node.body:
if isinstance(stmt, ast.Expr):
val = stmt.value
if isinstance(val, ast.Constant) and isinstance(val.value, str):
if saw_doc_str:
break
saw_doc_str = True
else:
break
elif isinstance(stmt, ast.Import):
continue
elif isinstance(stmt, ast.ImportFrom):
if stmt.module == "__static__.compiler_flags":
for name in stmt.names:
if name.name == "nonchecked_dicts":
self.cur_mod.nonchecked_dicts = True
elif name.name == "noframe":
self.cur_mod.noframe = True
def visitModule(self, node: Module) -> None:
self.scopes.append(ModuleBindingScope(node, self.cur_mod))
self.check_static_import_flags(node)
for stmt in node.body:
self.visit(stmt)
self.scopes.pop()
def set_param(self, arg: ast.arg, arg_type: Class, scope: BindingScope) -> None:
scope.declare(arg.arg, arg_type.instance)
self.set_type(arg, arg_type.instance)
def _visitFunc(self, node: Union[FunctionDef, AsyncFunctionDef]) -> None:
scope = BindingScope(node)
for decorator in node.decorator_list:
self.visit(decorator)
cur_scope = self.scope
if (
not node.decorator_list
and isinstance(cur_scope, ClassDef)
and node.args.args
):
# Handle type of "self"
klass = self.cur_mod.resolve_name(cur_scope.name)
if isinstance(klass, Class):
self.set_param(node.args.args[0], klass, scope)
else:
self.set_param(node.args.args[0], DYNAMIC_TYPE, scope)
for arg in node.args.posonlyargs:
ann = arg.annotation
if ann:
self.visit(ann)
arg_type = self.cur_mod.resolve_annotation(ann) or DYNAMIC_TYPE
elif arg.arg in scope.decl_types:
# Already handled self
continue
else:
arg_type = DYNAMIC_TYPE
self.set_param(arg, arg_type, scope)
for arg in node.args.args:
ann = arg.annotation
if ann:
self.visit(ann)
arg_type = self.cur_mod.resolve_annotation(ann) or DYNAMIC_TYPE
elif arg.arg in scope.decl_types:
# Already handled self
continue
else:
arg_type = DYNAMIC_TYPE
self.set_param(arg, arg_type, scope)
if node.args.defaults:
for default in node.args.defaults:
self.visit(default)
if node.args.kw_defaults:
for default in node.args.kw_defaults:
if default is not None:
self.visit(default)
vararg = node.args.vararg
if vararg:
ann = vararg.annotation
if ann:
self.visit(ann)
self.set_param(vararg, TUPLE_EXACT_TYPE, scope)
for arg in node.args.kwonlyargs:
ann = arg.annotation
if ann:
self.visit(ann)
arg_type = self.cur_mod.resolve_annotation(ann) or DYNAMIC_TYPE
else:
arg_type = DYNAMIC_TYPE
self.set_param(arg, arg_type, scope)
kwarg = node.args.kwarg
if kwarg:
ann = kwarg.annotation
if ann:
self.visit(ann)
self.set_param(kwarg, DICT_EXACT_TYPE, scope)
returns = None if node.args in self.cur_mod.dynamic_returns else node.returns
if returns:
# We store the return type on the node for the function as we otherwise
# don't need to store type information for it
expected = self.cur_mod.resolve_annotation(returns) or DYNAMIC_TYPE
self.set_type(node, expected.instance)
self.visit(returns)
else:
self.set_type(node, DYNAMIC)
self.scopes.append(scope)
for stmt in node.body:
self.visit(stmt)
self.scopes.pop()
def visitFunctionDef(self, node: FunctionDef) -> None:
self._visitFunc(node)
def visitAsyncFunctionDef(self, node: AsyncFunctionDef) -> None:
self._visitFunc(node)
def visitClassDef(self, node: ClassDef) -> None:
for decorator in node.decorator_list:
self.visit(decorator)
for kwarg in node.keywords:
self.visit(kwarg.value)
for base in node.bases:
self.visit(base)
self.scopes.append(BindingScope(node))
for stmt in node.body:
self.visit(stmt)
self.scopes.pop()
def set_type(self, node: AST, type: Value) -> None:
self.cur_mod.types[node] = type
def get_type(self, node: AST) -> Value:
assert node in self.cur_mod.types, f"node not found: {node}, {node.lineno}"
return self.cur_mod.types[node]
def get_node_data(
self, key: Union[AST, Delegator], data_type: Type[TType]
) -> TType:
return cast(TType, self.cur_mod.node_data[key, data_type])
def set_node_data(
self, key: Union[AST, Delegator], data_type: Type[TType], value: TType
) -> None:
self.cur_mod.node_data[key, data_type] = value
def check_primitive_scope(self, node: Name) -> None:
cur_scope = self.symbols.scopes[self.scope]
var_scope = cur_scope.check_name(node.id)
if var_scope != SC_LOCAL or isinstance(self.scope, Module):
raise self.syntax_error(
"cannot use primitives in global or closure scope", node
)
def get_var_scope(self, var_id: str) -> Optional[int]:
cur_scope = self.symbols.scopes[self.scope]
var_scope = cur_scope.check_name(var_id)
return var_scope
def _check_final_attribute_reassigned(
self,
target: AST,
assignment: Optional[AST],
) -> None:
member = None
klass = None
member_name = None
# Try to look up the Class and associated Slot
scope = self.scope
if isinstance(target, ast.Name) and isinstance(scope, ast.ClassDef):
klass = self.cur_mod.resolve_name(scope.name)
assert isinstance(klass, Class)
member_name = target.id
member = klass.get_member(member_name)
elif isinstance(target, ast.Attribute):
klass = self.get_type(target.value).klass
member_name = target.attr
member = klass.get_member(member_name)
# Ensure we don't reassign to Finals
if (
klass is not None
and member is not None
and (
(
isinstance(member, Slot)
and member.is_final
and member.assignment != assignment
)
or (isinstance(member, Function) and member.is_final)
)
):
raise self.syntax_error(
f"Cannot assign to a Final attribute of {klass.instance.name}:{member_name}",
target,
)
def visitAnnAssign(self, node: AnnAssign) -> None:
self.visit(node.annotation)
target = node.target
comp_type = (
self.cur_mod.resolve_annotation(node.annotation, is_declaration=True)
or DYNAMIC_TYPE
)
is_final = False
if isinstance(comp_type, FinalClass):
is_final = True
comp_type = comp_type.inner_type()
if isinstance(target, Name):
self.declare_local(target, comp_type.instance, is_final)
self.set_type(target, comp_type.instance)
self.visit(target)
value = node.value
if value:
self.visit(value, comp_type.instance)
if isinstance(target, Name):
# We could be narrowing the type after the assignment, so we update it here
# even though we assigned it above (but we never narrow primtives)
new_type = self.get_type(value)
local_type = self.maybe_set_local_type(target.id, new_type)
self.set_type(target, local_type)
self.check_can_assign_from(comp_type, self.get_type(value).klass, node)
self._check_final_attribute_reassigned(target, node)
def visitAugAssign(self, node: AugAssign) -> None:
self.visit(node.target)
target_type = inexact(self.get_type(node.target))
self.visit(node.value, target_type)
self.set_type(node, target_type)
def visitAssign(self, node: Assign) -> None:
# Sometimes, we need to propagate types from the target to the value to allow primitives to be handled
# correctly. So we compute the narrowest target type. (Other checks do happen later).
# e.g: `x: int8 = 1` means we need `1` to be of type `int8`
narrowest_target_type = None
for target in reversed(node.targets):
cur_type = None
if isinstance(target, ast.Name):
# This is a name, it could be unassigned still
decl_type = self.decl_types.get(target.id)
if decl_type is not None:
cur_type = decl_type.type
elif isinstance(target, (ast.Tuple, ast.List)):
# TODO: We should walk into the tuple/list and use it to infer
# types down on the RHS if we can
self.visit(target)
else:
# This is an attribute or subscript, the assignment can't change the type
self.visit(target)
cur_type = self.get_type(target)
if cur_type is not None and (
narrowest_target_type is None
or narrowest_target_type.klass.can_assign_from(cur_type.klass)
):
narrowest_target_type = cur_type
self.visit(node.value, narrowest_target_type)
value_type = self.get_type(node.value)
for target in reversed(node.targets):
self.assign_value(target, value_type, src=node.value, assignment=node)
self.set_type(node, value_type)
def check_can_assign_from(
self, dest: Class, src: Class, node: AST, reason: str = "cannot be assigned to"
) -> None:
if not dest.can_assign_from(src) and src is not DYNAMIC_TYPE:
raise self.syntax_error(
f"type mismatch: {src.instance.name} {reason} {dest.instance.name} ",
node,
)
def visitBoolOp(
self, node: BoolOp, type_ctx: Optional[Class] = None
) -> NarrowingEffect:
effect = NO_EFFECT
final_type = None
if isinstance(node.op, And):
for value in node.values:
new_effect = self.visit(value) or NO_EFFECT
effect = effect.and_(new_effect)
final_type = self.widen(final_type, self.get_type(value))
# apply the new effect as short circuiting would
# eliminate it.
new_effect.apply(self.local_types)
# we undo the effect as we have no clue what context we're in
# but then we return the combined effect in case we're being used
# in a conditional context
effect.undo(self.local_types)
elif isinstance(node.op, ast.Or):
for value in node.values:
new_effect = self.visit(value) or NO_EFFECT
effect = effect.or_(new_effect)
final_type = self.widen(final_type, self.get_type(value))
new_effect.reverse(self.local_types)
effect.undo(self.local_types)
else:
for value in node.values:
self.visit(value)
final_type = self.widen(final_type, self.get_type(value))
self.set_type(node, final_type or DYNAMIC)
return effect
def visitBinOp(
self, node: BinOp, type_ctx: Optional[Class] = None
) -> NarrowingEffect:
# In order to interpret numeric literals as primitives within a
# primitive type context, we want to try to pass the type context down
# to each side, but we can't require this, otherwise things like `List:
# List * int` would fail.
try:
self.visit(node.left, type_ctx)
except TypedSyntaxError:
self.visit(node.left)
try:
self.visit(node.right, type_ctx)
except TypedSyntaxError:
self.visit(node.right)
ltype = self.get_type(node.left)
rtype = self.get_type(node.right)
tried_right = False
if ltype.klass in rtype.klass.mro[1:]:
if rtype.bind_reverse_binop(node, self, type_ctx):
return NO_EFFECT
tried_right = True
if ltype.bind_binop(node, self, type_ctx):
return NO_EFFECT
if not tried_right:
rtype.bind_reverse_binop(node, self, type_ctx)
return NO_EFFECT
def visitUnaryOp(
self, node: UnaryOp, type_ctx: Optional[Class] = None
) -> NarrowingEffect:
effect = self.visit(node.operand, type_ctx)
self.get_type(node.operand).bind_unaryop(node, self, type_ctx)
if (
effect is not None
and effect is not NO_EFFECT
and isinstance(node.op, ast.Not)
):
return effect.not_()
return NO_EFFECT
def visitLambda(
self, node: Lambda, type_ctx: Optional[Class] = None
) -> NarrowingEffect:
self.visit(node.body)
self.set_type(node, DYNAMIC)
return NO_EFFECT
def visitIfExp(
self, node: IfExp, type_ctx: Optional[Class] = None
) -> NarrowingEffect:
effect = self.visit(node.test) or NO_EFFECT
effect.apply(self.local_types)
self.visit(node.body)
effect.reverse(self.local_types)
self.visit(node.orelse)
effect.undo(self.local_types)
# Select the most compatible types that we can, or fallback to
# dynamic if we can coerce to dynamic, otherwise report an error.
body_t = self.get_type(node.body)
else_t = self.get_type(node.orelse)
if body_t.klass.can_assign_from(else_t.klass):
self.set_type(node, body_t)
elif else_t.klass.can_assign_from(body_t.klass):
self.set_type(node, else_t)
elif DYNAMIC_TYPE.can_assign_from(
body_t.klass
) and DYNAMIC_TYPE.can_assign_from(else_t.klass):
self.set_type(node, DYNAMIC)
else:
raise self.syntax_error(
f"if expression has incompatible types: {body_t.name} and {else_t.name}",
node,
)
return NO_EFFECT
def visitSlice(
self, node: Slice, type_ctx: Optional[Class] = None
) -> NarrowingEffect:
lower = node.lower
if lower:
self.visit(lower, type_ctx)
upper = node.upper
if upper:
self.visit(upper, type_ctx)
step = node.step
if step:
self.visit(step, type_ctx)
self.set_type(node, SLICE_TYPE.instance)
return NO_EFFECT
def widen(self, existing: Optional[Value], new: Value) -> Value:
if existing is None or new.klass.can_assign_from(existing.klass):
return new
elif existing.klass.can_assign_from(new.klass):
return existing
res = UNION_TYPE.make_generic_type(
(existing.klass, new.klass), self.symtable.generic_types
).instance
return res
def visitDict(
self, node: ast.Dict, type_ctx: Optional[Class] = None
) -> NarrowingEffect:
key_type: Optional[Value] = None
value_type: Optional[Value] = None
for k, v in zip(node.keys, node.values):
if k:
self.visit(k)
key_type = self.widen(key_type, self.get_type(k))
self.visit(v)
value_type = self.widen(value_type, self.get_type(v))
else:
self.visit(v, type_ctx)
d_type = self.get_type(v).klass
if (
d_type.generic_type_def is CHECKED_DICT_TYPE
or d_type.generic_type_def is CHECKED_DICT_EXACT_TYPE
):
assert isinstance(d_type, GenericClass)
key_type = self.widen(key_type, d_type.type_args[0].instance)
value_type = self.widen(value_type, d_type.type_args[1].instance)
elif d_type in (DICT_TYPE, DICT_EXACT_TYPE, DYNAMIC_TYPE):
key_type = DYNAMIC
value_type = DYNAMIC
self.set_dict_type(node, key_type, value_type, type_ctx, is_exact=True)
return NO_EFFECT
def set_dict_type(
self,
node: ast.expr,
key_type: Optional[Value],
value_type: Optional[Value],
type_ctx: Optional[Class],
is_exact: bool = False,
) -> Value:
if self.cur_mod.nonchecked_dicts or not isinstance(
type_ctx, CheckedDictInstance
):
# This is not a checked dict, or the user opted out of checked dicts
if type_ctx in (DICT_TYPE.instance, DICT_EXACT_TYPE.instance):
typ = type_ctx
elif is_exact:
typ = DICT_EXACT_TYPE.instance
else:
typ = DICT_TYPE.instance
assert typ is not None
self.set_type(node, typ)
return typ
# Calculate the type that is inferred by the keys and values
if key_type is None:
key_type = OBJECT_TYPE.instance
if value_type is None:
value_type = OBJECT_TYPE.instance
checked_dict_typ = CHECKED_DICT_EXACT_TYPE if is_exact else CHECKED_DICT_TYPE
gen_type = checked_dict_typ.make_generic_type(
(key_type.klass, value_type.klass), self.symtable.generic_types
)
if type_ctx is not None:
type_class = type_ctx.klass
if type_class.generic_type_def in (
CHECKED_DICT_EXACT_TYPE,
CHECKED_DICT_TYPE,
):
assert isinstance(type_class, GenericClass)
self.set_type(node, type_ctx)
# We can use the type context to have a type which is wider than the
# inferred types. But we need to make sure that the keys/values are compatible
# with the wider type, and if not, we'll report that the inferred type isn't
# compatible.
if not type_class.type_args[0].can_assign_from(
key_type.klass
) or not type_class.type_args[1].can_assign_from(value_type.klass):
self.check_can_assign_from(type_class, gen_type, node)
return type_ctx
else:
# Otherwise we allow something that would assign to dynamic, but not something
# that would assign to an unrelated type (e.g. int)
self.set_type(node, gen_type.instance)
self.check_can_assign_from(type_class, gen_type, node)
else:
self.set_type(node, gen_type.instance)
return gen_type.instance
def visitSet(
self, node: ast.Set, type_ctx: Optional[Class] = None
) -> NarrowingEffect:
for elt in node.elts:
self.visit(elt)
self.set_type(node, SET_EXACT_TYPE.instance)
return NO_EFFECT
def visitGeneratorExp(
self, node: GeneratorExp, type_ctx: Optional[Class] = None
) -> NarrowingEffect:
self.visit_comprehension(node, node.generators, node.elt)
self.set_type(node, DYNAMIC)
return NO_EFFECT
def visitListComp(
self, node: ListComp, type_ctx: Optional[Class] = None
) -> NarrowingEffect:
self.visit_comprehension(node, node.generators, node.elt)
self.set_type(node, LIST_EXACT_TYPE.instance)
return NO_EFFECT
def visitSetComp(
self, node: SetComp, type_ctx: Optional[Class] = None
) -> NarrowingEffect:
self.visit_comprehension(node, node.generators, node.elt)
self.set_type(node, SET_EXACT_TYPE.instance)
return NO_EFFECT
def assign_value(
self,
target: expr,
value: Value,
src: Optional[expr] = None,
assignment: Optional[AST] = None,
) -> None:
if isinstance(target, Name):
decl_type = self.decl_types.get(target.id)
if decl_type is None:
# This var is not declared in the current scope, but it might be a
# global or nonlocal. In that case, we need to check whether it's a Final.
scope_type = self.get_var_scope(target.id)
if scope_type == SC_GLOBAL_EXPLICIT or scope_type == SC_GLOBAL_IMPLICIT:
declared_type = self.scopes[0].decl_types.get(target.id, None)
if declared_type is not None and declared_type.is_final:
raise self.syntax_error(
"Cannot assign to a Final variable", target
)
# For an inferred exact type, we want to declare the inexact
# type; the exact type is useful local inference information,
# but we should still allow assignment of a subclass later.
self.declare_local(target, inexact(value))
else:
if decl_type.is_final:
raise self.syntax_error("Cannot assign to a Final variable", target)
self.check_can_assign_from(decl_type.type.klass, value.klass, target)
local_type = self.maybe_set_local_type(target.id, value)
self.set_type(target, local_type)
elif isinstance(target, (ast.Tuple, ast.List)):
if isinstance(src, (ast.Tuple, ast.List)) and len(target.elts) == len(
src.elts
):
for target, inner_value in zip(target.elts, src.elts):
self.assign_value(
target, self.get_type(inner_value), src=inner_value
)
elif isinstance(src, ast.Constant):
t = src.value
if isinstance(t, tuple) and len(t) == len(target.elts):
for target, inner_value in zip(target.elts, t):
self.assign_value(target, CONSTANT_TYPES[type(inner_value)])
else:
for val in target.elts:
self.assign_value(val, DYNAMIC)
else:
for val in target.elts:
self.assign_value(val, DYNAMIC)
else:
self.check_can_assign_from(self.get_type(target).klass, value.klass, target)
self._check_final_attribute_reassigned(target, assignment)
def visitDictComp(
self, node: DictComp, type_ctx: Optional[Class] = None
) -> NarrowingEffect:
self.visit(node.generators[0].iter)
scope = BindingScope(node)
self.scopes.append(scope)
iter_type = self.get_type(node.generators[0].iter).get_iter_type(
node.generators[0].iter, self
)
self.assign_value(node.generators[0].target, iter_type)
for if_ in node.generators[0].ifs:
self.visit(if_)
for gen in node.generators[1:]:
self.visit(gen.iter)
iter_type = self.get_type(gen.iter).get_iter_type(gen.iter, self)
self.assign_value(gen.target, iter_type)
for if_ in node.generators[0].ifs:
self.visit(if_)
self.visit(node.key)
self.visit(node.value)
self.scopes.pop()
key_type = self.get_type(node.key)
value_type = self.get_type(node.value)
self.set_dict_type(node, key_type, value_type, type_ctx, is_exact=True)
return NO_EFFECT
def visit_comprehension(
self, node: ast.expr, generators: List[ast.comprehension], *elts: ast.expr
) -> None:
self.visit(generators[0].iter)
scope = BindingScope(node)
self.scopes.append(scope)
iter_type = self.get_type(generators[0].iter).get_iter_type(
generators[0].iter, self
)
self.assign_value(generators[0].target, iter_type)
for if_ in generators[0].ifs:
self.visit(if_)
for gen in generators[1:]:
self.visit(gen.iter)
iter_type = self.get_type(gen.iter).get_iter_type(gen.iter, self)
self.assign_value(gen.target, iter_type)
for if_ in generators[0].ifs:
self.visit(if_)
for elt in elts:
self.visit(elt)
self.scopes.pop()
def visitAwait(
self, node: Await, type_ctx: Optional[Class] = None
) -> NarrowingEffect:
self.visit(node.value)
self.set_type(node, DYNAMIC)
return NO_EFFECT
def visitYield(
self, node: Yield, type_ctx: Optional[Class] = None
) -> NarrowingEffect:
value = node.value
if value is not None:
self.visit(value)
self.set_type(node, DYNAMIC)
return NO_EFFECT
def visitYieldFrom(
self, node: YieldFrom, type_ctx: Optional[Class] = None
) -> NarrowingEffect:
self.visit(node.value)
self.set_type(node, DYNAMIC)
return NO_EFFECT
def visitIndex(
self, node: Index, type_ctx: Optional[Class] = None
) -> NarrowingEffect:
self.visit(node.value, type_ctx)
self.set_type(node, self.get_type(node.value))
return NO_EFFECT
def visitCompare(
self, node: Compare, type_ctx: Optional[Class] = None
) -> NarrowingEffect:
if len(node.ops) == 1 and isinstance(node.ops[0], (Is, IsNot)):
left = node.left
right = node.comparators[0]
other = None
self.set_type(node, BOOL_TYPE.instance)
self.set_type(node.ops[0], BOOL_TYPE.instance)
self.visit(left)
self.visit(right)
if isinstance(left, (Constant, NameConstant)) and left.value is None:
other = right
elif isinstance(right, (Constant, NameConstant)) and right.value is None:
other = left
if other is not None and isinstance(other, Name):
var_type = self.get_type(other)
if (
isinstance(var_type, UnionInstance)
and not var_type.klass.is_generic_type_definition
):
effect = IsInstanceEffect(
other.id, var_type, NONE_TYPE.instance, self
)
if isinstance(node.ops[0], IsNot):
effect = effect.not_()
return effect
self.visit(node.left)
left = node.left
ltype = self.get_type(node.left)
node.ops = [type(op)() for op in node.ops]
for comparator, op in zip(node.comparators, node.ops):
self.visit(comparator)
rtype = self.get_type(comparator)
tried_right = False
if ltype.klass in rtype.klass.mro[1:]:
if ltype.bind_reverse_compare(
node, left, op, comparator, self, type_ctx
):
continue
tried_right = True
if ltype.bind_compare(node, left, op, comparator, self, type_ctx):
continue
if not tried_right:
rtype.bind_reverse_compare(node, left, op, comparator, self, type_ctx)
ltype = rtype
right = comparator
return NO_EFFECT
def visitCall(
self, node: Call, type_ctx: Optional[Class] = None
) -> NarrowingEffect:
self.visit(node.func)
result = self.get_type(node.func).bind_call(node, self, type_ctx)
return result
def visitFormattedValue(
self, node: FormattedValue, type_ctx: Optional[Class] = None
) -> NarrowingEffect:
self.visit(node.value)
self.set_type(node, DYNAMIC)
return NO_EFFECT
def visitJoinedStr(
self, node: JoinedStr, type_ctx: Optional[Class] = None
) -> NarrowingEffect:
for value in node.values:
self.visit(value)
self.set_type(node, STR_EXACT_TYPE.instance)
return NO_EFFECT
def visitConstant(
self, node: Constant, type_ctx: Optional[Class] = None
) -> NarrowingEffect:
if type_ctx is not None:
type_ctx.bind_constant(node, self)
else:
DYNAMIC.bind_constant(node, self)
return NO_EFFECT
def visitAttribute(
self, node: Attribute, type_ctx: Optional[Class] = None
) -> NarrowingEffect:
self.visit(node.value)
self.get_type(node.value).bind_attr(node, self, type_ctx)
return NO_EFFECT
def visitSubscript(
self, node: Subscript, type_ctx: Optional[Class] = None
) -> NarrowingEffect:
self.visit(node.value)
self.visit(node.slice)
val_type = self.get_type(node.value)
val_type.bind_subscr(node, self.get_type(node.slice), self)
return NO_EFFECT
def visitStarred(
self, node: Starred, type_ctx: Optional[Class] = None
) -> NarrowingEffect:
self.visit(node.value)
self.set_type(node, DYNAMIC)
return NO_EFFECT
def visitName(
self, node: Name, type_ctx: Optional[Class] = None
) -> NarrowingEffect:
cur_scope = self.symbols.scopes[self.scope]
scope = cur_scope.check_name(node.id)
if scope == SC_LOCAL and not isinstance(self.scope, Module):
var_type = self.local_types.get(node.id, DYNAMIC)
self.set_type(node, var_type)
if type_ctx is not None:
self.check_can_assign_from(type_ctx.klass, var_type.klass, node)
else:
self.set_type(node, self.cur_mod.resolve_name(node.id) or DYNAMIC)
type = self.get_type(node)
if (
isinstance(type, UnionInstance)
and not type.klass.is_generic_type_definition
):
effect = IsInstanceEffect(node.id, type, NONE_TYPE.instance, self)
return effect.not_()
return NO_EFFECT
def visitList(
self, node: ast.List, type_ctx: Optional[Class] = None
) -> NarrowingEffect:
for elt in node.elts:
self.visit(elt, DYNAMIC)
self.set_type(node, LIST_EXACT_TYPE.instance)
return NO_EFFECT
def visitTuple(
self, node: ast.Tuple, type_ctx: Optional[Class] = None
) -> NarrowingEffect:
for elt in node.elts:
self.visit(elt, DYNAMIC)
self.set_type(node, TUPLE_EXACT_TYPE.instance)
return NO_EFFECT
def set_terminal_kind(self, node: AST, level: TerminalKind) -> None:
current = self.terminals.get(node, TerminalKind.NonTerminal)
if current < level:
self.terminals[node] = level
def visitContinue(self, node: ast.Continue) -> None:
self.set_terminal_kind(node, TerminalKind.BreakOrContinue)
def visitBreak(self, node: ast.Break) -> None:
self.set_terminal_kind(node, TerminalKind.BreakOrContinue)
def visitReturn(self, node: Return) -> None:
self.set_terminal_kind(node, TerminalKind.Return)
value = node.value
if value is not None:
cur_scope = self.binding_scope
func = cur_scope.node
expected = DYNAMIC
if isinstance(func, (ast.FunctionDef, ast.AsyncFunctionDef)):
func_returns = func.returns
if func_returns:
expected = (
self.cur_mod.resolve_annotation(func_returns) or DYNAMIC_TYPE
).instance
self.visit(value, expected)
returned = self.get_type(value).klass
if returned is not DYNAMIC_TYPE and not expected.klass.can_assign_from(
returned
):
raise self.syntax_error(
f"return type must be {expected.name}, not "
+ str(self.get_type(value).name),
node,
)
def visitImportFrom(self, node: ImportFrom) -> None:
mod_name = node.module
if node.level or not mod_name:
raise NotImplementedError("relative imports aren't supported")
if mod_name == "__static__":
for alias in node.names:
name = alias.name
if name == "*":
raise self.syntax_error(
"from __static__ import * is disallowed", node
)
elif name not in self.symtable.statics.children:
raise self.syntax_error(f"unsupported static import {name}", node)
def visit_until_terminates(self, nodes: List[ast.stmt]) -> TerminalKind:
for stmt in nodes:
self.visit(stmt)
if stmt in self.terminals:
return self.terminals[stmt]
return TerminalKind.NonTerminal
def visitIf(self, node: If) -> None:
branch = self.binding_scope.branch()
effect = self.visit(node.test) or NO_EFFECT
effect.apply(self.local_types)
terminates = self.visit_until_terminates(node.body)
if node.orelse:
if_end = branch.copy()
branch.restore()
effect.reverse(self.local_types)
else_terminates = self.visit_until_terminates(node.orelse)
if else_terminates:
if terminates:
# We're the least severe terminal of our two children
self.terminals[node] = min(terminates, else_terminates)
else:
branch.restore(if_end)
elif not terminates:
# Merge end of orelse with end of if
branch.merge(if_end)
elif terminates:
effect.reverse(self.local_types)
else:
# Merge end of if w/ opening (with test effect reversed)
branch.merge(effect.reverse(branch.entry_locals))
def visitTry(self, node: Try) -> None:
branch = self.binding_scope.branch()
self.visit(node.body)
branch.merge()
post_try = branch.copy()
merges = []
if node.orelse:
self.visit(node.orelse)
merges.append(branch.copy())
for handler in node.handlers:
branch.restore(post_try)
self.visit(handler)
merges.append(branch.copy())
branch.restore(post_try)
for merge in merges:
branch.merge(merge)
if node.finalbody:
self.visit(node.finalbody)
def visitExceptHandler(self, node: ast.ExceptHandler) -> None:
htype = node.type
hname = None
if htype:
self.visit(htype)
handler_type = self.get_type(htype)
hname = node.name
if hname:
if handler_type is DYNAMIC or not isinstance(handler_type, Class):
handler_type = DYNAMIC_TYPE
decl_type = self.decl_types.get(hname)
if decl_type and decl_type.is_final:
raise self.syntax_error("Cannot assign to a Final variable", node)
self.binding_scope.declare(hname, handler_type.instance)
self.visit(node.body)
if hname is not None:
del self.decl_types[hname]
del self.local_types[hname]
def visitWhile(self, node: While) -> None:
branch = self.scopes[-1].branch()
effect = self.visit(node.test) or NO_EFFECT
effect.apply(self.local_types)
while_returns = self.visit_until_terminates(node.body) == TerminalKind.Return
if while_returns:
branch.restore()
effect.reverse(self.local_types)
else:
branch.merge(effect.reverse(branch.entry_locals))
if node.orelse:
# The or-else can happen after the while body, or without executing
# it, but it can only happen after the while condition evaluates to
# False.
effect.reverse(self.local_types)
self.visit(node.orelse)
branch.merge()
def visitFor(self, node: For) -> None:
self.visit(node.iter)
target_type = self.get_type(node.iter).get_iter_type(node.iter, self)
self.visit(node.target)
self.assign_value(node.target, target_type)
self.visit(node.body)
self.visit(node.orelse)
def visitwithitem(self, node: ast.withitem) -> None:
self.visit(node.context_expr)
optional_vars = node.optional_vars
if optional_vars:
self.visit(optional_vars)
self.assign_value(optional_vars, DYNAMIC)
class PyFlowGraph38Static(PyFlowGraphCinder):
opcode: Opcode = opcode38static.opcode
class Static38CodeGenerator(CinderCodeGenerator):
flow_graph = PyFlowGraph38Static
_default_cache: Dict[Type[ast.AST], typingCallable[[...], None]] = {}
def __init__(
self,
parent: Optional[CodeGenerator],
node: AST,
symbols: SymbolVisitor,
graph: PyFlowGraph,
symtable: SymbolTable,
modname: str,
flags: int = 0,
optimization_lvl: int = 0,
) -> None:
super().__init__(parent, node, symbols, graph, flags, optimization_lvl)
self.symtable = symtable
self.modname = modname
# Use this counter to allocate temporaries for loop indices
self._tmpvar_loopidx_count = 0
self.cur_mod: ModuleTable = self.symtable.modules[modname]
def _is_static_compiler_disabled(self, node: AST) -> bool:
if not isinstance(node, (AsyncFunctionDef, FunctionDef, ClassDef)):
# Static compilation can only be disabled for functions and classes.
return False
scope = self.scope
fn = None
if isinstance(scope, ClassScope):
klass = self.cur_mod.resolve_name(scope.name)
if klass:
assert isinstance(klass, Class)
if klass.donotcompile:
# If static compilation is disabled on the entire class, it's skipped for all contained
# methods too.
return True
fn = klass.get_own_member(node.name)
if fn is None:
# Wasn't a method, let's check if it's a module level function
fn = self.cur_mod.resolve_name(node.name)
if isinstance(fn, (Function, StaticMethod)):
return (
fn.donotcompile
if isinstance(fn, Function)
else fn.function.donotcompile
)
return False
def make_child_codegen(
self,
tree: AST,
graph: PyFlowGraph,
codegen_type: Optional[Type[CinderCodeGenerator]] = None,
) -> CodeGenerator:
if self._is_static_compiler_disabled(tree):
return super().make_child_codegen(
tree, graph, codegen_type=CinderCodeGenerator
)
graph.setFlag(self.consts.CO_STATICALLY_COMPILED)
if self.cur_mod.noframe:
graph.setFlag(self.consts.CO_NO_FRAME)
gen = StaticCodeGenerator(
self,
tree,
self.symbols,
graph,
symtable=self.symtable,
modname=self.modname,
optimization_lvl=self.optimization_lvl,
)
if not isinstance(tree, ast.ClassDef):
self._processArgTypes(tree, gen)
return gen
def _processArgTypes(self, node: AST, gen: Static38CodeGenerator) -> None:
arg_checks = []
cellvars = gen.graph.cellvars
# pyre-fixme[16]: When node is a comprehension (i.e., not a FunctionDef
# or Lambda), our caller manually adds an args attribute.
args: ast.arguments = node.args
is_comprehension = not isinstance(
node, (ast.AsyncFunctionDef, ast.FunctionDef, ast.Lambda)
)
for i, arg in enumerate(args.posonlyargs):
t = self.get_type(arg)
if t is not DYNAMIC and t is not OBJECT:
arg_checks.append(self._calculate_idx(arg.arg, i, cellvars))
arg_checks.append(t.klass.type_descr)
for i, arg in enumerate(args.args):
# Comprehension nodes don't have arguments when they're typed; make
# up for that here.
t = DYNAMIC if is_comprehension else self.get_type(arg)
if t is not DYNAMIC and t is not OBJECT:
arg_checks.append(
self._calculate_idx(arg.arg, i + len(args.posonlyargs), cellvars)
)
arg_checks.append(t.klass.type_descr)
for i, arg in enumerate(args.kwonlyargs):
t = self.get_type(arg)
if t is not DYNAMIC and t is not OBJECT:
arg_checks.append(
self._calculate_idx(
arg.arg,
i + len(args.posonlyargs) + len(args.args),
cellvars,
)
)
arg_checks.append(t.klass.type_descr)
# we should never emit arg checks for object
assert not any(td == ("builtins", "object") for td in arg_checks[1::2])
gen.emit("CHECK_ARGS", tuple(arg_checks))
def get_type(self, node: Union[AST, Delegator]) -> Value:
return self.cur_mod.types[node]
def get_node_data(
self, key: Union[AST, Delegator], data_type: Type[TType]
) -> TType:
return cast(TType, self.cur_mod.node_data[key, data_type])
def set_node_data(
self, key: Union[AST, Delegator], data_type: Type[TType], value: TType
) -> None:
self.cur_mod.node_data[key, data_type] = value
@classmethod
# pyre-fixme[14]: `make_code_gen` overrides method defined in
# `Python37CodeGenerator` inconsistently.
def make_code_gen(
cls,
module_name: str,
tree: AST,
filename: str,
flags: int,
optimize: int,
peephole_enabled: bool = True,
ast_optimizer_enabled: bool = True,
) -> Static38CodeGenerator:
# TODO: Parsing here should really be that we run declaration visitor over all nodes,
# and then perform post processing on the symbol table, and then proceed to analysis
# and compilation
symtable = SymbolTable()
decl_visit = DeclarationVisitor(module_name, filename, symtable)
decl_visit.visit(tree)
for module in symtable.modules.values():
module.finish_bind()
if ast_optimizer_enabled:
tree = AstOptimizer(optimize=optimize > 0).visit(tree)
s = symbols.SymbolVisitor()
s.visit(tree)
graph = cls.flow_graph(
module_name, filename, s.scopes[tree], peephole_enabled=peephole_enabled
)
graph.setFlag(cls.consts.CO_STATICALLY_COMPILED)
type_binder = TypeBinder(s, filename, symtable, module_name, optimize)
type_binder.visit(tree)
code_gen = cls(None, tree, s, graph, symtable, module_name, flags, optimize)
code_gen.visit(tree)
return code_gen
def make_function_graph(
self,
func: FunctionDef,
filename: str,
scopes: Dict[AST, Scope],
class_name: str,
name: str,
first_lineno: int,
) -> PyFlowGraph:
graph = super().make_function_graph(
func, filename, scopes, class_name, name, first_lineno
)
# we tagged the graph as CO_STATICALLY_COMPILED, and the last co_const entry
# will inform the runtime of the return type for the code object.
ret_type = self.get_type(func)
type_descr = ret_type.klass.type_descr
graph.extra_consts.append(type_descr)
return graph
@contextmanager
def new_loopidx(self) -> Generator[str, None, None]:
self._tmpvar_loopidx_count += 1
try:
yield f"{_TMP_VAR_PREFIX}.{self._tmpvar_loopidx_count}"
finally:
self._tmpvar_loopidx_count -= 1
def store_type_name_and_flags(self, node: ClassDef) -> None:
self.emit("INVOKE_FUNCTION", (("_static", "set_type_static"), 1))
self.storeName(node.name)
def walkClassBody(self, node: ClassDef, gen: CodeGenerator) -> None:
super().walkClassBody(node, gen)
cur_mod = self.symtable.modules[self.modname]
klass = cur_mod.resolve_name(node.name)
if not isinstance(klass, Class) or klass is DYNAMIC_TYPE:
return
class_mems = [
name for name, value in klass.members.items() if isinstance(value, Slot)
]
if klass.allow_weakrefs:
class_mems.append("__weakref__")
# In the future we may want a compatibility mode where we add
# __dict__ and __weakref__
gen.emit("LOAD_CONST", tuple(class_mems))
gen.emit("STORE_NAME", "__slots__")
count = 0
for name, value in klass.members.items():
if not isinstance(value, Slot):
continue
if value.decl_type is DYNAMIC_TYPE:
continue
gen.emit("LOAD_CONST", name)
gen.emit("LOAD_CONST", value.type_descr)
count += 1
if count:
gen.emit("BUILD_MAP", count)
gen.emit("STORE_NAME", "__slot_types__")
def visitModule(self, node: Module) -> None:
if not self.cur_mod.nonchecked_dicts:
self.emit("LOAD_CONST", 0)
self.emit("LOAD_CONST", ("chkdict",))
self.emit("IMPORT_NAME", "_static")
self.emit("IMPORT_FROM", "chkdict")
self.emit("STORE_NAME", "dict")
super().visitModule(node)
def emit_module_return(self, node: ast.Module) -> None:
self.emit("LOAD_CONST", tuple(self.cur_mod.named_finals.keys()))
self.emit("STORE_NAME", "__final_constants__")
super().emit_module_return(node)
def visitAugAttribute(self, node: AugAttribute, mode: str) -> None:
if mode == "load":
self.visit(node.value)
self.emit("DUP_TOP")
load = ast.Attribute(node.value, node.attr, ast.Load())
load.lineno = node.lineno
load.col_offset = node.col_offset
self.get_type(node.value).emit_attr(load, self)
elif mode == "store":
self.emit("ROT_TWO")
self.get_type(node.value).emit_attr(node, self)
def visitAugSubscript(self, node: AugSubscript, mode: str) -> None:
if mode == "load":
self.get_type(node.value).emit_subscr(node.obj, 1, self)
elif mode == "store":
self.get_type(node.value).emit_store_subscr(node.obj, self)
def visitAttribute(self, node: Attribute) -> None:
self.update_lineno(node)
if isinstance(node.ctx, ast.Load) and self._is_super_call(node.value):
self.emit("LOAD_GLOBAL", "super")
load_arg = self._emit_args_for_super(node.value, node.attr)
self.emit("LOAD_ATTR_SUPER", load_arg)
else:
self.visit(node.value)
self.get_type(node.value).emit_attr(node, self)
def emit_type_check(self, dest: Class, src: Class, node: AST) -> None:
if src is DYNAMIC_TYPE and dest is not OBJECT_TYPE and dest is not DYNAMIC_TYPE:
if isinstance(dest, CType):
# TODO raise this in type binding instead
raise syntax_error(
f"Cannot assign a {src.instance.name} to {dest.instance.name}",
self.graph.filename,
node,
)
self.emit("CAST", dest.type_descr)
elif not dest.can_assign_from(src):
# TODO raise this in type binding instead
raise syntax_error(
f"Cannot assign a {src.instance.name} to {dest.instance.name}",
self.graph.filename,
node,
)
def visitAssignTarget(
self, elt: expr, stmt: AST, value: Optional[expr] = None
) -> None:
if isinstance(elt, (ast.Tuple, ast.List)):
self._visitUnpack(elt)
if isinstance(value, ast.Tuple) and len(value.elts) == len(elt.elts):
for target, inner_value in zip(elt.elts, value.elts):
self.visitAssignTarget(target, stmt, inner_value)
else:
for target in elt.elts:
self.visitAssignTarget(target, stmt, None)
else:
if value is not None:
self.emit_type_check(
self.get_type(elt).klass, self.get_type(value).klass, stmt
)
else:
self.emit_type_check(self.get_type(elt).klass, DYNAMIC_TYPE, stmt)
self.visit(elt)
def visitAssign(self, node: Assign) -> None:
self.set_lineno(node)
self.visit(node.value)
dups = len(node.targets) - 1
for i in range(len(node.targets)):
elt = node.targets[i]
if i < dups:
self.emit("DUP_TOP")
if isinstance(elt, ast.AST):
self.visitAssignTarget(elt, node, node.value)
def visitAnnAssign(self, node: ast.AnnAssign) -> None:
self.set_lineno(node)
value = node.value
if value:
self.visit(value)
self.emit_type_check(
self.get_type(node.target).klass, self.get_type(value).klass, node
)
self.visit(node.target)
target = node.target
if isinstance(target, ast.Name):
# If we have a simple name in a module or class, store the annotation
if node.simple and isinstance(self.tree, (ast.Module, ast.ClassDef)):
self.emitStoreAnnotation(target.id, node.annotation)
elif isinstance(target, ast.Attribute):
if not node.value:
self.checkAnnExpr(target.value)
elif isinstance(target, ast.Subscript):
if not node.value:
self.checkAnnExpr(target.value)
self.checkAnnSubscr(target.slice)
else:
raise SystemError(
f"invalid node type {type(node).__name__} for annotated assignment"
)
if not node.simple:
self.checkAnnotation(node)
def visitConstant(self, node: Constant) -> None:
self.get_type(node).emit_constant(node, self)
def get_final_literal(self, node: AST) -> Optional[ast.Constant]:
return self.cur_mod.get_final_literal(node, self.scope)
def visitName(self, node: Name) -> None:
final_val = self.get_final_literal(node)
if final_val is not None:
# visit the constant directly
return self.defaultVisit(final_val)
self.get_type(node).emit_name(node, self)
def visitAugAssign(self, node: AugAssign) -> None:
self.get_type(node.target).emit_augassign(node, self)
def visitAugName(self, node: AugName, mode: str) -> None:
self.get_type(node).emit_augname(node, self, mode)
def visitCompare(self, node: Compare) -> None:
self.update_lineno(node)
self.visit(node.left)
cleanup = self.newBlock("cleanup")
left = node.left
for op, code in zip(node.ops[:-1], node.comparators[:-1]):
optype = self.get_type(op)
ltype = self.get_type(left)
if ltype != optype:
optype.emit_convert(ltype, self)
self.emitChainedCompareStep(op, optype, code, cleanup)
left = code
# now do the last comparison
if node.ops:
op = node.ops[-1]
optype = self.get_type(op)
ltype = self.get_type(left)
if ltype != optype:
optype.emit_convert(ltype, self)
code = node.comparators[-1]
self.visit(code)
rtype = self.get_type(code)
if rtype != optype:
optype.emit_convert(rtype, self)
optype.emit_compare(op, self)
if len(node.ops) > 1:
end = self.newBlock("end")
self.emit("JUMP_FORWARD", end)
self.nextBlock(cleanup)
self.emit("ROT_TWO")
self.emit("POP_TOP")
self.nextBlock(end)
def emitChainedCompareStep(
self,
op: cmpop,
optype: Value,
value: AST,
cleanup: Block,
jump: str = "JUMP_IF_ZERO_OR_POP",
) -> None:
self.visit(value)
rtype = self.get_type(value)
if rtype != optype:
optype.emit_convert(rtype, self)
self.emit("DUP_TOP")
self.emit("ROT_THREE")
optype.emit_compare(op, self)
self.emit(jump, cleanup)
self.nextBlock(label="compare_or_cleanup")
def visitBoolOp(self, node: BoolOp) -> None:
end = self.newBlock()
for child in node.values[:-1]:
self.get_type(child).emit_jumpif_pop(
child, end, type(node.op) == ast.Or, self
)
self.nextBlock()
self.visit(node.values[-1])
self.nextBlock(end)
def visitBinOp(self, node: BinOp) -> None:
self.get_type(node).emit_binop(node, self)
def visitUnaryOp(self, node: UnaryOp, type_ctx: Optional[Class] = None) -> None:
self.get_type(node).emit_unaryop(node, self)
def visitCall(self, node: Call) -> None:
self.get_type(node.func).emit_call(node, self)
def visitSubscript(self, node: ast.Subscript, aug_flag: bool = False) -> None:
self.get_type(node.value).emit_subscr(node, aug_flag, self)
def _visitReturnValue(self, value: ast.AST, expected: Class) -> None:
self.visit(value)
if expected is not DYNAMIC_TYPE and self.get_type(value) is DYNAMIC:
self.emit("CAST", expected.type_descr)
def visitReturn(self, node: ast.Return) -> None:
self.checkReturn(node)
expected = self.get_type(self.tree).klass
self.set_lineno(node)
value = node.value
is_return_constant = isinstance(value, ast.Constant)
opcode = "RETURN_VALUE"
oparg = 0
if value:
if not is_return_constant:
self._visitReturnValue(value, expected)
self.unwind_setup_entries(preserve_tos=True)
else:
self.unwind_setup_entries(preserve_tos=False)
self._visitReturnValue(value, expected)
if isinstance(expected, CType):
opcode = "RETURN_INT"
oparg = expected.instance.as_oparg()
else:
self.unwind_setup_entries(preserve_tos=False)
self.emit("LOAD_CONST", None)
self.emit(opcode, oparg)
def visitDictComp(self, node: DictComp) -> None:
dict_type = self.get_type(node)
if dict_type in (DICT_TYPE.instance, DICT_EXACT_TYPE.instance):
return super().visitDictComp(node)
klass = dict_type.klass
assert isinstance(klass, GenericClass) and (
klass.type_def is CHECKED_DICT_TYPE
or klass.type_def is CHECKED_DICT_EXACT_TYPE
), dict_type
self.compile_comprehension(
node,
sys.intern("<dictcomp>"),
node.key,
node.value,
"BUILD_CHECKED_MAP",
(dict_type.klass.type_descr, 0),
)
def compile_subgendict(
self, node: ast.Dict, begin: int, end: int, dict_descr: TypeDescr
) -> None:
n = end - begin
for i in range(begin, end):
k = node.keys[i]
assert k is not None
self.visit(k)
self.visit(node.values[i])
self.emit("BUILD_CHECKED_MAP", (dict_descr, n))
def visitDict(self, node: ast.Dict) -> None:
dict_type = self.get_type(node)
if dict_type in (DICT_TYPE.instance, DICT_EXACT_TYPE.instance):
return super().visitDict(node)
klass = dict_type.klass
assert isinstance(klass, GenericClass) and (
klass.type_def is CHECKED_DICT_TYPE
or klass.type_def is CHECKED_DICT_EXACT_TYPE
), dict_type
self.update_lineno(node)
elements = 0
is_unpacking = False
built_final_dict = False
# This is similar to the normal dict code generation, but instead of relying
# upon an opcode for BUILD_MAP_UNPACK we invoke the update method on the
# underlying dict type. Therefore the first dict that we create becomes
# the final dict. This allows us to not introduce a new opcode, but we should
# also be able to dispatch the INVOKE_METHOD rather efficiently.
dict_descr = dict_type.klass.type_descr
update_descr = dict_descr + ("update",)
for i, (k, v) in enumerate(zip(node.keys, node.values)):
is_unpacking = k is None
if elements == 0xFFFF or (elements and is_unpacking):
self.compile_subgendict(node, i - elements, i, dict_descr)
built_final_dict = True
elements = 0
if is_unpacking:
if not built_final_dict:
# {**foo, ...}, we need to generate the empty dict
self.emit("BUILD_CHECKED_MAP", (dict_descr, 0))
built_final_dict = True
self.emit("DUP_TOP")
self.visit(v)
self.emit_invoke_method(update_descr, 1)
self.emit("POP_TOP")
else:
elements += 1
if elements or not built_final_dict:
if built_final_dict:
self.emit("DUP_TOP")
self.compile_subgendict(
node, len(node.keys) - elements, len(node.keys), dict_descr
)
if built_final_dict:
self.emit_invoke_method(update_descr, 1)
self.emit("POP_TOP")
def visitFor(self, node: ast.For) -> None:
iter_type = self.get_type(node.iter)
return iter_type.emit_forloop(node, self)
def emit_invoke_method(self, descr: TypeDescr, arg_count: int) -> None:
# Emit a zero EXTENDED_ARG before so that we can optimize and insert the
# arg count
self.emit("EXTENDED_ARG", 0)
self.emit("INVOKE_METHOD", (descr, arg_count))
def defaultVisit(self, node: object, *args: object) -> None:
self.node = node
klass = node.__class__
meth = self._default_cache.get(klass, None)
if meth is None:
className = klass.__name__
meth = getattr(
super(Static38CodeGenerator, Static38CodeGenerator),
"visit" + className,
StaticCodeGenerator.generic_visit,
)
self._default_cache[klass] = meth
return meth(self, node, *args)
def compileJumpIf(self, test: AST, next: Block, is_if_true: bool) -> None:
self.get_type(test).emit_jumpif(test, next, is_if_true, self)
def _calculate_idx(
self, arg_name: str, non_cellvar_pos: int, cellvars: IndexedSet
) -> int:
try:
offset = cellvars.index(arg_name)
except ValueError:
return non_cellvar_pos
else:
# the negative sign indicates to the runtime/JIT that this is a cellvar
return -(offset + 1)
StaticCodeGenerator = Static38CodeGenerator