components/_impl/workers/subprocess_rpc.py (288 lines of code) (raw):

"""Utilities to handle communication between parent worker. This module implements three principle facilities: 1) Raw IPC (via the Pipe class) 2) Exception propagation (via the SerializedException class) 3) A run loop for the worker (via the run_loop function) """ import contextlib import dataclasses import datetime import io import marshal import os import pickle import struct import sys import textwrap import threading import time import traceback import types import psutil import typing # Shared static values / namespace between worker and parent BOOTSTRAP_IMPORT_SUCCESS = b"BOOTSTRAP_IMPORT_SUCCESS" BOOTSTRAP_INPUT_LOOP_SUCCESS = b"BOOTSTRAP_INPUT_LOOP_SUCCESS" WORKER_IMPL_NAMESPACE = "__worker_impl_namespace" # Constants for passing to and from pipes _CHECK = b"\x00\x00" _TIMEOUT = b"\x01\x01" _DEAD = b"\x02\x02" assert len(_CHECK) == len(_TIMEOUT) == len(_DEAD) _ULL = "Q" # Unsigned long long _ULL_SIZE = len(struct.pack(_ULL, 0)) assert _ULL_SIZE == 8 # Text encoding for input commands. ENCODING = "utf-8" SUCCESS = "SUCCESS" # In Python, `sys.exit()` is a soft exit. It throws a SystemExit, and only # exits if that is not caught. `os._exit()` is not catchable, and is suitable # for cases where we really, really need to exit. This is of particular # importance because the worker run loop does its very best to swallow # exceptions. HARD_EXIT = "import os\nos._exit(0)".encode(ENCODING) # Precompute serialized normal return values EMPTY_RESULT = marshal.dumps({}) SUCCESS_BYTES = marshal.dumps(SUCCESS) # ============================================================================= # == Raw Communication ======================================================== # ============================================================================= # Windows does not allow subprocesses to inherit file descriptors, so instead # we have to go the the OS and get get the handle for the backing resource. IS_WINDOWS = sys.platform == "win32" if IS_WINDOWS: import msvcrt def to_handle(fd: typing.Optional[int]) -> typing.Optional[int]: return None if fd is None else msvcrt.get_osfhandle(fd) def from_handle(handle: typing.Optional[int], mode: int) -> typing.Optional[int]: return None if handle is None else msvcrt.open_osfhandle(handle, mode) else: to_handle = lambda fd: fd from_handle = lambda fd, _: fd class _TimeoutPIPE: """Allow Pipe to interrupt its read. `os.read` is a syscall, which means it is not interruptable. This means that normal timeout mechanisms such as `asyncio.wait_for(..., timeout=...)` will not work because they rely on the awaited function returning control to the event loop. An alternate formulation uses `run_in_executor` and `asyncio.wait`, which places the read on a side thread under the hood. However this is also not suitable, because: 1) This additional machinery increases the cost when data is already present in the Pipe (most common case) ~1000x, from O(us) to O(ms) 2) We have to poll the future, which wastes the awaitable nature `read` Instead of trying to interrupt the pipe read, we cause it to terminate by writing to the pipe; because we control the read (via `Pipe.read`) we can catch the sentinel timeout value and raise appropriately. This class is designed to be extremely lightweight. Timeouts should be on the order of seconds (or minutes), and are only to prevent deadlocks in the case of catastrophic worker failure. As a result, we prioritize low resource usage over the ability to support small timeouts. """ _singleton_lock = threading.Lock() _singleton: typing.Optional["_TimeoutPIPE"] = None _loop_lock = threading.Lock() _active_reads: typing.Dict[int, typing.Tuple[float, float, int]] _loop_cadence = 1 # second @classmethod def singleton(cls) -> "_TimeoutPIPE": # This class will spawn a thread, so we only want one active at a time. with cls._singleton_lock: if cls._singleton is None: cls._singleton = cls() return cls._singleton def __init__(self) -> None: self._active_reads = {} self._thread = threading.Thread(target=self._loop) self._thread.daemon = True self._thread.start() def _loop(self): # This loop is scoped to the life of the process, so we rely on process # teardown to pull the rug out from under the daemonic thread running # this function. while True: time.sleep(self._loop_cadence) now = time.time() with self._loop_lock: for w_fd, (timeout, start_time, writer_pid) in tuple(self._active_reads.items()): # if child process is in zombie status, check its exit code if psutil.pid_exists(writer_pid): p = psutil.Process(writer_pid) if p.status() == psutil.STATUS_ZOMBIE: # wait 1 second for the exit code exit_code = p.wait(timeout=self._loop_cadence) if exit_code: os.write(w_fd, _DEAD + struct.pack(_ULL, abs(int(exit_code)))) self.pop(w_fd) # check if process timeout if timeout: if now - start_time >= timeout and w_fd in self._active_reads: os.write(w_fd, _TIMEOUT) self.pop(w_fd) def pop(self, w_fd: int) -> None: self._active_reads.pop(w_fd, None) @classmethod @contextlib.contextmanager def maybe_timeout_read(cls, pipe: "Pipe") -> None: timeout = pipe.timeout # Spawn a loop thread to periodically check the liveness of subprocess w_fd = pipe.write_fd assert w_fd is not None, "Cannot timeout without write file descriptor." assert pipe.get_writer_pid() is not None, "Cannot check process liveness without pid." singleton = cls.singleton() with singleton._loop_lock: # This will only occur in the case of concurrent reads on different # threads (not supported) or a leaked case. assert w_fd not in singleton._active_reads, f"{w_fd} is already being watched." singleton._active_reads[w_fd] = (timeout, time.time(), pipe.get_writer_pid()) try: yield finally: singleton.pop(w_fd) class Pipe: """Helper class to move data in a robust fashion. This class handles: 1) Child process liveness checks if pipe is read by parent 2) File descriptor lifetimes 3) File descriptor inheritance 4) Message packing and unpacking 5) (Optional) timeouts for reads NOTE: we don't check liveness of parent since the parent process shouldn't regularly fail without proper clean up. """ def __init__( self, # writer_pid only exists when `self` is a pipe read by parent # in which case, write_pid is the pid of the child process writer_pid: typing.Optional[int] = None, read_handle: typing.Optional[int] = None, write_handle: typing.Optional[int] = None, timeout: typing.Optional[float] = None, timeout_callback: typing.Callable[[], typing.NoReturn] = (lambda: None), ) -> None: self._writer_pid = writer_pid self._owns_pipe = read_handle is None and write_handle is None if self._owns_pipe: self.read_fd, self.write_fd = os.pipe() else: self.read_fd = from_handle(read_handle, os.O_RDONLY) self.write_fd = from_handle(write_handle, os.O_WRONLY) self.read_handle = read_handle or to_handle(self.read_fd) self.write_handle = write_handle or to_handle(self.write_fd) self.timeout = timeout self.timeout_callback = timeout_callback def _read(self, size: int) -> bytes: """Handle the low level details of reading from the PIPE.""" if self.read_fd is None: raise IOError("Cannot read from PIPE, we do not have the read handle") # `self._write_pid` is not None iff `self` is the read pipe from parent process # only support timeout and child process liveness check in this case if self._writer_pid: with _TimeoutPIPE.maybe_timeout_read(self): raw_msg = os.read(self.read_fd, len(_CHECK) + size) else: raw_msg = os.read(self.read_fd, len(_CHECK) + size) check_bytes, msg = raw_msg[:len(_CHECK)], raw_msg[len(_CHECK):] if check_bytes == _TIMEOUT: self.timeout_callback() # Give caller the chance to cleanup. raise IOError(f"Exceeded timeout: {self.timeout}") if check_bytes == _DEAD: raise IOError(f"Subprocess terminates with code {int.from_bytes(msg, sys.byteorder)}") if check_bytes != _CHECK: raise IOError(f"{check} != {_CHECK}, {msg}") if len(msg) != size: raise IOError(f"len(msg) != size: {len(msg)} vs. {size}") return msg def read(self) -> bytes: msg_size = struct.unpack(_ULL, self._read(_ULL_SIZE))[0] return self._read(msg_size) def write(self, msg: bytes) -> None: if self.write_fd is None: raise IOError("Cannot write from PIPE, we do not have the write handle") assert isinstance(msg, bytes), msg packed_msg = ( # First read: message length _CHECK + struct.pack(_ULL, len(msg)) + # Second read: message contents _CHECK + msg ) os.write(self.write_fd, packed_msg) def get_writer_pid(self) -> int: assert self._writer_pid is not None, "Writer pid is not specified. Maybe calling from child process or input pipe.\ Please report a bug." return self._writer_pid def set_writer_pid(self, writer_pid: int) -> None: self._writer_pid = writer_pid def _close_fds(self): """Factor cleanup to a helper so we can test when it runs.""" os.close(self.read_fd) os.close(self.write_fd) def __del__(self) -> None: if self._owns_pipe: self._close_fds() # ============================================================================= # == Exception Propagation =================================================== # ============================================================================= class ExceptionUnpickler(pickle.Unpickler): """Unpickler which is specialized for Exception types. When we catch an exception that we want to raise in another process, we need to include the type of Exception. For custom exceptions this is a problem, because pickle dynamically resolves imports which means we might not be able to unpickle in the parent. (And reviving them by replaying the constructor args might not work.) So in the interest of robustness, we confine ourselves to builtin Exceptions. (With UnserializableException as a fallback.) However it is not possible to marshal even builtin Exception types, so instead we use pickle and check that the type is builtin in `find_class`. """ @classmethod def load_bytes(cls, data: bytes) -> typing.Type[Exception]: result = cls(io.BytesIO(data)).load() # Make sure we have an Exception class, but not an instantiated # Exception. if not issubclass(result, Exception): raise pickle.UnpicklingError(f"{result} is not an Exception") if isinstance(result, Exception): raise pickle.UnpicklingError( f"{result} is an Exception instance, not a class.") return result # type: ignore[no-any-return] def find_class(self, module: str, name: str) -> typing.Any: if module != "builtins": raise pickle.UnpicklingError(f"Invalid object: {module}.{name}") return super().find_class(module, name) class UnserializableException(Exception): """Fallback class for if a non-builtin Exception is raised.""" def __init__(self, type_repr: str, args_repr: str) -> None: self.type_repr = type_repr self.args_repr = args_repr super().__init__(type_repr, args_repr) class ChildTraceException(Exception): """Used to display a raising child's stack trace in the parent's stderr.""" pass @dataclasses.dataclass(init=True, frozen=True) class SerializedException: _is_serializable: bool _type_bytes: bytes _args_bytes: bytes # Fallbacks for UnserializableException _type_repr: str _args_repr: str _traceback_print: str @staticmethod def from_exception(e: Exception, tb: types.TracebackType) -> "SerializedException": """Best effort attempt to serialize Exception. Because this will be used to communicate from a subprocess to its parent, we want to surface as much information as possible. It is not possible to serialize a traceback because it is too intertwined with the runtime; however what we really want is the traceback so we can print it. We can grab that string and send it without issue. (And providing a stack trace is very important for debuggability.) ExceptionUnpickler explicitly refuses to load any non-builtin exception (for the same reason we prefer `marshal` to `pickle`), so we won't be able to serialize all cases. However we don't want to simply give up as this will make it difficult for a user to diagnose what's going on. So instead we extract what information we can, and raise an UnserializableException in the main process with whatever we were able to scrape up from the child process. """ try: print_file = io.StringIO() traceback.print_exception( etype=type(e), value=e, tb=tb, file=print_file, ) print_file.seek(0) traceback_print: str = print_file.read() except Exception: traceback_print = textwrap.dedent(""" Traceback Failed to extract traceback from worker. This is not expected. """).strip() try: args_bytes: bytes = marshal.dumps(e.args) type_bytes = pickle.dumps(e.__class__) # Make sure we'll be able to get something out on the other side. revived_type = ExceptionUnpickler.load_bytes(data=type_bytes) revived_e = revived_type(*marshal.loads(args_bytes)) is_serializable: bool = True except Exception: is_serializable = False args_bytes = b"" type_bytes = b"" # __repr__ can contain arbitrary code, so we can't trust it to noexcept. def hardened_repr(o: typing.Any) -> str: try: return repr(o) except Exception: return "< Unknown >" return SerializedException( _is_serializable=is_serializable, _type_bytes=type_bytes, _args_bytes=args_bytes, _type_repr=hardened_repr(e.__class__), _args_repr=hardened_repr(getattr(e, "args", None)), _traceback_print=traceback_print, ) @staticmethod def raise_from( serialized_e: "SerializedException", extra_context: typing.Optional[str] = None, ) -> None: """Revive `serialized_e`, and raise. We raise the revived exception type (if possible) so that any higher try catch logic will see the original exception type. In other words: ``` try: worker.run("assert False") except AssertionError: ... ``` will flow identically to: ``` try: assert False except AssertionError: ... ``` If for some reason we can't move the true exception type to the main process (e.g. a custom Exception) we raise UnserializableException as a fallback. """ if serialized_e._is_serializable: revived_type = ExceptionUnpickler.load_bytes(data=serialized_e._type_bytes) e = revived_type(*marshal.loads(serialized_e._args_bytes)) else: e = UnserializableException(serialized_e._type_repr, serialized_e._args_repr) traceback_str = serialized_e._traceback_print if extra_context: traceback_str = f"{traceback_str}\n{extra_context}" raise e from ChildTraceException(traceback_str) # ============================================================================= # == Snippet Execution ======================================================= # ============================================================================= def _log_progress(suffix: str) -> None: now = datetime.datetime.now().strftime("[%Y-%m-%d] %H:%M:%S.%f") print(f"{now}: TIMER_SUBPROCESS_{suffix}") def _run_block( *, input_pipe: Pipe, output_pipe: Pipe, globals_dict: typing.Dict[str, typing.Any], ): result = EMPTY_RESULT try: _log_progress("BEGIN_READ") cmd = input_pipe.read().decode(ENCODING) _log_progress("BEGIN_EXEC") exec( # noqa: P204 compile(cmd, "<subprocess-worker>", "exec"), globals_dict ) _log_progress("SUCCESS") result = SUCCESS_BYTES except (Exception, KeyboardInterrupt, SystemExit) as e: tb = sys.exc_info()[2] assert tb is not None serialized_e = SerializedException.from_exception(e, tb) result = marshal.dumps(dataclasses.asdict(serialized_e)) _log_progress("FAILED") finally: output_pipe.write(result) _log_progress("FINISHED") sys.stdout.flush() sys.stderr.flush() def run_loop( *, input_handle: int, output_pipe: Pipe, load_handle: int, ) -> None: input_pipe = Pipe(read_handle=input_handle) # In general, we want a clean separation between user code and framework # code. However, certain methods in SubprocessWorker (store and load) # want to access implementation details in this module. As a result, we # run tasks through a context where globals start out clean EXCEPT for # a namespace where we can stash implementation details. globals_dict = { WORKER_IMPL_NAMESPACE: { "subprocess_rpc": sys.modules[__name__], "marshal": marshal, "load_pipe": Pipe(write_handle=load_handle) } } output_pipe.write(BOOTSTRAP_INPUT_LOOP_SUCCESS) while True: _run_block( input_pipe=input_pipe, output_pipe=output_pipe, globals_dict=globals_dict, )