Lib/compiler/strict/compiler.py (228 lines of code) (raw):
# Copyright (c) Facebook, Inc. and its affiliates. (http://www.facebook.com)
from __future__ import annotations
import ast
import builtins
from contextlib import nullcontext
from symtable import SymbolTable as PythonSymbolTable, SymbolTableFactory
from types import CodeType
from typing import (
Callable,
ContextManager,
Iterable,
List,
Optional,
Tuple,
final,
Dict,
Set,
TYPE_CHECKING,
)
from _static import __build_cinder_class__
from _strictmodule import (
StrictAnalysisResult,
StrictModuleLoader,
NONSTRICT_MODULE_KIND,
STATIC_MODULE_KIND,
STUB_KIND_MASK_TYPING,
)
from ..errors import TypedSyntaxError
from ..pycodegen import compile as python_compile
from ..readonly import readonly_compile
from ..static import Compiler as StaticCompiler, ModuleTable, StaticCodeGenerator
from . import strict_compile
from .class_conflict_checker import check_class_conflict
from .common import StrictModuleError
from .rewriter import rewrite, remove_annotations
if TYPE_CHECKING:
from _strictmodule import IStrictModuleLoader, StrictModuleLoaderFactory
def getSymbolTable(mod: StrictAnalysisResult) -> PythonSymbolTable:
"""
Construct a symtable object from analysis result
"""
return SymbolTableFactory()(mod.symtable, mod.file_name)
TIMING_LOGGER_TYPE = Callable[[str, str, str], ContextManager[None]]
@final
class Compiler(StaticCompiler):
def __init__(
self,
import_path: Iterable[str],
stub_root: str,
allow_list_prefix: Iterable[str],
allow_list_exact: Iterable[str],
log_time_func: Optional[Callable[[], TIMING_LOGGER_TYPE]] = None,
raise_on_error: bool = False,
enable_patching: bool = False,
loader_factory: StrictModuleLoaderFactory = StrictModuleLoader,
use_py_compiler: bool = False,
) -> None:
super().__init__(StaticCodeGenerator)
self.import_path: List[str] = list(import_path)
self.stub_root = stub_root
self.allow_list_prefix = allow_list_prefix
self.allow_list_exact = allow_list_exact
self.loader: IStrictModuleLoader = loader_factory(
self.import_path,
str(stub_root),
list(allow_list_prefix),
list(allow_list_exact),
True,
)
self.raise_on_error = raise_on_error
self.log_time_func = log_time_func
self.enable_patching = enable_patching
self.track_import_call: bool = False
self.not_static: Set[str] = set()
self.use_py_compiler = use_py_compiler
self.original_builtins: Dict[str, object] = dict(__builtins__)
def import_module(self, name: str, optimize: int) -> Optional[ModuleTable]:
res = self.modules.get(name)
if res is not None:
return res
if name in self.not_static:
return None
mod = self.loader.check(name)
if mod.is_valid and name not in self.modules and len(mod.errors) == 0:
modKind = mod.module_kind
if modKind == STATIC_MODULE_KIND:
root = mod.ast_preprocessed
stubKind = mod.stub_kind
if STUB_KIND_MASK_TYPING & stubKind:
root = remove_annotations(root)
root = self._get_rewritten_ast(name, mod, root, optimize)
log = self.log_time_func
ctx = (
log()(name, mod.file_name, "declaration_visit")
if log
else nullcontext()
)
with ctx:
root = self.add_module(name, mod.file_name, root, optimize)
else:
self.not_static.add(name)
return self.modules.get(name)
def _get_rewritten_ast(
self, name: str, mod: StrictAnalysisResult, root: ast.Module, optimize: int
) -> ast.Module:
symbols = getSymbolTable(mod)
return rewrite(
root,
symbols,
mod.file_name,
name,
optimize=optimize,
is_static=True,
track_import_call=self.track_import_call,
builtins=self.original_builtins,
)
def load_compiled_module_from_source(
self,
source: str | bytes,
filename: str,
name: str,
optimize: int,
submodule_search_locations: Optional[List[str]] = None,
track_import_call: bool = False,
force_strict: bool = False,
) -> Tuple[CodeType | None, bool]:
if force_strict:
self.loader.set_force_strict_by_name(name)
mod = self.loader.check_source(
source, filename, name, submodule_search_locations or []
)
errors = mod.errors
is_valid_strict = (
mod.is_valid
and len(errors) == 0
and (force_strict or (mod.module_kind != NONSTRICT_MODULE_KIND))
)
if errors and self.raise_on_error:
# if raise on error, just raise the first error
error = errors[0]
raise StrictModuleError(error[0], error[1], error[2], error[3])
elif is_valid_strict:
symbols = getSymbolTable(mod)
try:
check_class_conflict(mod.ast, filename, symbols)
except StrictModuleError as e:
if self.raise_on_error:
raise
mod.errors.append((e.msg, e.filename, e.lineno, e.col))
if not is_valid_strict:
code = self._compile_basic(mod.ast, filename, optimize)
elif mod.module_kind == STATIC_MODULE_KIND:
code = self._compile_static(
mod, filename, name, optimize, track_import_call
)
else:
code = self._compile_strict(
mod, filename, name, optimize, track_import_call
)
return code, is_valid_strict
def _compile_readonly(
self, name: str, root: ast.Module, filename: str, optimize: int
) -> CodeType:
"""
TODO: this should eventually replace compile_basic once all python sources
are compiled through this compiler
"""
return readonly_compile(name, filename, root, flags=0, optimize=optimize)
def _compile_basic(
self, root: ast.Module, filename: str, optimize: int
) -> CodeType:
compile_method = python_compile if self.use_py_compiler else compile
return compile_method(
root,
filename,
"exec",
optimize=optimize,
)
def _compile_strict(
self,
mod: StrictAnalysisResult,
filename: str,
name: str,
optimize: int,
track_import_call: bool,
) -> CodeType:
symbols = getSymbolTable(mod)
tree = rewrite(
mod.ast_preprocessed,
symbols,
filename,
name,
optimize=optimize,
track_import_call=track_import_call,
builtins=self.original_builtins,
)
return strict_compile(name, filename, tree, optimize, self.original_builtins)
def _compile_static(
self,
mod: StrictAnalysisResult,
filename: str,
name: str,
optimize: int,
track_import_call: bool,
) -> CodeType | None:
self._ensure_build_class_patched()
self.track_import_call = track_import_call
root = self.ast_cache.get(name)
if root is None:
root = self._get_rewritten_ast(name, mod, mod.ast_preprocessed, optimize)
code = None
try:
log = self.log_time_func
ctx = log()(name, filename, "compile") if log else nullcontext()
with ctx:
code = self.compile(
name,
filename,
root,
optimize,
enable_patching=self.enable_patching,
builtins=self.original_builtins,
)
except TypedSyntaxError as e:
err = StrictModuleError(
e.msg or "unknown error during static compilation",
e.filename or filename,
e.lineno or 1,
0,
)
mod.errors.append(err)
if self.raise_on_error:
raise err
return code
def _ensure_build_class_patched(self) -> None:
# pyre-ignore[61]: __build_class__ isn't exposed in builtins.pyi.
if builtins.__build_class__ is not __build_cinder_class__:
builtins.__build_class__ = __build_cinder_class__