atr/manager.py (218 lines of code) (raw):

# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. """Worker process manager.""" from __future__ import annotations import asyncio import datetime import io import logging import os import signal import sys from typing import Final import sqlmodel import atr.db as db import atr.db.models as models _LOGGER: Final = logging.getLogger(__name__) # Global debug flag to control worker process output capturing global_worker_debug: bool = False # Global worker manager instance # Can't use "StringClass" | None, must use Optional["StringClass"] for forward references global_worker_manager: WorkerManager | None = None class WorkerManager: """Manager for a pool of worker processes.""" def __init__( self, min_workers: int = 4, max_workers: int = 8, check_interval_seconds: float = 2.0, max_task_seconds: float = 300.0, ): self.min_workers = min_workers self.max_workers = max_workers self.check_interval_seconds = check_interval_seconds self.max_task_seconds = max_task_seconds self.workers: dict[int, WorkerProcess] = {} self.running = False self.check_task: asyncio.Task | None = None async def start(self) -> None: """Start the worker manager.""" if self.running: return self.running = True _LOGGER.info("Starting worker manager in %s", os.getcwd()) # Start initial workers for _ in range(self.min_workers): await self.spawn_worker() # Start monitoring task self.check_task = asyncio.create_task(self.monitor_workers()) async def stop(self) -> None: """Stop all workers and the manager.""" if not self.running: return self.running = False _LOGGER.info("Stopping worker manager") # Cancel monitoring task if self.check_task: self.check_task.cancel() try: await self.check_task except asyncio.CancelledError: ... # Stop all workers await self.stop_all_workers() async def stop_all_workers(self) -> None: """Stop all worker processes.""" for worker in list(self.workers.values()): if worker.pid: try: os.kill(worker.pid, signal.SIGTERM) except ProcessLookupError: # The process may have already exited ... except Exception as e: _LOGGER.error(f"Error stopping worker {worker.pid}: {e}") # Wait for processes to exit for worker in list(self.workers.values()): try: await asyncio.wait_for(worker.process.wait(), timeout=5.0) except TimeoutError: if worker.pid: try: os.kill(worker.pid, signal.SIGKILL) except ProcessLookupError: # The process may have already exited ... except Exception as e: _LOGGER.error(f"Error force killing worker {worker.pid}: {e}") self.workers.clear() async def spawn_worker(self) -> None: """Spawn a new worker process.""" if len(self.workers) >= self.max_workers: return try: # Get the absolute path to the project root (i.e. atr/..) abs_path = await asyncio.to_thread(os.path.abspath, __file__) project_root = os.path.dirname(os.path.dirname(abs_path)) # Ensure PYTHONPATH includes our project root env = os.environ.copy() python_path = env.get("PYTHONPATH", "") env["PYTHONPATH"] = f"{project_root}:{python_path}" if python_path else project_root # Get absolute path to worker script worker_script = os.path.join(project_root, "atr", "worker.py") # Handle stdout and stderr based on debug setting stdout_target: int | io.TextIOWrapper = asyncio.subprocess.DEVNULL stderr_target: int | io.TextIOWrapper = asyncio.subprocess.DEVNULL # Generate a unique log file name for this worker if debugging is enabled log_file_path = None if global_worker_debug: timestamp = datetime.datetime.now(datetime.UTC).strftime("%Y%m%d_%H%M%S") log_file_name = f"worker_{timestamp}_{os.getpid()}.log" log_file_path = os.path.join(project_root, "state", log_file_name) # Open log file for writing log_file = await asyncio.to_thread(open, log_file_path, "w") stdout_target = log_file stderr_target = log_file _LOGGER.info(f"Worker output will be logged to {log_file_path}") # Start worker process with the updated environment # Use preexec_fn to create new process group process = await asyncio.create_subprocess_exec( sys.executable, worker_script, stdout=stdout_target, stderr=stderr_target, env=env, preexec_fn=os.setsid, ) worker = WorkerProcess(process, datetime.datetime.now(datetime.UTC)) if worker.pid: self.workers[worker.pid] = worker _LOGGER.info(f"Started worker process {worker.pid}") if global_worker_debug and log_file_path: _LOGGER.info(f"Worker {worker.pid} logs: {log_file_path}") else: _LOGGER.error("Failed to start worker process: No PID assigned") if global_worker_debug and isinstance(stdout_target, io.TextIOWrapper): await asyncio.to_thread(stdout_target.close) except Exception as e: _LOGGER.error(f"Error spawning worker: {e}") async def monitor_workers(self) -> None: """Monitor worker processes and restart them if needed.""" while self.running: try: await self.check_workers() await asyncio.sleep(self.check_interval_seconds) except asyncio.CancelledError: break except Exception as e: _LOGGER.error(f"Error in worker monitor: {e}", exc_info=e) # TODO: How long should we wait before trying again? await asyncio.sleep(1.0) async def check_workers(self) -> None: """Check worker processes and restart if needed.""" exited_workers = [] async with db.session() as data: # Check each worker first for pid, worker in list(self.workers.items()): # Check if process is running if not await worker.is_running(): exited_workers.append(pid) _LOGGER.info(f"Worker {pid} has exited") continue # Check if worker has been processing its task for too long # This also stops tasks if they have indeed been running for too long if await self.check_task_duration(data, pid, worker): exited_workers.append(pid) # Remove exited workers for pid in exited_workers: self.workers.pop(pid, None) # Check for active tasks # try: # async with get_session() as session: # result = await session.execute( # text(""" # SELECT COUNT(*) # FROM task # WHERE status = 'QUEUED' # """) # ) # queued_count = result.scalar() # logger.info(f"Found {queued_count} queued tasks waiting for workers") # except Exception as e: # logger.error(f"Error checking queued tasks: {e}") # Spawn new workers if needed await self.maintain_worker_pool() # Reset any tasks that were being processed by now inactive workers await self.reset_broken_tasks() async def terminate_long_running_task( self, task: models.Task, worker: WorkerProcess, task_id: int, pid: int ) -> None: """ Terminate a task that has been running for too long. Updates the task status and terminates the worker process. """ try: # Mark the task as failed task.status = models.TaskStatus.FAILED task.completed = datetime.datetime.now(datetime.UTC) task.error = f"Task terminated after exceeding time limit of {self.max_task_seconds} seconds" if worker.pid: os.kill(worker.pid, signal.SIGTERM) _LOGGER.info(f"Worker {pid} terminated after processing task {task_id} for > {self.max_task_seconds}s") except ProcessLookupError: return except Exception as e: _LOGGER.error(f"Error stopping long-running worker {pid}: {e}") async def check_task_duration(self, data: db.Session, pid: int, worker: WorkerProcess) -> bool: """ Check whether a worker has been processing its task for too long. Returns True if the worker has been terminated. """ try: async with data.begin(): task = await data.task(pid=pid, status=models.TaskStatus.ACTIVE).get() if not task or not task.started: return False task_duration = (datetime.datetime.now(datetime.UTC) - task.started).total_seconds() if task_duration > self.max_task_seconds: await self.terminate_long_running_task(task, worker, task.id, pid) return True return False except Exception as e: _LOGGER.error(f"Error checking task duration for worker {pid}: {e}") # TODO: Return False here to avoid over-reporting errors return False async def maintain_worker_pool(self) -> None: """Ensure we maintain the minimum number of workers.""" current_count = len(self.workers) if current_count < self.min_workers: _LOGGER.info(f"Worker pool below minimum ({current_count} < {self.min_workers}), spawning new workers") while len(self.workers) < self.min_workers: await self.spawn_worker() _LOGGER.info(f"Worker pool restored to {len(self.workers)} workers") async def reset_broken_tasks(self) -> None: """Reset any tasks that were being processed by exited workers.""" try: async with db.session() as data: async with data.begin(): active_worker_pids = list(self.workers) update_stmt = ( sqlmodel.update(models.Task) .where( sqlmodel.and_( db.validate_instrumented_attribute(models.Task.pid).notin_(active_worker_pids), models.Task.status == models.TaskStatus.ACTIVE, ) ) .values(status=models.TaskStatus.QUEUED, started=None, pid=None) ) result = await data.execute(update_stmt) if result.rowcount > 0: _LOGGER.info(f"Reset {result.rowcount} tasks to state 'QUEUED' as their worker died") except Exception as e: _LOGGER.error(f"Error resetting broken tasks: {e}") class WorkerProcess: """Interface to control a worker process.""" def __init__(self, process: asyncio.subprocess.Process, started: datetime.datetime): self.process = process self.started = started self.last_checked = started @property def pid(self) -> int | None: return self.process.pid async def is_running(self) -> bool: """Check if the process is still running.""" if self.process.returncode is not None: # Process has already exited return False if not self.pid: # Process did not start return False try: os.kill(self.pid, 0) self.last_checked = datetime.datetime.now(datetime.UTC) return True except ProcessLookupError: # Process no longer exists return False except PermissionError: # Process exists, but we don't have permission to signal it # This shouldn't happen in our case since we own the process _LOGGER.warning(f"Permission error checking process {self.pid}") return False def get_worker_manager() -> WorkerManager: """Get the global worker manager instance.""" global global_worker_manager if global_worker_manager is None: global_worker_manager = WorkerManager() return global_worker_manager