Lib/compiler/strict/loader.py (303 lines of code) (raw):

# Copyright (c) Facebook, Inc. and its affiliates. (http://www.facebook.com) from __future__ import annotations import builtins import os import sys from cinder import StrictModule from enum import Enum from importlib.abc import Loader from importlib.machinery import ( BYTECODE_SUFFIXES, EXTENSION_SUFFIXES, SOURCE_SUFFIXES, ExtensionFileLoader, FileFinder, ModuleSpec, SourceFileLoader, SourcelessFileLoader, ) from types import CodeType, ModuleType from typing import ( Callable, Collection, Iterable, List, Mapping, Optional, Tuple, Type, cast, final, ) from ..consts import CO_STATICALLY_COMPILED from .common import DEFAULT_STUB_PATH, FIXED_MODULES, MAGIC_NUMBER from .compiler import Compiler, TIMING_LOGGER_TYPE from .track_import_call import tracker # Force immediate resolution of Compiler in case it's deferred from Lazy Imports Compiler = Compiler _MAGIC_STRICT: bytes = (MAGIC_NUMBER + 2 ** 15).to_bytes(2, "little") + b"\r\n" # We don't actually need to increment anything here, because the strict modules # AST rewrite has no impact on pycs for non-strict modules. So we just always # use two zero bytes. This simplifies generating "fake" strict pycs for # known-not-to-be-strict third-party modules. _MAGIC_NONSTRICT: bytes = (0).to_bytes(2, "little") + b"\r\n" _MAGIC_LEN: int = len(_MAGIC_STRICT) @final class _PatchState(Enum): """Singleton used for tracking values which have not yet been patched.""" Patched = 1 Deleted = 2 Unpatched = 3 # Unfortunately module passed in could be a mock object, # which also has a `patch` method that clashes with the StrictModule method. # Directly get the function to avoid name clash. def _set_patch(module: StrictModule, name: str, value: object) -> None: type(module).patch(module, name, value) def _del_patch(module: StrictModule, name: str) -> None: type(module).patch_delete(module, name) @final class StrictModuleTestingPatchProxy: """Provides a proxy object which enables patching of a strict module if the module has been loaded with the StrictSourceWithPatchingFileLoader. The proxy can be used as a context manager in which case exiting the with block will result in the patches being disabled. The process will be terminated if the patches are not unapplied and the proxy is deallocated.""" def __init__(self, module: StrictModule) -> None: object.__setattr__(self, "module", module) object.__setattr__(self, "_patches", {}) object.__setattr__(self, "__name__", module.__name__) object.__setattr__( self, "_final_constants", getattr(module, "__final_constants__", ()) ) if not type(module).__patch_enabled__.__get__(module, type(module)): raise ValueError(f"strict module {module} does not allow patching") def __setattr__(self, name: str, value: object) -> None: patches = object.__getattribute__(self, "_patches") prev_patched = patches.get(name, _PatchState.Unpatched) module = object.__getattribute__(self, "module") final_constants = object.__getattribute__(self, "_final_constants") if name in final_constants: raise AttributeError( f"Cannot patch Final attribute `{name}` of module `{module.__name__}`" ) if value is prev_patched: # We're restoring the previous value del patches[name] elif prev_patched is _PatchState.Unpatched: # We're overwriting a value # only set patches[name] when name is patched for the first time patches[name] = getattr(module, name, _PatchState.Patched) if value is _PatchState.Deleted: _del_patch(module, name) else: _set_patch(module, name, value) def __delattr__(self, name: str) -> None: StrictModuleTestingPatchProxy.__setattr__(self, name, _PatchState.Deleted) def __getattribute__(self, name: str) -> object: res = getattr(object.__getattribute__(self, "module"), name) return res def __enter__(self) -> StrictModuleTestingPatchProxy: return self def __exit__(self, *excinfo: object) -> None: StrictModuleTestingPatchProxy.cleanup(self) def cleanup(self, ignore: Optional[Collection[str]] = None) -> None: patches = object.__getattribute__(self, "_patches") module = object.__getattribute__(self, "module") for name, value in list(patches.items()): if ignore and name in ignore: del patches[name] continue if value is _PatchState.Patched: # value is patched means that module originally # does not contain this field. try: _del_patch(module, name) except AttributeError: pass finally: del patches[name] else: setattr(self, name, value) assert not patches def __del__(self) -> None: patches = object.__getattribute__(self, "_patches") if patches: print( "Patch(es)", ", ".join(patches.keys()), "failed to be detached from strict module", "'" + object.__getattribute__(self, "module").__name__ + "'", file=sys.stderr, ) os.abort() __builtins__: ModuleType class StrictSourceFileLoader(SourceFileLoader): strict: bool = False compiler: Optional[Compiler] = None module: Optional[ModuleType] = None def __init__( self, fullname: str, path: str, import_path: Optional[Iterable[str]] = None, stub_path: Optional[str] = None, allow_list_prefix: Optional[Iterable[str]] = None, allow_list_exact: Optional[Iterable[str]] = None, enable_patching: bool = False, log_source_load: Optional[Callable[[str, Optional[str], bool], None]] = None, track_import_call: bool = False, init_cached_properties: Optional[ Callable[ [Mapping[str, str | Tuple[str, bool]]], Callable[[Type[object]], Type[object]], ] ] = None, log_time_func: Optional[Callable[[], TIMING_LOGGER_TYPE]] = None, use_py_compiler: bool = False, ) -> None: self.name = fullname self.path = path self.import_path: Iterable[str] = import_path or list(sys.path) configured_stub_path = sys._xoptions.get( "strict-module-stubs-path" ) or os.getenv("PYTHONSTRICTMODULESTUBSPATH") if stub_path is None: stub_path = configured_stub_path or DEFAULT_STUB_PATH if stub_path and not os.path.isdir(stub_path): raise ValueError(f"Strict module stubs path does not exist: {stub_path}") self.stub_path: str = stub_path self.allow_list_prefix: Iterable[str] = allow_list_prefix or [] self.allow_list_exact: Iterable[str] = allow_list_exact or [] self.enable_patching = enable_patching self.log_source_load: Optional[ Callable[[str, Optional[str], bool], None] ] = log_source_load self.bytecode_found = False self.bytecode_path: Optional[str] = None self.track_import_call = track_import_call self.init_cached_properties = init_cached_properties self.log_time_func = log_time_func self.use_py_compiler = use_py_compiler @classmethod def ensure_compiler( cls, path: Iterable[str], stub_path: str, allow_list_prefix: Iterable[str], allow_list_exact: Iterable[str], log_time_func: Optional[Callable[[], TIMING_LOGGER_TYPE]], enable_patching: bool = False, ) -> Compiler: if (comp := cls.compiler) is None: comp = cls.compiler = Compiler( path, stub_path, allow_list_prefix, allow_list_exact, raise_on_error=True, log_time_func=log_time_func, enable_patching=enable_patching, ) return comp def get_data(self, path: bytes | str) -> bytes: assert isinstance(path, str) is_pyc = False if path.endswith(tuple(BYTECODE_SUFFIXES)): is_pyc = True path = add_strict_tag(path, self.enable_patching) self.bytecode_path = path data = super().get_data(path) if is_pyc: self.bytecode_found = True magic = data[:_MAGIC_LEN] if magic == _MAGIC_NONSTRICT: self.strict = False elif magic == _MAGIC_STRICT: self.strict = True else: # This is a bit ugly: OSError is the only kind of error that # get_code() ignores from get_data(). But this is way better # than the alternative of copying and modifying everything. raise OSError(f"Bad magic number {data[:4]!r} in {path}") data = data[_MAGIC_LEN:] return data def set_data(self, path: bytes | str, data: bytes, *, _mode=0o666) -> None: assert isinstance(path, str) if path.endswith(tuple(BYTECODE_SUFFIXES)): path = add_strict_tag(path, self.enable_patching) magic = _MAGIC_STRICT if self.strict else _MAGIC_NONSTRICT data = magic + data return super().set_data(path, data, _mode=_mode) def should_force_strict(self) -> bool: return False # pyre-ignore[40]: Non-static method `source_to_code` cannot override a static # method defined in `importlib.abc.InspectLoader`. def source_to_code( self, data: bytes | str, path: str, *, _optimize: int = -1 ) -> CodeType: log_source_load = self.log_source_load if log_source_load is not None: log_source_load(path, self.bytecode_path, self.bytecode_found) # pyre-ignore[28]: typeshed doesn't know about _optimize arg code = super().source_to_code(data, path, _optimize=_optimize) force = self.should_force_strict() if force or "__strict__" in code.co_names or "__static__" in code.co_names: # Since a namespace package will never call `source_to_code` (there # is no source!), there are only two possibilities here: non-package # (submodule_search_paths should be None) or regular package # (submodule_search_paths should have one entry, the directory # containing the "__init__.py"). submodule_search_locations = None if path.endswith("__init__.py"): submodule_search_locations = [path[:12]] # Usually _optimize will be -1 (which means "default to the value # of sys.flags.optimize"). But this default happens very deep in # Python's compiler (in PyAST_CompileObject), so if we just pass # around -1 and rely on that, it means we can't make any of our own # decisions based on that flag. So instead we do the default right # here, so we have the correct optimize flag value throughout our # compiler. opt = sys.flags.optimize if _optimize == -1 else _optimize # Let the ast transform attempt to validate the strict module. This # will return an unmodified module if import __strict__ isn't # actually at the top-level code, is_valid_strict = self.ensure_compiler( self.import_path, self.stub_path, self.allow_list_prefix, self.allow_list_exact, self.log_time_func, self.enable_patching, ).load_compiled_module_from_source( data, path, self.name, opt, submodule_search_locations, self.track_import_call, force_strict=force, ) self.strict = is_valid_strict assert code is not None return code self.strict = False return code def exec_module(self, module: ModuleType) -> None: # This ends up being slightly convoluted, because create_module # gets called, then source_to_code gets called, so we don't know if # we have a strict module until after we were requested to create it. # So we'll run the module code we get back in the module that was # initially published in sys.modules, check and see if it's a strict # module, and then run the strict module body after replacing the # entry in sys.modules with a StrictModule entry. This shouldn't # really be observable because no user code runs between publishing # the normal module in sys.modules and replacing it with the # StrictModule. code = self.get_code(module.__name__) if code is None: raise ImportError( f"Cannot import module {module.__name__}; get_code() returned None" ) # fix up the pyc path cached = getattr(module, "__cached__", None) if cached: # pyre-ignore[16]: `ModuleType` has no attribute `__cached__`. module.__cached__ = cached = add_strict_tag(cached, self.enable_patching) spec: Optional[ModuleSpec] = module.__spec__ if cached and spec and spec.cached: spec.cached = cached if self.track_import_call: tracker.enter_import() if self.strict: if spec is None: raise ImportError(f"Missing module spec for {module.__name__}") new_dict = { "<fixed-modules>": cast(object, FIXED_MODULES), "<builtins>": builtins.__dict__, "<init-cached-properties>": self.init_cached_properties, } if code.co_flags & CO_STATICALLY_COMPILED: new_dict["<imported-from>"] = code.co_consts[-1] new_dict.update(module.__dict__) strict_mod = StrictModule(new_dict, self.enable_patching) sys.modules[module.__name__] = strict_mod exec(code, new_dict) else: exec(code, module.__dict__) if self.track_import_call: tracker.exit_import() def add_strict_tag(path: str, enable_patching: bool) -> str: base, __, ext = path.rpartition(".") enable_patching_marker = ".patch" if enable_patching else "" return f"{base}.strict{enable_patching_marker}.{ext}" def _get_supported_file_loaders() -> List[Tuple[Loader, List[str]]]: """Returns a list of file-based module loaders. Each item is a tuple (loader, suffixes). """ extensions = ExtensionFileLoader, EXTENSION_SUFFIXES source = StrictSourceFileLoader, SOURCE_SUFFIXES bytecode = SourcelessFileLoader, BYTECODE_SUFFIXES return cast(List[Tuple[Loader, List[str]]], [extensions, source, bytecode]) def install() -> None: """Installs a loader which is capable of loading and validating strict modules""" supported_loaders = _get_supported_file_loaders() for index, hook in enumerate(sys.path_hooks): if not isinstance(hook, type): sys.path_hooks.insert(index, FileFinder.path_hook(*supported_loaders)) break else: sys.path_hooks.insert(0, FileFinder.path_hook(*supported_loaders)) # We need to clear the path_importer_cache so that our new FileFinder will # start being used for existing directories we've loaded modules from. sys.path_importer_cache.clear() if __name__ == "__main__": install() del sys.argv[0] mod: object = __import__(sys.argv[0]) if not isinstance(mod, StrictModule): raise TypeError( "compiler.strict.loader should be used to run strict modules: " + type(mod).__name__ ) mod.__main__() # pyre-ignore[16]: `object` has no attribute `__main__`.