# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

"""worker.py - Task worker process for ATR"""

# TODO: If started is older than some threshold and status
# is active but the pid is no longer running, we can revert
# the task to status='QUEUED'. For this to work, ideally we
# need to check wall clock time as well as CPU time.

import asyncio
import datetime
import inspect
import json
import logging
import os
import resource
import signal
import traceback
from typing import Any, Final

import sqlmodel

import atr.db as db
import atr.db.models as models
import atr.tasks as tasks
import atr.tasks.checks as checks
import atr.tasks.task as task

_LOGGER: Final = logging.getLogger(__name__)

# Resource limits, 5 minutes and 1GB
# _CPU_LIMIT_SECONDS: Final = 300
_MEMORY_LIMIT_BYTES: Final = 1024 * 1024 * 1024

# # Create tables if they don't exist
# SQLModel.metadata.create_all(engine)


def main() -> None:
    """Main entry point."""
    import atr.config as config

    conf = config.get()
    if os.path.isdir(conf.STATE_DIR):
        os.chdir(conf.STATE_DIR)

    _setup_logging()

    _LOGGER.info(f"Starting worker process with pid {os.getpid()}")

    tasks: list[asyncio.Task] = []

    async def _handle_signal(signum: int) -> None:
        _LOGGER.info(f"Received signal {signum}, shutting down...")

        await db.shutdown_database()

        for t in tasks:
            t.cancel()

        _LOGGER.debug("Cancelled all running tasks")
        asyncio.get_event_loop().stop()
        _LOGGER.debug("Stopped event loop")

    for s in (signal.SIGTERM, signal.SIGINT):
        signal.signal(s, lambda signum, frame: asyncio.create_task(_handle_signal(signum)))

    _worker_resources_limit_set()

    async def _start() -> None:
        await asyncio.create_task(db.init_database_for_worker())
        tasks.append(asyncio.create_task(_worker_loop_run()))
        await asyncio.gather(*tasks)

    asyncio.run(_start())

    # If the worker decides to stop running (see #230 in _worker_loop_run()), shutdown the database gracefully
    asyncio.run(db.shutdown_database())
    _LOGGER.info("Exiting worker process")


def _setup_logging() -> None:
    # Configure logging
    log_format = "[%(asctime)s.%(msecs)03d] [%(process)d] [%(levelname)s] %(message)s"
    date_format = "%Y-%m-%d %H:%M:%S"

    logging.basicConfig(filename="atr-worker.log", format=log_format, datefmt=date_format, level=logging.INFO)


# Task functions


async def _task_error_handle(task_id: int, e: Exception) -> None:
    """Handle task error by updating the database with error information."""
    if isinstance(e, task.Error):
        _LOGGER.error(f"Task {task_id} failed: {e.message}")
        _LOGGER.error("".join(traceback.format_exception(e)))
        result = json.dumps(e.result)

        async with db.session() as data:
            async with data.begin():
                task_obj = await data.task(id=task_id).get()
                if task_obj:
                    task_obj.status = task.FAILED
                    task_obj.completed = datetime.datetime.now(datetime.UTC)
                    task_obj.error = e.message
                    task_obj.result = result
    else:
        _LOGGER.error(f"Task {task_id} failed: {e}")
        _LOGGER.error("".join(traceback.format_exception(e)))

        async with db.session() as data:
            async with data.begin():
                task_obj = await data.task(id=task_id).get()
                if task_obj:
                    task_obj.status = task.FAILED
                    task_obj.completed = datetime.datetime.now(datetime.UTC)
                    task_obj.error = str(e)


async def _task_next_claim() -> tuple[int, str, list[str] | dict[str, Any]] | None:
    """
    Attempt to claim the oldest unclaimed task.
    Returns (task_id, task_type, task_args) if successful.
    Returns None if no tasks are available.
    """
    async with db.session() as data:
        async with data.begin():
            # Get the ID of the oldest queued task
            oldest_queued_task = (
                sqlmodel.select(models.Task.id)
                .where(models.Task.status == task.QUEUED)
                .order_by(db.validate_instrumented_attribute(models.Task.added).asc())
                .limit(1)
            )

            # Use an UPDATE with a WHERE clause to atomically claim the task
            # This ensures that only one worker can claim a specific task
            now = datetime.datetime.now(datetime.UTC)
            update_stmt = (
                sqlmodel.update(models.Task)
                .where(sqlmodel.and_(models.Task.id == oldest_queued_task, models.Task.status == task.QUEUED))
                .values(status=task.ACTIVE, started=now, pid=os.getpid())
                .returning(
                    db.validate_instrumented_attribute(models.Task.id),
                    db.validate_instrumented_attribute(models.Task.task_type),
                    db.validate_instrumented_attribute(models.Task.task_args),
                )
            )

            result = await data.execute(update_stmt)
            claimed_task = result.first()

            if claimed_task:
                task_id, task_type, task_args = claimed_task
                _LOGGER.info(f"Claimed task {task_id} ({task_type}) with args {task_args}")
                return task_id, task_type, task_args

            return None


async def _task_process(task_id: int, task_type: str, task_args: list[str] | dict[str, Any]) -> None:
    """Process a claimed task."""
    _LOGGER.info(f"Processing task {task_id} ({task_type}) with raw args {task_args}")
    try:
        task_type_member = models.TaskType(task_type)
    except ValueError as e:
        _LOGGER.error(f"Invalid task type: {task_type}")
        await _task_result_process(task_id, tuple(), task.FAILED, str(e))
        return

    task_results: tuple[Any, ...]
    try:
        handler = tasks.resolve(task_type_member)
        sig = inspect.signature(handler)
        params = list(sig.parameters.values())

        # Check whether the handler is a check handler
        if (len(params) == 1) and (params[0].annotation == checks.FunctionArguments):
            _LOGGER.debug(f"Handler {handler.__name__} expects checks.FunctionArguments, fetching full task details")
            async with db.session() as data:
                task_obj = await data.task(id=task_id).demand(
                    ValueError(f"Task {task_id} disappeared during processing")
                )

            # Validate required fields from the Task object itself
            if task_obj.release_name is None:
                raise ValueError(f"Task {task_id} is missing required release_name")
            if task_obj.draft_revision is None:
                raise ValueError(f"Task {task_id} is missing required draft_revision")

            if not isinstance(task_args, dict):
                raise TypeError(
                    f"Task {task_id} ({task_type}) has non-dict raw args"
                    f" {task_args} which should represent keyword_args"
                )

            async def recorder_factory() -> checks.Recorder:
                return await checks.Recorder.create(
                    checker=handler,
                    release_name=task_obj.release_name or "",
                    draft_revision=task_obj.draft_revision or "",
                    primary_rel_path=task_obj.primary_rel_path,
                )

            function_arguments = checks.FunctionArguments(
                recorder=recorder_factory,
                release_name=task_obj.release_name,
                draft_revision=task_obj.draft_revision,
                primary_rel_path=task_obj.primary_rel_path,
                extra_args=task_args,
            )
            _LOGGER.debug(f"Calling {handler.__name__} with structured arguments: {function_arguments}")
            handler_result = await handler(function_arguments)
        else:
            # Otherwise, it's not a check handler
            handler_result = await handler(task_args)

        task_results = (handler_result,)
        status = task.COMPLETED
        error = None
    except Exception as e:
        task_results = tuple()
        status = task.FAILED
        error_details = traceback.format_exc()
        _LOGGER.error(f"Task {task_id} failed processing: {error_details}")
        error = str(e)
    await _task_result_process(task_id, task_results, status, error)


async def _task_result_process(
    task_id: int, task_results: tuple[Any, ...], status: models.TaskStatus, error: str | None = None
) -> None:
    """Process and store task results in the database."""
    async with db.session() as data:
        async with data.begin():
            # Find the task by ID
            task_obj = await data.task(id=task_id).get()
            if task_obj:
                # Update task properties
                task_obj.status = status
                task_obj.completed = datetime.datetime.now(datetime.UTC)
                task_obj.result = task_results

                if (status == task.FAILED) and error:
                    task_obj.error = error


# Worker functions


async def _worker_loop_run() -> None:
    """Main worker loop."""
    processed = 0
    max_to_process = 10
    while True:
        try:
            task = await _task_next_claim()
            if task:
                task_id, task_type, task_args = task
                await _task_process(task_id, task_type, task_args)
                processed += 1
                # Only process max_to_process tasks and then exit
                # This prevents memory leaks from accumulating
                # Another worker will be started automatically when one exits
                if processed >= max_to_process:
                    break
            else:
                # No tasks available, wait 100ms before checking again
                await asyncio.sleep(0.1)
        except Exception:
            # TODO: Should probably be more robust about this
            _LOGGER.exception("Worker loop error")
            await asyncio.sleep(1)


def _worker_resources_limit_set() -> None:
    """Set CPU and memory limits for this process."""
    # # Set CPU time limit
    # try:
    #     resource.setrlimit(resource.RLIMIT_CPU, (CPU_LIMIT_SECONDS, CPU_LIMIT_SECONDS))
    #     _LOGGER.info(f"Set CPU time limit to {CPU_LIMIT_SECONDS} seconds")
    # except ValueError as e:
    #     _LOGGER.warning(f"Could not set CPU time limit: {e}")

    # Set memory limit
    try:
        resource.setrlimit(resource.RLIMIT_AS, (_MEMORY_LIMIT_BYTES, _MEMORY_LIMIT_BYTES))
        _LOGGER.info(f"Set memory limit to {_MEMORY_LIMIT_BYTES} bytes")
    except ValueError as e:
        _LOGGER.warning(f"Could not set memory limit: {e}")


if __name__ == "__main__":
    _LOGGER.info("Starting ATR worker...")
    try:
        main()
    except Exception as e:
        with open("atr-worker-error.log", "a") as f:
            f.write(f"{datetime.datetime.now(datetime.UTC)}: {e}\n")
            f.flush()
