proxy_worker/dispatcher.py (398 lines of code) (raw):

# Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. import asyncio import concurrent.futures import logging import os import queue import sys import threading import traceback import typing from asyncio import AbstractEventLoop from dataclasses import dataclass from typing import Any, Optional import grpc from proxy_worker import protos from proxy_worker.logging import ( CONSOLE_LOG_PREFIX, disable_console_logging, enable_console_logging, error_logger, is_system_log_category, logger, ) from proxy_worker.utils.common import ( get_app_setting, get_script_file_name, is_envvar_true, ) from proxy_worker.utils.constants import ( PYTHON_ENABLE_DEBUG_LOGGING, PYTHON_THREADPOOL_THREAD_COUNT, ) from proxy_worker.version import VERSION from .utils.dependency import DependencyManager # Library worker import reloaded in init and reload request _library_worker = None class ContextEnabledTask(asyncio.Task): AZURE_INVOCATION_ID = '__azure_function_invocation_id__' def __init__(self, coro, loop, context=None, **kwargs): super().__init__(coro, loop=loop, context=context, **kwargs) current_task = asyncio.current_task(loop) if current_task is not None: invocation_id = getattr( current_task, self.AZURE_INVOCATION_ID, None) if invocation_id is not None: self.set_azure_invocation_id(invocation_id) def set_azure_invocation_id(self, invocation_id: str) -> None: setattr(self, self.AZURE_INVOCATION_ID, invocation_id) _invocation_id_local = threading.local() def get_current_invocation_id() -> Optional[Any]: loop = asyncio._get_running_loop() if loop is not None: current_task = asyncio.current_task(loop) if current_task is not None: task_invocation_id = getattr(current_task, ContextEnabledTask.AZURE_INVOCATION_ID, None) if task_invocation_id is not None: return task_invocation_id return getattr(_invocation_id_local, 'invocation_id', None) class AsyncLoggingHandler(logging.Handler): def emit(self, record: logging.LogRecord) -> None: # Since we disable console log after gRPC channel is initiated, # we should redirect all the messages into dispatcher. # When dispatcher receives an exception, it should switch back # to console logging. However, it is possible that # __current_dispatcher__ is set to None as there are still messages # buffered in this handler, not calling the emit yet. msg = self.format(record) try: Dispatcher.current.on_logging(record, msg) except RuntimeError as runtime_error: # This will cause 'Dispatcher not found' failure. # Logging such of an issue will cause infinite loop of gRPC logging # To mitigate, we should suppress the 2nd level error logging here # and use print function to report exception instead. print(f'{CONSOLE_LOG_PREFIX} ERROR: {str(runtime_error)}', file=sys.stderr, flush=True) @dataclass class WorkerRequest: name: str request: str properties: Optional[dict[str, typing.Any]] = None class DispatcherMeta(type): __current_dispatcher__: Optional["Dispatcher"] = None @property def current(cls): disp = cls.__current_dispatcher__ if disp is None: raise RuntimeError('no currently running Dispatcher is found') return disp class Dispatcher(metaclass=DispatcherMeta): _GRPC_STOP_RESPONSE = object() def __init__(self, loop: AbstractEventLoop, host: str, port: int, worker_id: str, request_id: str, grpc_connect_timeout: float, grpc_max_msg_len: int = -1) -> None: self._loop = loop self._host = host self._port = port self._request_id = request_id self._worker_id = worker_id self._grpc_connect_timeout: float = grpc_connect_timeout self._grpc_max_msg_len: int = grpc_max_msg_len self._old_task_factory: Optional[Any] = None self._grpc_resp_queue: queue.Queue = queue.Queue() self._grpc_connected_fut = loop.create_future() self._grpc_thread: Optional[threading.Thread] = threading.Thread( name='grpc_local-thread', target=self.__poll_grpc) self._sync_call_tp: Optional[concurrent.futures.Executor] = ( self._create_sync_call_tp(self._get_sync_tp_max_workers())) def on_logging(self, record: logging.LogRecord, formatted_msg: str) -> None: if record.levelno >= logging.CRITICAL: log_level = protos.RpcLog.Critical elif record.levelno >= logging.ERROR: log_level = protos.RpcLog.Error elif record.levelno >= logging.WARNING: log_level = protos.RpcLog.Warning elif record.levelno >= logging.INFO: log_level = protos.RpcLog.Information elif record.levelno >= logging.DEBUG: log_level = protos.RpcLog.Debug else: log_level = getattr(protos.RpcLog, 'None') if is_system_log_category(record.name): log_category = protos.RpcLog.RpcLogCategory.Value('System') else: # customers using logging will yield 'root' in record.name log_category = protos.RpcLog.RpcLogCategory.Value('User') log = dict( level=log_level, message=formatted_msg, category=record.name, log_category=log_category ) invocation_id = get_current_invocation_id() if invocation_id is not None: log['invocation_id'] = invocation_id self._grpc_resp_queue.put_nowait( protos.StreamingMessage( request_id=self.request_id, rpc_log=protos.RpcLog(**log))) @property def request_id(self) -> str: return self._request_id @property def worker_id(self) -> str: return self._worker_id @classmethod async def connect(cls, host: str, port: int, worker_id: str, request_id: str, connect_timeout: float): loop = asyncio.events.get_event_loop() disp = cls(loop, host, port, worker_id, request_id, connect_timeout) # Safety check for mypy if disp._grpc_thread is not None: disp._grpc_thread.start() await disp._grpc_connected_fut logger.info('Successfully opened gRPC channel to %s:%s ', host, port) return disp def __poll_grpc(self): options = [] if self._grpc_max_msg_len: options.append(('grpc_local.max_receive_message_length', self._grpc_max_msg_len)) options.append(('grpc_local.max_send_message_length', self._grpc_max_msg_len)) channel = grpc.insecure_channel( f'{self._host}:{self._port}', options) try: grpc.channel_ready_future(channel).result( timeout=self._grpc_connect_timeout) except Exception as ex: self._loop.call_soon_threadsafe( self._grpc_connected_fut.set_exception, ex) return else: self._loop.call_soon_threadsafe( self._grpc_connected_fut.set_result, True) stub = protos.FunctionRpcStub(channel) def gen(resp_queue): while True: msg = resp_queue.get() if msg is self._GRPC_STOP_RESPONSE: grpc_req_stream.cancel() return yield msg grpc_req_stream = stub.EventStream(gen(self._grpc_resp_queue)) try: for req in grpc_req_stream: self._loop.call_soon_threadsafe( self._loop.create_task, self._dispatch_grpc_request(req)) except Exception as ex: if ex is grpc_req_stream: # Yes, this is how grpc_req_stream iterator exits. return error_logger.exception( 'unhandled error in gRPC thread. Exception: {0}'.format( ''.join(traceback.format_exception(ex)))) raise async def _dispatch_grpc_request(self, request): content_type = request.WhichOneof("content") match content_type: case "worker_init_request": request_handler = self._handle__worker_init_request case "function_environment_reload_request": request_handler = self._handle__function_environment_reload_request case "functions_metadata_request": request_handler = self._handle__functions_metadata_request case "function_load_request": request_handler = self._handle__function_load_request case "worker_status_request": request_handler = self._handle__worker_status_request case "invocation_request": request_handler = self._handle__invocation_request case _: # Don't crash on unknown messages. Log the error and return. logger.error("Unknown StreamingMessage content type: %s", content_type) return resp = await request_handler(request) self._grpc_resp_queue.put_nowait(resp) async def dispatch_forever(self): # sourcery skip: swap-if-expression if DispatcherMeta.__current_dispatcher__ is not None: raise RuntimeError('there can be only one running dispatcher per ' 'process') self._old_task_factory = self._loop.get_task_factory() DispatcherMeta.__current_dispatcher__ = self try: forever = self._loop.create_future() self._grpc_resp_queue.put_nowait( protos.StreamingMessage( request_id=self.request_id, start_stream=protos.StartStream( worker_id=self.worker_id))) # In Python 3.11+, constructing a task has an optional context # parameter. Allow for this param to be passed to ContextEnabledTask self._loop.set_task_factory( lambda loop, coro, context=None, **kwargs: ContextEnabledTask( coro, loop=loop, context=context, **kwargs)) # Detach console logging before enabling GRPC channel logging logger.info('Detaching console logging.') disable_console_logging() # Attach gRPC logging to the root logger. Since gRPC channel is # established, should use it for system and user logs logging_handler = AsyncLoggingHandler() root_logger = logging.getLogger() log_level = logging.INFO if not is_envvar_true( PYTHON_ENABLE_DEBUG_LOGGING) else logging.DEBUG root_logger.setLevel(log_level) root_logger.addHandler(logging_handler) logger.info('Switched to gRPC logging.') logging_handler.flush() try: await forever finally: logger.warning('Detaching gRPC logging due to exception.') logging_handler.flush() root_logger.removeHandler(logging_handler) # Reenable console logging when there's an exception enable_console_logging() logger.warning('Switched to console logging due to exception.') finally: DispatcherMeta.__current_dispatcher__ = None self._loop.set_task_factory(self._old_task_factory) self.stop() def stop(self) -> None: if self._grpc_thread is not None: self._grpc_resp_queue.put_nowait(self._GRPC_STOP_RESPONSE) self._grpc_thread.join() self._grpc_thread = None self._stop_sync_call_tp() def _stop_sync_call_tp(self): """Deallocate the current synchronous thread pool and assign self._sync_call_tp to None. If the thread pool does not exist, this will be a no op. """ if getattr(self, '_sync_call_tp', None): assert self._sync_call_tp is not None # mypy fix self._sync_call_tp.shutdown() self._sync_call_tp = None @staticmethod def _create_sync_call_tp(max_worker: Optional[int]) -> concurrent.futures.Executor: """Create a thread pool executor with max_worker. This is a wrapper over ThreadPoolExecutor constructor. Consider calling this method after _stop_sync_call_tp() to ensure only 1 synchronous thread pool is running. """ return concurrent.futures.ThreadPoolExecutor( max_workers=max_worker ) @staticmethod def _get_sync_tp_max_workers() -> typing.Optional[int]: def tp_max_workers_validator(value: str) -> bool: try: int_value = int(value) except ValueError: logger.warning('%s must be an integer', PYTHON_THREADPOOL_THREAD_COUNT) return False if int_value < 1: logger.warning( '%s must be set to a value between 1 and sys.maxint. ' 'Reverting to default value for max_workers', PYTHON_THREADPOOL_THREAD_COUNT, 1) return False return True max_workers = get_app_setting(setting=PYTHON_THREADPOOL_THREAD_COUNT, validator=tp_max_workers_validator) # We can box the app setting as int for earlier python versions. return int(max_workers) if max_workers else None @staticmethod def reload_library_worker(directory: str): global _library_worker v2_scriptfile = os.path.join(directory, get_script_file_name()) if os.path.exists(v2_scriptfile): try: import azure_functions_worker_v2 # NoQA _library_worker = azure_functions_worker_v2 logger.debug("azure_functions_worker_v2 import succeeded: %s", _library_worker.__file__) except ImportError: logger.debug("azure_functions_worker_v2 library not found: : %s", traceback.format_exc()) else: try: import azure_functions_worker_v1 # NoQA _library_worker = azure_functions_worker_v1 logger.debug("azure_functions_worker_v1 import succeeded: %s", _library_worker.__file__) # type: ignore[union-attr] except ImportError: logger.debug("azure_functions_worker_v1 library not found: %s", traceback.format_exc()) async def _handle__worker_init_request(self, request): logger.info('Received WorkerInitRequest, ' 'python version %s, ' 'worker version %s, ' 'request ID %s. ' 'To enable debug level logging, please refer to ' 'https://aka.ms/python-enable-debug-logging', sys.version, VERSION, self.request_id) if DependencyManager.is_in_linux_consumption(): import azure_functions_worker_v2 if DependencyManager.should_load_cx_dependencies(): DependencyManager.prioritize_customer_dependencies() directory = request.worker_init_request.function_app_directory self.reload_library_worker(directory) init_request = WorkerRequest(name="WorkerInitRequest", request=request, properties={"protos": protos, "host": self._host}) init_response = await ( _library_worker.worker_init_request( # type: ignore[union-attr] init_request)) return protos.StreamingMessage( request_id=self.request_id, worker_init_response=init_response) async def _handle__function_environment_reload_request(self, request): logger.info('Received FunctionEnvironmentReloadRequest, ' 'request ID: %s, ' 'To enable debug level logging, please refer to ' 'https://aka.ms/python-enable-debug-logging', self.request_id) func_env_reload_request = \ request.function_environment_reload_request directory = func_env_reload_request.function_app_directory DependencyManager.prioritize_customer_dependencies(directory) self.reload_library_worker(directory) env_reload_request = WorkerRequest(name="FunctionEnvironmentReloadRequest", request=request, properties={"protos": protos, "host": self._host}) env_reload_response = await ( _library_worker.function_environment_reload_request( # type: ignore[union-attr] # noqa env_reload_request)) return protos.StreamingMessage( request_id=self.request_id, function_environment_reload_response=env_reload_response) async def _handle__worker_status_request(self, request): # Logging is not necessary in this request since the response is used # for host to judge scale decisions of out-of-proc languages. # Having log here will reduce the responsiveness of the worker. return protos.StreamingMessage( request_id=request.request_id, worker_status_response=protos.WorkerStatusResponse()) async def _handle__functions_metadata_request(self, request): logger.info( 'Received WorkerMetadataRequest, request ID %s, ' 'worker id: %s', self.request_id, self.worker_id) metadata_request = WorkerRequest(name="WorkerMetadataRequest", request=request) metadata_response = await ( _library_worker.functions_metadata_request( # type: ignore[union-attr] metadata_request)) return protos.StreamingMessage( request_id=request.request_id, function_metadata_response=metadata_response) async def _handle__function_load_request(self, request): func_request = request.function_load_request function_id = func_request.function_id function_metadata = func_request.metadata function_name = function_metadata.name logger.info( 'Received WorkerLoadRequest, request ID %s, function_id: %s,' 'function_name: %s, worker_id: %s', self.request_id, function_id, function_name, self.worker_id) load_request = WorkerRequest(name="FunctionLoadRequest ", request=request) load_response = await ( _library_worker.function_load_request( # type: ignore[union-attr] load_request)) return protos.StreamingMessage( request_id=self.request_id, function_load_response=load_response) async def _handle__invocation_request(self, request): invoc_request = request.invocation_request invocation_id = invoc_request.invocation_id function_id = invoc_request.function_id logger.info( 'Received FunctionInvocationRequest, request ID %s, function_id: %s,' 'invocation_id: %s, worker_id: %s', self.request_id, function_id, invocation_id, self.worker_id) invocation_request = WorkerRequest(name="FunctionInvocationRequest", request=request, properties={ "threadpool": self._sync_call_tp}) invocation_response = await ( _library_worker.invocation_request( # type: ignore[union-attr] invocation_request)) return protos.StreamingMessage( request_id=self.request_id, invocation_response=invocation_response)