"""
Process the SQS email queue.

The SQS queue is processed using the long poll method, which waits until at
least one message is available, or wait_seconds is reached.

See:
https://docs.aws.amazon.com/AWSSimpleQueueService/latest/SQSDeveloperGuide/sqs-short-and-long-polling.html#sqs-long-polling
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sqs.html#SQS.Queue.receive_messages
"""

import gc
import json
import logging
import shlex
import time
from datetime import UTC, datetime
from multiprocessing import Pool
from typing import Any, cast
from urllib.parse import urlsplit

from django import setup
from django.core.management.base import CommandError
from django.db import connection
from django.http import HttpResponse

import boto3
from botocore.exceptions import ClientError
from codetiming import Timer
from markus.utils import generate_tag
from mypy_boto3_sqs.service_resource import Message as SQSMessage
from mypy_boto3_sqs.service_resource import Queue as SQSQueue
from sentry_sdk import capture_exception

from emails.management.command_from_django_settings import (
    CommandFromDjangoSettings,
    SettingToLocal,
)
from emails.sns import VerificationFailed, verify_from_sns
from emails.utils import gauge_if_enabled, incr_if_enabled
from emails.views import _sns_inbound_logic, validate_sns_arn_and_type

logger = logging.getLogger("eventsinfo.process_emails_from_sqs")


class Command(CommandFromDjangoSettings):
    help = "Fetch email tasks from SQS and process them."

    settings_to_locals = [
        SettingToLocal(
            "PROCESS_EMAIL_BATCH_SIZE",
            "batch_size",
            "Number of SQS messages to fetch at a time.",
            lambda batch_size: 0 < batch_size <= 10,
        ),
        SettingToLocal(
            "PROCESS_EMAIL_WAIT_SECONDS",
            "wait_seconds",
            "Time to wait for messages with long polling.",
            lambda wait_seconds: wait_seconds > 0,
        ),
        SettingToLocal(
            "PROCESS_EMAIL_VISIBILITY_SECONDS",
            "visibility_seconds",
            "Time to mark a message as reserved for this process.",
            lambda visibility_seconds: visibility_seconds > 0,
        ),
        SettingToLocal(
            "PROCESS_EMAIL_HEALTHCHECK_PATH",
            "healthcheck_path",
            "Path to file to write healthcheck data.",
            lambda healthcheck_path: healthcheck_path is not None,
        ),
        SettingToLocal(
            "PROCESS_EMAIL_DELETE_FAILED_MESSAGES",
            "delete_failed_messages",
            (
                "If a message fails to process, delete it from the queue,"
                " instead of letting SQS resend or move to a dead-letter queue."
            ),
            lambda delete_failed_messages: delete_failed_messages in (True, False),
        ),
        SettingToLocal(
            "PROCESS_EMAIL_MAX_SECONDS",
            "max_seconds",
            "Maximum time to process before exiting, or None to run forever.",
            lambda max_seconds: max_seconds is None or max_seconds > 0.0,
        ),
        SettingToLocal(
            "PROCESS_EMAIL_MAX_SECONDS_PER_MESSAGE",
            "max_seconds_per_message",
            "Maximum time to process a message before cancelling.",
            lambda max_seconds: max_seconds > 0.0,
        ),
        SettingToLocal(
            "AWS_REGION",
            "aws_region",
            "AWS region of the SQS queue",
            lambda aws_region: bool(aws_region),
        ),
        SettingToLocal(
            "AWS_SQS_EMAIL_QUEUE_URL",
            "sqs_url",
            "URL of the SQL queue",
            lambda sqs_url: bool(sqs_url),
        ),
        SettingToLocal(
            "PROCESS_EMAIL_VERBOSITY",
            "verbosity",
            "Default verbosity of the process logs",
            lambda verbosity: verbosity in range(5),
        ),
    ]

    # Added by CommandFromDjangoSettings.init_from_settings
    batch_size: int
    wait_seconds: int
    visibility_seconds: int
    healthcheck_path: str
    delete_failed_messages: bool
    max_seconds: float | None
    max_seconds_per_message: float
    aws_region: str
    sqs_url: str
    verbosity: int

    def handle(self, verbosity: int, *args: Any, **kwargs: Any) -> None:
        """Handle call from command line (called by BaseCommand)"""
        self.init_from_settings(verbosity)
        self.init_locals()
        logger.info(
            "Starting process_emails_from_sqs",
            extra={
                "batch_size": self.batch_size,
                "wait_seconds": self.wait_seconds,
                "visibility_seconds": self.visibility_seconds,
                "healthcheck_path": self.healthcheck_path,
                "delete_failed_messages": self.delete_failed_messages,
                "max_seconds": self.max_seconds,
                "max_seconds_per_message": self.max_seconds_per_message,
                "aws_region": self.aws_region,
                "sqs_url": self.sqs_url,
                "verbosity": self.verbosity,
            },
        )

        try:
            self.queue = self.create_client()
        except ClientError as e:
            raise CommandError("Unable to connect to SQS") from e

        process_data = self.process_queue()
        logger.info("Exiting process_emails_from_sqs", extra=process_data)

    def init_locals(self) -> None:
        """Initialize command attributes that don't come from settings."""
        self.queue_name = urlsplit(self.sqs_url).path.split("/")[-1]
        self.halt_requested = False
        self.start_time: float = 0.0
        self.cycles: int = 0
        self.total_messages: int = 0
        self.failed_messages: int = 0
        self.pause_count: int = 0
        self.queue_count: int = 0
        self.queue_count_delayed: int = 0
        self.queue_count_not_visible: int = 0

    def create_client(self) -> SQSQueue:
        """Create the SQS client."""
        if not self.aws_region:
            raise ValueError("self.aws_region must be truthy value.")
        if not self.sqs_url:
            raise ValueError("self.sqs_url must be truthy value.")
        sqs_client = boto3.resource("sqs", region_name=self.aws_region)
        return sqs_client.Queue(self.sqs_url)

    def process_queue(self) -> dict[str, Any]:
        """
        Process the SQS email queue until an exit condition is reached.

        Return is a dict suitable for logging context, with these keys:
        * exit_on: Why processing exited - "interrupt", "max_seconds", "unknown"
        * cycles: How many polling cycles completed
        * total_s: The total execution time, in seconds with millisecond precision
        * total_messages: The number of messages processed, with and without errors
        * failed_messages: The number of messages that failed with errors,
          omitted if none
        * pause_count: The number of 1-second pauses due to temporary errors
        """
        exit_on = "unknown"
        self.cycles = 0
        self.total_messages = 0
        self.failed_messages = 0
        self.pause_count = 0
        self.start_time = time.monotonic()

        while not self.halt_requested:
            try:
                cycle_data: dict[str, Any] = {
                    "cycle_num": self.cycles,
                    "cycle_s": 0.0,
                }
                cycle_data.update(self.refresh_and_emit_queue_count_metrics())
                self.write_healthcheck()

                # Check if we should exit due to time limit
                if self.max_seconds is not None:
                    elapsed = time.monotonic() - self.start_time
                    if elapsed >= self.max_seconds:
                        exit_on = "max_seconds"
                        break

                # Request and process a chunk of messages
                with Timer(logger=None) as cycle_timer:
                    message_batch, queue_data = self.poll_queue_for_messages()
                    cycle_data.update(queue_data)
                    cycle_data.update(self.process_message_batch(message_batch))

                # Collect data and log progress
                self.total_messages += len(message_batch)
                self.failed_messages += int(cycle_data.get("failed_count", 0))
                self.pause_count += int(cycle_data.get("pause_count", 0))
                cycle_data["message_total"] = self.total_messages
                cycle_data["cycle_s"] = round(cycle_timer.last, 3)
                logger.log(
                    (
                        logging.INFO
                        if (message_batch or self.verbosity > 1)
                        else logging.DEBUG
                    ),
                    (
                        f"Cycle {self.cycles}: processed"
                        f" {self.pluralize(len(message_batch), 'message')}"
                    ),
                    extra=cycle_data,
                )

                self.cycles += 1
                gc.collect()  # Force garbage collection of boto3 SQS client resources

            except KeyboardInterrupt:
                self.halt_requested = True
                exit_on = "interrupt"

        process_data = {
            "exit_on": exit_on,
            "cycles": self.cycles,
            "total_s": round(time.monotonic() - self.start_time, 3),
            "total_messages": self.total_messages,
        }
        if self.failed_messages:
            process_data["failed_messages"] = self.failed_messages
        if self.pause_count:
            process_data["pause_count"] = self.pause_count
        return process_data

    def refresh_and_emit_queue_count_metrics(self) -> dict[str, float | int]:
        """
        Query SQS queue attributes, store backlog metrics, and emit them as gauge stats

        Return is a dict suitable for logging context, with these keys:
        * queue_load_s: How long, in seconds (millisecond precision) it took to
          load attributes
        * queue_count: Approximate number of messages in queue
        * queue_count_delayed: Approx. messages not yet ready for receiving
        * queue_count_not_visible: Approx. messages reserved by other receiver

        """
        # Load attributes from SQS
        with Timer(logger=None) as attribute_timer:
            self.queue.load()

        # Save approximate queue counts
        self.queue_count = int(self.queue.attributes["ApproximateNumberOfMessages"])
        self.queue_count_delayed = int(
            self.queue.attributes["ApproximateNumberOfMessagesDelayed"]
        )
        self.queue_count_not_visible = int(
            self.queue.attributes["ApproximateNumberOfMessagesNotVisible"]
        )

        # Emit gauges for approximate queue counts
        queue_tag = generate_tag("queue", self.queue_name)
        gauge_if_enabled("email_queue_count", self.queue_count, tags=[queue_tag])
        gauge_if_enabled(
            "email_queue_count_delayed", self.queue_count_delayed, tags=[queue_tag]
        )
        gauge_if_enabled(
            "email_queue_count_not_visible",
            self.queue_count_not_visible,
            tags=[queue_tag],
        )

        return {
            "queue_load_s": round(attribute_timer.last, 3),
            "queue_count": self.queue_count,
            "queue_count_delayed": self.queue_count_delayed,
            "queue_count_not_visible": self.queue_count_not_visible,
        }

    def poll_queue_for_messages(
        self,
    ) -> tuple[list[SQSMessage], dict[str, float | int]]:
        """Request a batch of messages, using the long-poll method.

        Return is a tuple:
        * message_batch: a list of messages, which may be empty
        * data: A dict suitable for logging context, with these keys:
            - message_count: the number of messages
            - sqs_poll_s: The poll time, in seconds with millisecond precision
        """
        with Timer(logger=None) as poll_timer:
            message_batch = self.queue.receive_messages(
                MaxNumberOfMessages=self.batch_size,
                VisibilityTimeout=self.visibility_seconds,
                WaitTimeSeconds=self.wait_seconds,
            )
        return (
            message_batch,
            {
                "message_count": len(message_batch),
                "sqs_poll_s": round(poll_timer.last, 3),
            },
        )

    def process_message_batch(self, message_batch: list[SQSMessage]) -> dict[str, Any]:
        """
        Process a batch of messages.

        Arguments:
        * messages - a list of SQS messages, possibly empty

        Return is a dict suitable for logging context, with these keys:
        * process_s: How long processing took, omitted if no messages
        * pause_count: How many pauses were taken for temporary errors, omitted if 0
        * pause_s: How long pauses took, omitted if no pauses
        * failed_count: How many messages failed to process, omitted if 0

        Times are in seconds, with millisecond precision
        """
        if not message_batch:
            return {}
        failed_count = 0
        pause_time = 0.0
        pause_count = 0
        process_time = 0.0
        for message in message_batch:
            self.write_healthcheck()
            with Timer(logger=None) as message_timer:
                message_data = self.process_message(message)
                if not message_data["success"]:
                    failed_count += 1
                if message_data["success"] or self.delete_failed_messages:
                    message.delete()
                pause_time += message_data.get("pause_s", 0.0)
                pause_count += message_data.get("pause_count", 0)

            message_data["message_process_time_s"] = round(message_timer.last, 3)
            process_time += message_timer.last
            logger.log(logging.INFO, "Message processed", extra=message_data)

        batch_data = {"process_s": round((process_time - pause_time), 3)}
        if pause_count:
            batch_data["pause_count"] = pause_count
            batch_data["pause_s"] = round(pause_time, 3)
        if failed_count:
            batch_data["failed_count"] = failed_count
        return batch_data

    def process_message(self, message: SQSMessage) -> dict[str, Any]:
        """
        Process an SQS message, which may include sending an email.

        Return is a dict suitable for logging context, with these keys:
        * success: True if message was processed successfully
        * error: The processing error, omitted on success
        * message_body_quoted: Set if the message was non-JSON, omitted for valid JSON
        * pause_count: Set to 1 if paused due to temporary error, or omitted
          with no error
        * pause_s: The pause in seconds (ms precision) for temp error, or omitted
        * pause_error: The temporary error, or omitted if no temp error
        * client_error_code: The error code for non-temp or retry error,
          omitted on success
        """
        incr_if_enabled("process_message_from_sqs", 1)
        results = {"success": True, "sqs_message_id": message.message_id}
        raw_body = message.body
        try:
            json_body = json.loads(raw_body)
        except ValueError as e:
            results["success"] = False
            results["error"] = f"Failed to load message.body: {e}"
            results["message_body_quoted"] = shlex.quote(raw_body)
            return results
        try:
            verified_json_body = verify_from_sns(json_body)
        except (KeyError, VerificationFailed) as e:
            logger.error("Failed SNS verification", extra={"error": str(e)})
            results["success"] = False
            results["error"] = f"Failed SNS verification: {e}"
            return results

        topic_arn = verified_json_body["TopicArn"]
        message_type = verified_json_body["Type"]
        error_details = validate_sns_arn_and_type(topic_arn, message_type)
        if error_details:
            results["success"] = False
            results.update(error_details)
            return results

        def success_callback(result: HttpResponse) -> None:
            """Handle return from successful call to _sns_inbound_logic"""
            # TODO: extract data from _sns_inbound_logic return
            pass

        def error_callback(exc_info: BaseException) -> None:
            """Handle exception raised by _sns_inbound_logic"""
            capture_exception(exc_info)
            results["success"] = False
            if isinstance(exc_info, ClientError):
                incr_if_enabled("message_from_sqs_error")
                err = exc_info.response["Error"]
                logger.error("sqs_client_error", extra=err)
                results["error"] = err
                results["client_error_code"] = err["Code"].lower()
            else:
                incr_if_enabled("email_processing_failure")
                results["error"] = str(exc_info)
                results["error_type"] = type(exc_info).__name__

        # Run in a multiprocessing Pool
        # This will start a subprocess, which needs to run django.setup
        # The benefit is that the subprocess can be terminated
        # The penalty is that is is slower to start
        pool_start_time = time.monotonic()
        with Pool(1, initializer=setup) as pool:
            future = pool.apply_async(
                run_sns_inbound_logic,
                [topic_arn, message_type, verified_json_body],
                callback=success_callback,
                error_callback=error_callback,
            )
            setup_time = time.monotonic() - pool_start_time
            results["subprocess_setup_time_s"] = round(setup_time, 3)

            message_start_time = time.monotonic()
            message_duration = 0.0
            while message_duration < self.max_seconds_per_message:
                self.write_healthcheck()
                future.wait(1.0)
                message_duration = time.monotonic() - message_start_time
                if future.ready():
                    break

            results["message_process_time_s"] = round(message_duration, 3)
            if not future.ready():
                error = f"Timed out after {self.max_seconds_per_message:0.1f} seconds."
                results["success"] = False
                results["error"] = error
        return results

    def write_healthcheck(self) -> None:
        """Update the healthcheck file with operations data, if path is set."""
        data: dict[str, str | int] = {
            "timestamp": datetime.now(tz=UTC).isoformat(),
            "cycles": self.cycles,
            "total_messages": self.total_messages,
            "failed_messages": self.failed_messages,
            "pause_count": self.pause_count,
            "queue_count": int(self.queue.attributes["ApproximateNumberOfMessages"]),
            "queue_count_delayed": int(
                self.queue.attributes["ApproximateNumberOfMessagesDelayed"]
            ),
            "queue_count_not_visible": int(
                self.queue.attributes["ApproximateNumberOfMessagesNotVisible"]
            ),
        }
        with open(self.healthcheck_path, "w", encoding="utf-8") as healthcheck_file:
            json.dump(data, healthcheck_file)

    def pluralize(self, value: int, singular: str, plural: str | None = None) -> str:
        """Returns 's' suffix to make plural, like 's' in tasks"""
        if value == 1:
            return f"{value} {singular}"
        else:
            return f"{value} {plural or (singular + 's')}"


def run_sns_inbound_logic(
    topic_arn: str, message_type: str, json_body: str
) -> HttpResponse:
    # Reset any exiting connection, verify it is usable
    with connection.cursor() as cursor:
        cursor.db.queries_log.clear()
        if not cursor.db.is_usable():
            cursor.db.close()

    return cast(HttpResponse, _sns_inbound_logic(topic_arn, message_type, json_body))
