atr/manager.py (218 lines of code) (raw):
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Worker process manager."""
from __future__ import annotations
import asyncio
import datetime
import io
import logging
import os
import signal
import sys
from typing import Final
import sqlmodel
import atr.db as db
import atr.db.models as models
_LOGGER: Final = logging.getLogger(__name__)
# Global debug flag to control worker process output capturing
global_worker_debug: bool = False
# Global worker manager instance
# Can't use "StringClass" | None, must use Optional["StringClass"] for forward references
global_worker_manager: WorkerManager | None = None
class WorkerManager:
"""Manager for a pool of worker processes."""
def __init__(
self,
min_workers: int = 4,
max_workers: int = 8,
check_interval_seconds: float = 2.0,
max_task_seconds: float = 300.0,
):
self.min_workers = min_workers
self.max_workers = max_workers
self.check_interval_seconds = check_interval_seconds
self.max_task_seconds = max_task_seconds
self.workers: dict[int, WorkerProcess] = {}
self.running = False
self.check_task: asyncio.Task | None = None
async def start(self) -> None:
"""Start the worker manager."""
if self.running:
return
self.running = True
_LOGGER.info("Starting worker manager in %s", os.getcwd())
# Start initial workers
for _ in range(self.min_workers):
await self.spawn_worker()
# Start monitoring task
self.check_task = asyncio.create_task(self.monitor_workers())
async def stop(self) -> None:
"""Stop all workers and the manager."""
if not self.running:
return
self.running = False
_LOGGER.info("Stopping worker manager")
# Cancel monitoring task
if self.check_task:
self.check_task.cancel()
try:
await self.check_task
except asyncio.CancelledError:
...
# Stop all workers
await self.stop_all_workers()
async def stop_all_workers(self) -> None:
"""Stop all worker processes."""
for worker in list(self.workers.values()):
if worker.pid:
try:
os.kill(worker.pid, signal.SIGTERM)
except ProcessLookupError:
# The process may have already exited
...
except Exception as e:
_LOGGER.error(f"Error stopping worker {worker.pid}: {e}")
# Wait for processes to exit
for worker in list(self.workers.values()):
try:
await asyncio.wait_for(worker.process.wait(), timeout=5.0)
except TimeoutError:
if worker.pid:
try:
os.kill(worker.pid, signal.SIGKILL)
except ProcessLookupError:
# The process may have already exited
...
except Exception as e:
_LOGGER.error(f"Error force killing worker {worker.pid}: {e}")
self.workers.clear()
async def spawn_worker(self) -> None:
"""Spawn a new worker process."""
if len(self.workers) >= self.max_workers:
return
try:
# Get the absolute path to the project root (i.e. atr/..)
abs_path = await asyncio.to_thread(os.path.abspath, __file__)
project_root = os.path.dirname(os.path.dirname(abs_path))
# Ensure PYTHONPATH includes our project root
env = os.environ.copy()
python_path = env.get("PYTHONPATH", "")
env["PYTHONPATH"] = f"{project_root}:{python_path}" if python_path else project_root
# Get absolute path to worker script
worker_script = os.path.join(project_root, "atr", "worker.py")
# Handle stdout and stderr based on debug setting
stdout_target: int | io.TextIOWrapper = asyncio.subprocess.DEVNULL
stderr_target: int | io.TextIOWrapper = asyncio.subprocess.DEVNULL
# Generate a unique log file name for this worker if debugging is enabled
log_file_path = None
if global_worker_debug:
timestamp = datetime.datetime.now(datetime.UTC).strftime("%Y%m%d_%H%M%S")
log_file_name = f"worker_{timestamp}_{os.getpid()}.log"
log_file_path = os.path.join(project_root, "state", log_file_name)
# Open log file for writing
log_file = await asyncio.to_thread(open, log_file_path, "w")
stdout_target = log_file
stderr_target = log_file
_LOGGER.info(f"Worker output will be logged to {log_file_path}")
# Start worker process with the updated environment
# Use preexec_fn to create new process group
process = await asyncio.create_subprocess_exec(
sys.executable,
worker_script,
stdout=stdout_target,
stderr=stderr_target,
env=env,
preexec_fn=os.setsid,
)
worker = WorkerProcess(process, datetime.datetime.now(datetime.UTC))
if worker.pid:
self.workers[worker.pid] = worker
_LOGGER.info(f"Started worker process {worker.pid}")
if global_worker_debug and log_file_path:
_LOGGER.info(f"Worker {worker.pid} logs: {log_file_path}")
else:
_LOGGER.error("Failed to start worker process: No PID assigned")
if global_worker_debug and isinstance(stdout_target, io.TextIOWrapper):
await asyncio.to_thread(stdout_target.close)
except Exception as e:
_LOGGER.error(f"Error spawning worker: {e}")
async def monitor_workers(self) -> None:
"""Monitor worker processes and restart them if needed."""
while self.running:
try:
await self.check_workers()
await asyncio.sleep(self.check_interval_seconds)
except asyncio.CancelledError:
break
except Exception as e:
_LOGGER.error(f"Error in worker monitor: {e}", exc_info=e)
# TODO: How long should we wait before trying again?
await asyncio.sleep(1.0)
async def check_workers(self) -> None:
"""Check worker processes and restart if needed."""
exited_workers = []
async with db.session() as data:
# Check each worker first
for pid, worker in list(self.workers.items()):
# Check if process is running
if not await worker.is_running():
exited_workers.append(pid)
_LOGGER.info(f"Worker {pid} has exited")
continue
# Check if worker has been processing its task for too long
# This also stops tasks if they have indeed been running for too long
if await self.check_task_duration(data, pid, worker):
exited_workers.append(pid)
# Remove exited workers
for pid in exited_workers:
self.workers.pop(pid, None)
# Check for active tasks
# try:
# async with get_session() as session:
# result = await session.execute(
# text("""
# SELECT COUNT(*)
# FROM task
# WHERE status = 'QUEUED'
# """)
# )
# queued_count = result.scalar()
# logger.info(f"Found {queued_count} queued tasks waiting for workers")
# except Exception as e:
# logger.error(f"Error checking queued tasks: {e}")
# Spawn new workers if needed
await self.maintain_worker_pool()
# Reset any tasks that were being processed by now inactive workers
await self.reset_broken_tasks()
async def terminate_long_running_task(
self, task: models.Task, worker: WorkerProcess, task_id: int, pid: int
) -> None:
"""
Terminate a task that has been running for too long.
Updates the task status and terminates the worker process.
"""
try:
# Mark the task as failed
task.status = models.TaskStatus.FAILED
task.completed = datetime.datetime.now(datetime.UTC)
task.error = f"Task terminated after exceeding time limit of {self.max_task_seconds} seconds"
if worker.pid:
os.kill(worker.pid, signal.SIGTERM)
_LOGGER.info(f"Worker {pid} terminated after processing task {task_id} for > {self.max_task_seconds}s")
except ProcessLookupError:
return
except Exception as e:
_LOGGER.error(f"Error stopping long-running worker {pid}: {e}")
async def check_task_duration(self, data: db.Session, pid: int, worker: WorkerProcess) -> bool:
"""
Check whether a worker has been processing its task for too long.
Returns True if the worker has been terminated.
"""
try:
async with data.begin():
task = await data.task(pid=pid, status=models.TaskStatus.ACTIVE).get()
if not task or not task.started:
return False
task_duration = (datetime.datetime.now(datetime.UTC) - task.started).total_seconds()
if task_duration > self.max_task_seconds:
await self.terminate_long_running_task(task, worker, task.id, pid)
return True
return False
except Exception as e:
_LOGGER.error(f"Error checking task duration for worker {pid}: {e}")
# TODO: Return False here to avoid over-reporting errors
return False
async def maintain_worker_pool(self) -> None:
"""Ensure we maintain the minimum number of workers."""
current_count = len(self.workers)
if current_count < self.min_workers:
_LOGGER.info(f"Worker pool below minimum ({current_count} < {self.min_workers}), spawning new workers")
while len(self.workers) < self.min_workers:
await self.spawn_worker()
_LOGGER.info(f"Worker pool restored to {len(self.workers)} workers")
async def reset_broken_tasks(self) -> None:
"""Reset any tasks that were being processed by exited workers."""
try:
async with db.session() as data:
async with data.begin():
active_worker_pids = list(self.workers)
update_stmt = (
sqlmodel.update(models.Task)
.where(
sqlmodel.and_(
db.validate_instrumented_attribute(models.Task.pid).notin_(active_worker_pids),
models.Task.status == models.TaskStatus.ACTIVE,
)
)
.values(status=models.TaskStatus.QUEUED, started=None, pid=None)
)
result = await data.execute(update_stmt)
if result.rowcount > 0:
_LOGGER.info(f"Reset {result.rowcount} tasks to state 'QUEUED' as their worker died")
except Exception as e:
_LOGGER.error(f"Error resetting broken tasks: {e}")
class WorkerProcess:
"""Interface to control a worker process."""
def __init__(self, process: asyncio.subprocess.Process, started: datetime.datetime):
self.process = process
self.started = started
self.last_checked = started
@property
def pid(self) -> int | None:
return self.process.pid
async def is_running(self) -> bool:
"""Check if the process is still running."""
if self.process.returncode is not None:
# Process has already exited
return False
if not self.pid:
# Process did not start
return False
try:
os.kill(self.pid, 0)
self.last_checked = datetime.datetime.now(datetime.UTC)
return True
except ProcessLookupError:
# Process no longer exists
return False
except PermissionError:
# Process exists, but we don't have permission to signal it
# This shouldn't happen in our case since we own the process
_LOGGER.warning(f"Permission error checking process {self.pid}")
return False
def get_worker_manager() -> WorkerManager:
"""Get the global worker manager instance."""
global global_worker_manager
if global_worker_manager is None:
global_worker_manager = WorkerManager()
return global_worker_manager