atr/worker.py (197 lines of code) (raw):
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""worker.py - Task worker process for ATR"""
# TODO: If started is older than some threshold and status
# is active but the pid is no longer running, we can revert
# the task to status='QUEUED'. For this to work, ideally we
# need to check wall clock time as well as CPU time.
import asyncio
import datetime
import inspect
import json
import logging
import os
import resource
import signal
import traceback
from typing import Any, Final
import sqlmodel
import atr.db as db
import atr.db.models as models
import atr.tasks as tasks
import atr.tasks.checks as checks
import atr.tasks.task as task
_LOGGER: Final = logging.getLogger(__name__)
# Resource limits, 5 minutes and 1GB
# _CPU_LIMIT_SECONDS: Final = 300
_MEMORY_LIMIT_BYTES: Final = 1024 * 1024 * 1024
# # Create tables if they don't exist
# SQLModel.metadata.create_all(engine)
def main() -> None:
"""Main entry point."""
import atr.config as config
conf = config.get()
if os.path.isdir(conf.STATE_DIR):
os.chdir(conf.STATE_DIR)
_setup_logging()
_LOGGER.info(f"Starting worker process with pid {os.getpid()}")
tasks: list[asyncio.Task] = []
async def _handle_signal(signum: int) -> None:
_LOGGER.info(f"Received signal {signum}, shutting down...")
await db.shutdown_database()
for t in tasks:
t.cancel()
_LOGGER.debug("Cancelled all running tasks")
asyncio.get_event_loop().stop()
_LOGGER.debug("Stopped event loop")
for s in (signal.SIGTERM, signal.SIGINT):
signal.signal(s, lambda signum, frame: asyncio.create_task(_handle_signal(signum)))
_worker_resources_limit_set()
async def _start() -> None:
await asyncio.create_task(db.init_database_for_worker())
tasks.append(asyncio.create_task(_worker_loop_run()))
await asyncio.gather(*tasks)
asyncio.run(_start())
# If the worker decides to stop running (see #230 in _worker_loop_run()), shutdown the database gracefully
asyncio.run(db.shutdown_database())
_LOGGER.info("Exiting worker process")
def _setup_logging() -> None:
# Configure logging
log_format = "[%(asctime)s.%(msecs)03d] [%(process)d] [%(levelname)s] %(message)s"
date_format = "%Y-%m-%d %H:%M:%S"
logging.basicConfig(filename="atr-worker.log", format=log_format, datefmt=date_format, level=logging.INFO)
# Task functions
async def _task_error_handle(task_id: int, e: Exception) -> None:
"""Handle task error by updating the database with error information."""
if isinstance(e, task.Error):
_LOGGER.error(f"Task {task_id} failed: {e.message}")
_LOGGER.error("".join(traceback.format_exception(e)))
result = json.dumps(e.result)
async with db.session() as data:
async with data.begin():
task_obj = await data.task(id=task_id).get()
if task_obj:
task_obj.status = task.FAILED
task_obj.completed = datetime.datetime.now(datetime.UTC)
task_obj.error = e.message
task_obj.result = result
else:
_LOGGER.error(f"Task {task_id} failed: {e}")
_LOGGER.error("".join(traceback.format_exception(e)))
async with db.session() as data:
async with data.begin():
task_obj = await data.task(id=task_id).get()
if task_obj:
task_obj.status = task.FAILED
task_obj.completed = datetime.datetime.now(datetime.UTC)
task_obj.error = str(e)
async def _task_next_claim() -> tuple[int, str, list[str] | dict[str, Any]] | None:
"""
Attempt to claim the oldest unclaimed task.
Returns (task_id, task_type, task_args) if successful.
Returns None if no tasks are available.
"""
async with db.session() as data:
async with data.begin():
# Get the ID of the oldest queued task
oldest_queued_task = (
sqlmodel.select(models.Task.id)
.where(models.Task.status == task.QUEUED)
.order_by(db.validate_instrumented_attribute(models.Task.added).asc())
.limit(1)
)
# Use an UPDATE with a WHERE clause to atomically claim the task
# This ensures that only one worker can claim a specific task
now = datetime.datetime.now(datetime.UTC)
update_stmt = (
sqlmodel.update(models.Task)
.where(sqlmodel.and_(models.Task.id == oldest_queued_task, models.Task.status == task.QUEUED))
.values(status=task.ACTIVE, started=now, pid=os.getpid())
.returning(
db.validate_instrumented_attribute(models.Task.id),
db.validate_instrumented_attribute(models.Task.task_type),
db.validate_instrumented_attribute(models.Task.task_args),
)
)
result = await data.execute(update_stmt)
claimed_task = result.first()
if claimed_task:
task_id, task_type, task_args = claimed_task
_LOGGER.info(f"Claimed task {task_id} ({task_type}) with args {task_args}")
return task_id, task_type, task_args
return None
async def _task_process(task_id: int, task_type: str, task_args: list[str] | dict[str, Any]) -> None:
"""Process a claimed task."""
_LOGGER.info(f"Processing task {task_id} ({task_type}) with raw args {task_args}")
try:
task_type_member = models.TaskType(task_type)
except ValueError as e:
_LOGGER.error(f"Invalid task type: {task_type}")
await _task_result_process(task_id, tuple(), task.FAILED, str(e))
return
task_results: tuple[Any, ...]
try:
handler = tasks.resolve(task_type_member)
sig = inspect.signature(handler)
params = list(sig.parameters.values())
# Check whether the handler is a check handler
if (len(params) == 1) and (params[0].annotation == checks.FunctionArguments):
_LOGGER.debug(f"Handler {handler.__name__} expects checks.FunctionArguments, fetching full task details")
async with db.session() as data:
task_obj = await data.task(id=task_id).demand(
ValueError(f"Task {task_id} disappeared during processing")
)
# Validate required fields from the Task object itself
if task_obj.release_name is None:
raise ValueError(f"Task {task_id} is missing required release_name")
if task_obj.draft_revision is None:
raise ValueError(f"Task {task_id} is missing required draft_revision")
if not isinstance(task_args, dict):
raise TypeError(
f"Task {task_id} ({task_type}) has non-dict raw args"
f" {task_args} which should represent keyword_args"
)
async def recorder_factory() -> checks.Recorder:
return await checks.Recorder.create(
checker=handler,
release_name=task_obj.release_name or "",
draft_revision=task_obj.draft_revision or "",
primary_rel_path=task_obj.primary_rel_path,
)
function_arguments = checks.FunctionArguments(
recorder=recorder_factory,
release_name=task_obj.release_name,
draft_revision=task_obj.draft_revision,
primary_rel_path=task_obj.primary_rel_path,
extra_args=task_args,
)
_LOGGER.debug(f"Calling {handler.__name__} with structured arguments: {function_arguments}")
handler_result = await handler(function_arguments)
else:
# Otherwise, it's not a check handler
handler_result = await handler(task_args)
task_results = (handler_result,)
status = task.COMPLETED
error = None
except Exception as e:
task_results = tuple()
status = task.FAILED
error_details = traceback.format_exc()
_LOGGER.error(f"Task {task_id} failed processing: {error_details}")
error = str(e)
await _task_result_process(task_id, task_results, status, error)
async def _task_result_process(
task_id: int, task_results: tuple[Any, ...], status: models.TaskStatus, error: str | None = None
) -> None:
"""Process and store task results in the database."""
async with db.session() as data:
async with data.begin():
# Find the task by ID
task_obj = await data.task(id=task_id).get()
if task_obj:
# Update task properties
task_obj.status = status
task_obj.completed = datetime.datetime.now(datetime.UTC)
task_obj.result = task_results
if (status == task.FAILED) and error:
task_obj.error = error
# Worker functions
async def _worker_loop_run() -> None:
"""Main worker loop."""
processed = 0
max_to_process = 10
while True:
try:
task = await _task_next_claim()
if task:
task_id, task_type, task_args = task
await _task_process(task_id, task_type, task_args)
processed += 1
# Only process max_to_process tasks and then exit
# This prevents memory leaks from accumulating
# Another worker will be started automatically when one exits
if processed >= max_to_process:
break
else:
# No tasks available, wait 100ms before checking again
await asyncio.sleep(0.1)
except Exception:
# TODO: Should probably be more robust about this
_LOGGER.exception("Worker loop error")
await asyncio.sleep(1)
def _worker_resources_limit_set() -> None:
"""Set CPU and memory limits for this process."""
# # Set CPU time limit
# try:
# resource.setrlimit(resource.RLIMIT_CPU, (CPU_LIMIT_SECONDS, CPU_LIMIT_SECONDS))
# _LOGGER.info(f"Set CPU time limit to {CPU_LIMIT_SECONDS} seconds")
# except ValueError as e:
# _LOGGER.warning(f"Could not set CPU time limit: {e}")
# Set memory limit
try:
resource.setrlimit(resource.RLIMIT_AS, (_MEMORY_LIMIT_BYTES, _MEMORY_LIMIT_BYTES))
_LOGGER.info(f"Set memory limit to {_MEMORY_LIMIT_BYTES} bytes")
except ValueError as e:
_LOGGER.warning(f"Could not set memory limit: {e}")
if __name__ == "__main__":
_LOGGER.info("Starting ATR worker...")
try:
main()
except Exception as e:
with open("atr-worker-error.log", "a") as f:
f.write(f"{datetime.datetime.now(datetime.UTC)}: {e}\n")
f.flush()