images/airflow/2.10.1/python/mwaa/subprocess/conditions.py (342 lines of code) (raw):

""" Contain various conditions that can be added to a process. We frequently need to monitor and control running processes. For example, we might want to limit the running time of a process to X minutes. Or, we might need to monitor the health of a process and restart it if necessary. Process conditions make this possible. """ from __future__ import annotations # Python imports from collections import deque from dataclasses import dataclass, field from datetime import datetime, timedelta from dateutil.tz import tz from functools import cached_property from types import FrameType, TracebackType from typing import Callable, Deque, Optional import logging import signal import socket import sys import time # 3rd-party imports from airflow.configuration import conf from sqlalchemy import create_engine from sqlalchemy.pool import NullPool # Our imports from mwaa.celery.task_monitor import WorkerTaskMonitor from mwaa.config.database import get_db_connection_string from mwaa.logging.utils import throttle from mwaa.subprocess import ProcessStatus from mwaa.utils.plogs import generate_plog from mwaa.utils.statsd import get_statsd logger = logging.getLogger(__name__) @dataclass class ProcessConditionResponse: """ Encapsulates all the information regarding a single execution of a health check. """ condition: ProcessCondition successful: bool message: str = "" timestamp: datetime = field(default_factory=lambda: datetime.now(tz.tzutc())) @property def name(self): """ Return the name of the condition this response is about. :returns The name of the condition. """ return self.condition.name def __str__(self) -> str: """ Return the string representation of this response. :returns The string representation of this response. """ if self.successful: return ( f"At {self.timestamp} condition {self.name} succeeded with " f"message: {self.message}" ) else: return ( f"At {self.timestamp} condition {self.name} failed with " f"message: {self.message}" ) _PROCESS_CONDITION_DEFAULT_MAX_HISTORY = 10 class ProcessCondition: """ Base class for all process conditions. A process condition is, as the name suggests, a condition that must be satisfied for a process to continue running. The """ def __init__( self, name: str | None = None, max_history: int = _PROCESS_CONDITION_DEFAULT_MAX_HISTORY, ): """ Initialize the process condition. :param name: The name of the condition. You can pass None for this to use the name of the class of the condition, which is usually sufficient, unless you want to customize the name. """ self.name = name if name else self.__class__.__name__ self.history: Deque[ProcessConditionResponse] = deque(maxlen=max_history) self.closed = False def prepare(self): """ Called by the Subprocess class to indicate the start of the subprocess. """ pass def close(self): """ Free any resources obtained by the condition. """ if self.closed: return self._close() self.closed = True def _close(self): pass def __enter__(self) -> ProcessCondition: """ Enter the runtime context related to this object. """ self.prepare() return self def __exit__( self, exc_type: type[BaseException], exc_value: BaseException, traceback: TracebackType, ): """ Exit the runtime context related to this object. """ self._close() def check(self, process_status: ProcessStatus) -> ProcessConditionResponse: """ Execute the condition and return the response. :returns A ProcessConditionResponse containing data about the response. """ response = self._check(process_status) self.history.append(response) return response def _check(self, process_status: ProcessStatus) -> ProcessConditionResponse: """ Execute the condition and return the response. :returns A ProcessConditionResponse containing data about the response. """ raise NotImplementedError() SIDECAR_DEFAULT_HEALTH_PORT = 8200 SOCKET_BUFFER_SIZE = 1024 # Socket timeout. Wait time upon receiving a message from sidecar. # The connection to the socket is considered as timed out if it has to wait more than # this threshold In seconds _SOCKET_TIMEOUT_SECONDS = 1 # The time to wait for the sidecar, during which timeouts from the sidecar are ignored. _SIDECAR_WAIT_PERIOD = timedelta(minutes=5) class SidecarHealthCondition(ProcessCondition): """ A health check that reads health status from the sidecar. The sidecar is another container that lives adjacent to the airflow container of the MWAA Fargate tasks. It is responsible for a couple of tasks, including health monitoring. The latter's logic sends back the result of its health assessment back to the main container, which is then read by this class. """ def __init__( self, airflow_component: str, container_start_time: float, port: int = SIDECAR_DEFAULT_HEALTH_PORT, ): """ :param airflow_component: The airflow component to check. :param container_start_time: The epoch in seconds, i.e. time.time(), when the container started. :param port: The port the sidecar sends health monitoring results to. """ super().__init__() self.airflow_component = airflow_component self.port = port self.socket: socket.socket | None self.container_start_time: float = container_start_time def prepare(self): """ Called by the Subprocess class to indicate the start of the subprocess. Here, we create the UDP socket that listens for health messages from the MWAA sidecar. """ self.socket = socket.socket( socket.AF_INET, # Internet socket.SOCK_DGRAM, # UDP ) logger.info(f"Binding to port {self.port}") self.socket.bind(("127.0.0.1", self.port)) self.socket.settimeout(_SOCKET_TIMEOUT_SECONDS) def _close(self): """ Free the socket that was created. """ if self.socket: self.socket.close() def _generate_autorestart_plog(self): """ Generate a processable log that the service can ingest to know that a restart on an Airlfow worker/scheduler has happened and report health metrics. """ # Unlike normal logs, plogs are ingested by the service to take various actions. # Hence, we always use 'print', to avoid log level accidentally stopping them. print( generate_plog( "AutoRestartLogsProcessor", f"[{self.airflow_component}] Restarting process...", ) ) @throttle(seconds=60, instance_level_throttling=True) # avoid excessive calls to process conditions def _check(self, process_status: ProcessStatus) -> ProcessConditionResponse: """ Execute the condition and return the response. :returns A ProcessConditionResponse containing data about the response. """ if self.socket is None: raise RuntimeError( "Unexpected error: socket object and start time shouldn't be None." ) try: status, _ = self.socket.recvfrom(SOCKET_BUFFER_SIZE) status = status.decode("utf-8") match status.lower(): case "red": response = ProcessConditionResponse( condition=self, successful=False, message=f"Status received from sidecar: {status}", ) logger.error(response.message) case "blue" | "yellow": # We treat blue/yellow as healthy to avoid unnecessary restarts, # but we log a warning. response = ProcessConditionResponse( condition=self, successful=True, message=f"Status received from sidecar: {status}", ) logger.warning(response.message) case "healthy": response = ProcessConditionResponse( condition=self, successful=True, message=f"Status received from sidecar: {status}", ) logger.info(response.message) case _: response = ProcessConditionResponse( condition=self, successful=True, message=f"Unexpected response retrieved from sidecar: {status}. " "Treating the status as HEALTHY. This may be a false positive " "so it should be investigated, unless it is happening at the " "start of the container before the sidecar monitoring is up " "and emitting health indicators.", ) logger.warning(response.message) except Exception: if ( time.time() - self.container_start_time > _SIDECAR_WAIT_PERIOD.total_seconds() ): response = ProcessConditionResponse( condition=self, successful=True, message="Reading the health status from the sidecar timed out. " "Unable to positively determine health, so assuming healthy. This " "may be a false positive so it should be investigated.", ) logger.error(response.message, exc_info=sys.exc_info()) else: response = ProcessConditionResponse( condition=self, successful=True, message="Reading the health status from the sidecar timed out, but " "ignoring this since the container just started, so the sidecar " "monitoring might not have been initialized yet.", ) logger.info(response.message) if not response.successful: self._generate_autorestart_plog() return response class TimeoutCondition(ProcessCondition): """ A timeout condition is used to control the running time of a process. A timeout condition always checks, except when the process has run for more than the allowed time, at which point this condition fails check, and results in terminating the process. """ def __init__(self, timeout: timedelta): """ Initialize the timeout condition. :param timeout: The maximum time the process is allowed to run. """ super().__init__() self.timeout = timeout self.start_time: float | None = None def prepare(self): """ Called by the Subprocess class to indicate the start of the subprocess. Here, we set the `start_time` field to save the time the process started. """ self.start_time = time.time() @throttle(seconds=60, instance_level_throttling=True) # avoid excessive calls to process conditions def _check(self, process_status: ProcessStatus) -> ProcessConditionResponse: """ Execute the condition and return the response. :returns A ProcessConditionResponse containing data about the response. """ if not self.start_time: raise RuntimeError("TimeoutCondition has not been initialized") running_time_ms = (time.time() - self.start_time) * 1000 timeout_ms = self.timeout.total_seconds() * 1000 if running_time_ms < timeout_ms: return ProcessConditionResponse(condition=self, successful=True) else: return ProcessConditionResponse( condition=self, successful=False, message=f"Process timed out after running for more than " f"{running_time_ms} milliseconds when the maximum running time " f"is {timeout_ms} milliseconds.", ) class AirflowDbReachableCondition(ProcessCondition): """ A condition for ensuring the Airflow database is reachable. """ def __init__(self, airflow_component: str): """ Initialize the check object. :param airflow_component: The airflow component to check. """ super().__init__() self.airflow_component = airflow_component self.healthy: bool = True def prepare(self): """ Initialize the condition. """ # values taken from entrypoint.sh engine_args = {} if not self._is_db_connection_pooling_enabled: logging.info( "Connection pooling is disabled. AirflowDbReachableCondition will not pool connections." ) engine_args["poolclass"] = NullPool else: logging.info( "Connection pooling is enabled. AirflowDbReachableCondition will pool connections." ) self.engine = create_engine( get_db_connection_string(), connect_args={"connect_timeout": 3}, **engine_args, ) def _generate_health_plog(self, healthy: bool, health_changed: bool): if health_changed: health_status = ( "CONNECTION_BECAME_HEALTHY" if healthy else "CONNECTION_BECAME_UNHEALTHY" ) else: health_status = "CONNECTION_HEALTHY" if healthy else "CONNECTION_UNHEALTHY" # Unlike normal logs, plogs are ingested by the service to take various actions. # Hence, we always use 'print', to avoid log level accidentally stopping them. print( generate_plog( "RDSHealthLogsProcessor", f"[{self.airflow_component}] connection with RDS Meta DB is {health_status}.", ) ) @throttle(seconds=60, instance_level_throttling=True) # avoid excessive calls to process conditions def _check(self, process_status: ProcessStatus) -> ProcessConditionResponse: """ Execute the condition and return the response. :returns A ProcessConditionResponse containing data about the response. """ # only run db healthcheck from scheduler for now. logging.info("Performing health check on Airflow DB.") # test db connectivity try: # try connection to the RDS metadata db and run a test query with self.engine.connect() as connection: # type: ignore connection.execute("SELECT 1") # type: ignore healthy = True message = "Successfully connected to database." logger.info(message) except Exception as ex: healthy = False message = f"Couldn't connect to database. Error: {ex}" logger.error(message) response = ProcessConditionResponse( condition=self, # This condition is currently informational only, to report metrics and # logs. We have an issue to implement a restart after a certain grace # period: https://github.com/aws/amazon-mwaa-docker-images/issues/75 successful=True, message=message, ) self._generate_health_plog( healthy, healthy != self.healthy, # compare the current health against the last. ) self.healthy = healthy return response @cached_property def _is_db_connection_pooling_enabled(self) -> bool: return conf.getboolean( # type: ignore "database", "sql_alchemy_pool_enabled", fallback=False ) class TaskMonitoringCondition(ProcessCondition): """ A condition for regularly communicating with the worker task monitor to ensure graceful shutdown of workers in case of auto scaling. :param worker_task_monitor: The worker task monitor. See the implementation of WorkerTaskMonitor for more details on what this monitor does. :param terminate_if_idle: Whether to terminate the worker if it is idle. """ WORKER_MONITOR_CLOSED_TIME_THRESHOLD = timedelta(seconds=30) def __init__( self, worker_task_monitor: WorkerTaskMonitor, terminate_if_idle: bool, ): """ Initialize the instance. :param worker_task_monitor: The worker task monitor. See the implementation of WorkerTaskMonitor for more details on what this monitor does. """ super().__init__() self.worker_task_monitor = worker_task_monitor self.terminate_if_idle = terminate_if_idle self.stats = get_statsd() self.last_check_time = 0 self.time_since_closed = 0 def prepare(self): """ Initialize the condition. """ pass def _close(self): """ Close the auto scaling condition. """ self.worker_task_monitor.close() def _get_failed_condition_response(self, message: str) -> ProcessConditionResponse: """ Get the failed condition response. :param message: The message to include in the response. :returns The failed condition response. """ logger.info(message) return ProcessConditionResponse( condition=self, successful=False, message=message, ) def _publish_metrics(self): is_closed = self.worker_task_monitor.is_closed() dt = time.time() - self.last_check_time self.last_check_time = time.time() if not is_closed: self.time_since_closed = 0 else: self.time_since_closed += dt self.stats.incr("mwaa.task_monitor.task_count", self.worker_task_monitor.get_current_task_count()) self.stats.incr("mwaa.task_monitor.cleanup_task_count", self.worker_task_monitor.get_cleanup_task_count()) if self.time_since_closed > TaskMonitoringCondition.WORKER_MONITOR_CLOSED_TIME_THRESHOLD.total_seconds(): self.stats.incr("mwaa.task_monitor.worker_shutdown_delayed", 1) @throttle(seconds=10, instance_level_throttling=True) # avoid excessive calls to process conditions def _check(self, process_status: ProcessStatus) -> ProcessConditionResponse: """ Execute the condition and return the response. :returns A ProcessConditionResponse containing data about the response. """ if process_status == ProcessStatus.RUNNING: self.worker_task_monitor.cleanup_abandoned_resources() self.worker_task_monitor.process_next_signal() # If the allowed time limit for waiting for activation signal has been breached, then we give up on further wait and exit. if self.worker_task_monitor.is_activation_wait_time_limit_breached(): return self._get_failed_condition_response("Allowed time limit for activation has been breached. Exiting") # If the worker is marked to be killed, then we exit the worker without waiting for the tasks to be completed. elif self.worker_task_monitor.is_marked_for_kill(): return self._get_failed_condition_response("Worker has been marked for kill. Exiting.") # If the worker is marked to be terminated, then we exit the worker after waiting for the tasks to be completed. elif self.worker_task_monitor.is_marked_for_termination(): logger.info("Worker has been marked for termination, checking for idleness before terminating.") if self.worker_task_monitor.is_worker_idle(): return self._get_failed_condition_response("Worker marked for termination has become idle. Exiting.") elif self.worker_task_monitor.is_termination_time_limit_breached(): return self._get_failed_condition_response("Allowed time limit for graceful termination has been breached. Exiting") else: logger.info("Worker marked for termination is NOT yet idle. Waiting.") elif self.terminate_if_idle and self.worker_task_monitor.is_worker_idle(): # After detecting worker idleness, we pause further work consumption via # Celery, wait and check again for idleness. logger.info("Worker process is idle and needs to be terminated. Pausing task consumption.") self.worker_task_monitor.pause_task_consumption() if self.worker_task_monitor.is_worker_idle(): return self._get_failed_condition_response("Worker which should be terminated if idle has been " "found to be idle. Exiting.") else: logger.info("Worker picked up new tasks during shutdown, reviving worker.") self.worker_task_monitor.resume_task_consumption() self.worker_task_monitor.reset_monitor_state() else: logger.info("Worker process is either NOT idle or has not been marked for termination. No action is needed.") else: logger.info( f"Worker process finished (status is {process_status.name}). " "No need to monitor tasks anymore." ) self._publish_metrics() # For all other scenarios, the condition defaults to returning success. return ProcessConditionResponse( condition=self, successful=True, )