emails/management/commands/process_emails_from_sqs.py (367 lines of code) (raw):

""" Process the SQS email queue. The SQS queue is processed using the long poll method, which waits until at least one message is available, or wait_seconds is reached. See: https://docs.aws.amazon.com/AWSSimpleQueueService/latest/SQSDeveloperGuide/sqs-short-and-long-polling.html#sqs-long-polling https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sqs.html#SQS.Queue.receive_messages """ import gc import json import logging import shlex import time from datetime import UTC, datetime from multiprocessing import Pool from typing import Any, cast from urllib.parse import urlsplit from django import setup from django.core.management.base import CommandError from django.db import connection from django.http import HttpResponse import boto3 from botocore.exceptions import ClientError from codetiming import Timer from markus.utils import generate_tag from mypy_boto3_sqs.service_resource import Message as SQSMessage from mypy_boto3_sqs.service_resource import Queue as SQSQueue from sentry_sdk import capture_exception from emails.management.command_from_django_settings import ( CommandFromDjangoSettings, SettingToLocal, ) from emails.sns import VerificationFailed, verify_from_sns from emails.utils import gauge_if_enabled, incr_if_enabled from emails.views import _sns_inbound_logic, validate_sns_arn_and_type logger = logging.getLogger("eventsinfo.process_emails_from_sqs") class Command(CommandFromDjangoSettings): help = "Fetch email tasks from SQS and process them." settings_to_locals = [ SettingToLocal( "PROCESS_EMAIL_BATCH_SIZE", "batch_size", "Number of SQS messages to fetch at a time.", lambda batch_size: 0 < batch_size <= 10, ), SettingToLocal( "PROCESS_EMAIL_WAIT_SECONDS", "wait_seconds", "Time to wait for messages with long polling.", lambda wait_seconds: wait_seconds > 0, ), SettingToLocal( "PROCESS_EMAIL_VISIBILITY_SECONDS", "visibility_seconds", "Time to mark a message as reserved for this process.", lambda visibility_seconds: visibility_seconds > 0, ), SettingToLocal( "PROCESS_EMAIL_HEALTHCHECK_PATH", "healthcheck_path", "Path to file to write healthcheck data.", lambda healthcheck_path: healthcheck_path is not None, ), SettingToLocal( "PROCESS_EMAIL_DELETE_FAILED_MESSAGES", "delete_failed_messages", ( "If a message fails to process, delete it from the queue," " instead of letting SQS resend or move to a dead-letter queue." ), lambda delete_failed_messages: delete_failed_messages in (True, False), ), SettingToLocal( "PROCESS_EMAIL_MAX_SECONDS", "max_seconds", "Maximum time to process before exiting, or None to run forever.", lambda max_seconds: max_seconds is None or max_seconds > 0.0, ), SettingToLocal( "PROCESS_EMAIL_MAX_SECONDS_PER_MESSAGE", "max_seconds_per_message", "Maximum time to process a message before cancelling.", lambda max_seconds: max_seconds > 0.0, ), SettingToLocal( "AWS_REGION", "aws_region", "AWS region of the SQS queue", lambda aws_region: bool(aws_region), ), SettingToLocal( "AWS_SQS_EMAIL_QUEUE_URL", "sqs_url", "URL of the SQL queue", lambda sqs_url: bool(sqs_url), ), SettingToLocal( "PROCESS_EMAIL_VERBOSITY", "verbosity", "Default verbosity of the process logs", lambda verbosity: verbosity in range(5), ), ] # Added by CommandFromDjangoSettings.init_from_settings batch_size: int wait_seconds: int visibility_seconds: int healthcheck_path: str delete_failed_messages: bool max_seconds: float | None max_seconds_per_message: float aws_region: str sqs_url: str verbosity: int def handle(self, verbosity: int, *args: Any, **kwargs: Any) -> None: """Handle call from command line (called by BaseCommand)""" self.init_from_settings(verbosity) self.init_locals() logger.info( "Starting process_emails_from_sqs", extra={ "batch_size": self.batch_size, "wait_seconds": self.wait_seconds, "visibility_seconds": self.visibility_seconds, "healthcheck_path": self.healthcheck_path, "delete_failed_messages": self.delete_failed_messages, "max_seconds": self.max_seconds, "max_seconds_per_message": self.max_seconds_per_message, "aws_region": self.aws_region, "sqs_url": self.sqs_url, "verbosity": self.verbosity, }, ) try: self.queue = self.create_client() except ClientError as e: raise CommandError("Unable to connect to SQS") from e process_data = self.process_queue() logger.info("Exiting process_emails_from_sqs", extra=process_data) def init_locals(self) -> None: """Initialize command attributes that don't come from settings.""" self.queue_name = urlsplit(self.sqs_url).path.split("/")[-1] self.halt_requested = False self.start_time: float = 0.0 self.cycles: int = 0 self.total_messages: int = 0 self.failed_messages: int = 0 self.pause_count: int = 0 self.queue_count: int = 0 self.queue_count_delayed: int = 0 self.queue_count_not_visible: int = 0 def create_client(self) -> SQSQueue: """Create the SQS client.""" if not self.aws_region: raise ValueError("self.aws_region must be truthy value.") if not self.sqs_url: raise ValueError("self.sqs_url must be truthy value.") sqs_client = boto3.resource("sqs", region_name=self.aws_region) return sqs_client.Queue(self.sqs_url) def process_queue(self) -> dict[str, Any]: """ Process the SQS email queue until an exit condition is reached. Return is a dict suitable for logging context, with these keys: * exit_on: Why processing exited - "interrupt", "max_seconds", "unknown" * cycles: How many polling cycles completed * total_s: The total execution time, in seconds with millisecond precision * total_messages: The number of messages processed, with and without errors * failed_messages: The number of messages that failed with errors, omitted if none * pause_count: The number of 1-second pauses due to temporary errors """ exit_on = "unknown" self.cycles = 0 self.total_messages = 0 self.failed_messages = 0 self.pause_count = 0 self.start_time = time.monotonic() while not self.halt_requested: try: cycle_data: dict[str, Any] = { "cycle_num": self.cycles, "cycle_s": 0.0, } cycle_data.update(self.refresh_and_emit_queue_count_metrics()) self.write_healthcheck() # Check if we should exit due to time limit if self.max_seconds is not None: elapsed = time.monotonic() - self.start_time if elapsed >= self.max_seconds: exit_on = "max_seconds" break # Request and process a chunk of messages with Timer(logger=None) as cycle_timer: message_batch, queue_data = self.poll_queue_for_messages() cycle_data.update(queue_data) cycle_data.update(self.process_message_batch(message_batch)) # Collect data and log progress self.total_messages += len(message_batch) self.failed_messages += int(cycle_data.get("failed_count", 0)) self.pause_count += int(cycle_data.get("pause_count", 0)) cycle_data["message_total"] = self.total_messages cycle_data["cycle_s"] = round(cycle_timer.last, 3) logger.log( ( logging.INFO if (message_batch or self.verbosity > 1) else logging.DEBUG ), ( f"Cycle {self.cycles}: processed" f" {self.pluralize(len(message_batch), 'message')}" ), extra=cycle_data, ) self.cycles += 1 gc.collect() # Force garbage collection of boto3 SQS client resources except KeyboardInterrupt: self.halt_requested = True exit_on = "interrupt" process_data = { "exit_on": exit_on, "cycles": self.cycles, "total_s": round(time.monotonic() - self.start_time, 3), "total_messages": self.total_messages, } if self.failed_messages: process_data["failed_messages"] = self.failed_messages if self.pause_count: process_data["pause_count"] = self.pause_count return process_data def refresh_and_emit_queue_count_metrics(self) -> dict[str, float | int]: """ Query SQS queue attributes, store backlog metrics, and emit them as gauge stats Return is a dict suitable for logging context, with these keys: * queue_load_s: How long, in seconds (millisecond precision) it took to load attributes * queue_count: Approximate number of messages in queue * queue_count_delayed: Approx. messages not yet ready for receiving * queue_count_not_visible: Approx. messages reserved by other receiver """ # Load attributes from SQS with Timer(logger=None) as attribute_timer: self.queue.load() # Save approximate queue counts self.queue_count = int(self.queue.attributes["ApproximateNumberOfMessages"]) self.queue_count_delayed = int( self.queue.attributes["ApproximateNumberOfMessagesDelayed"] ) self.queue_count_not_visible = int( self.queue.attributes["ApproximateNumberOfMessagesNotVisible"] ) # Emit gauges for approximate queue counts queue_tag = generate_tag("queue", self.queue_name) gauge_if_enabled("email_queue_count", self.queue_count, tags=[queue_tag]) gauge_if_enabled( "email_queue_count_delayed", self.queue_count_delayed, tags=[queue_tag] ) gauge_if_enabled( "email_queue_count_not_visible", self.queue_count_not_visible, tags=[queue_tag], ) return { "queue_load_s": round(attribute_timer.last, 3), "queue_count": self.queue_count, "queue_count_delayed": self.queue_count_delayed, "queue_count_not_visible": self.queue_count_not_visible, } def poll_queue_for_messages( self, ) -> tuple[list[SQSMessage], dict[str, float | int]]: """Request a batch of messages, using the long-poll method. Return is a tuple: * message_batch: a list of messages, which may be empty * data: A dict suitable for logging context, with these keys: - message_count: the number of messages - sqs_poll_s: The poll time, in seconds with millisecond precision """ with Timer(logger=None) as poll_timer: message_batch = self.queue.receive_messages( MaxNumberOfMessages=self.batch_size, VisibilityTimeout=self.visibility_seconds, WaitTimeSeconds=self.wait_seconds, ) return ( message_batch, { "message_count": len(message_batch), "sqs_poll_s": round(poll_timer.last, 3), }, ) def process_message_batch(self, message_batch: list[SQSMessage]) -> dict[str, Any]: """ Process a batch of messages. Arguments: * messages - a list of SQS messages, possibly empty Return is a dict suitable for logging context, with these keys: * process_s: How long processing took, omitted if no messages * pause_count: How many pauses were taken for temporary errors, omitted if 0 * pause_s: How long pauses took, omitted if no pauses * failed_count: How many messages failed to process, omitted if 0 Times are in seconds, with millisecond precision """ if not message_batch: return {} failed_count = 0 pause_time = 0.0 pause_count = 0 process_time = 0.0 for message in message_batch: self.write_healthcheck() with Timer(logger=None) as message_timer: message_data = self.process_message(message) if not message_data["success"]: failed_count += 1 if message_data["success"] or self.delete_failed_messages: message.delete() pause_time += message_data.get("pause_s", 0.0) pause_count += message_data.get("pause_count", 0) message_data["message_process_time_s"] = round(message_timer.last, 3) process_time += message_timer.last logger.log(logging.INFO, "Message processed", extra=message_data) batch_data = {"process_s": round((process_time - pause_time), 3)} if pause_count: batch_data["pause_count"] = pause_count batch_data["pause_s"] = round(pause_time, 3) if failed_count: batch_data["failed_count"] = failed_count return batch_data def process_message(self, message: SQSMessage) -> dict[str, Any]: """ Process an SQS message, which may include sending an email. Return is a dict suitable for logging context, with these keys: * success: True if message was processed successfully * error: The processing error, omitted on success * message_body_quoted: Set if the message was non-JSON, omitted for valid JSON * pause_count: Set to 1 if paused due to temporary error, or omitted with no error * pause_s: The pause in seconds (ms precision) for temp error, or omitted * pause_error: The temporary error, or omitted if no temp error * client_error_code: The error code for non-temp or retry error, omitted on success """ incr_if_enabled("process_message_from_sqs", 1) results = {"success": True, "sqs_message_id": message.message_id} raw_body = message.body try: json_body = json.loads(raw_body) except ValueError as e: results["success"] = False results["error"] = f"Failed to load message.body: {e}" results["message_body_quoted"] = shlex.quote(raw_body) return results try: verified_json_body = verify_from_sns(json_body) except (KeyError, VerificationFailed) as e: logger.error("Failed SNS verification", extra={"error": str(e)}) results["success"] = False results["error"] = f"Failed SNS verification: {e}" return results topic_arn = verified_json_body["TopicArn"] message_type = verified_json_body["Type"] error_details = validate_sns_arn_and_type(topic_arn, message_type) if error_details: results["success"] = False results.update(error_details) return results def success_callback(result: HttpResponse) -> None: """Handle return from successful call to _sns_inbound_logic""" # TODO: extract data from _sns_inbound_logic return pass def error_callback(exc_info: BaseException) -> None: """Handle exception raised by _sns_inbound_logic""" capture_exception(exc_info) results["success"] = False if isinstance(exc_info, ClientError): incr_if_enabled("message_from_sqs_error") err = exc_info.response["Error"] logger.error("sqs_client_error", extra=err) results["error"] = err results["client_error_code"] = err["Code"].lower() else: incr_if_enabled("email_processing_failure") results["error"] = str(exc_info) results["error_type"] = type(exc_info).__name__ # Run in a multiprocessing Pool # This will start a subprocess, which needs to run django.setup # The benefit is that the subprocess can be terminated # The penalty is that is is slower to start pool_start_time = time.monotonic() with Pool(1, initializer=setup) as pool: future = pool.apply_async( run_sns_inbound_logic, [topic_arn, message_type, verified_json_body], callback=success_callback, error_callback=error_callback, ) setup_time = time.monotonic() - pool_start_time results["subprocess_setup_time_s"] = round(setup_time, 3) message_start_time = time.monotonic() message_duration = 0.0 while message_duration < self.max_seconds_per_message: self.write_healthcheck() future.wait(1.0) message_duration = time.monotonic() - message_start_time if future.ready(): break results["message_process_time_s"] = round(message_duration, 3) if not future.ready(): error = f"Timed out after {self.max_seconds_per_message:0.1f} seconds." results["success"] = False results["error"] = error return results def write_healthcheck(self) -> None: """Update the healthcheck file with operations data, if path is set.""" data: dict[str, str | int] = { "timestamp": datetime.now(tz=UTC).isoformat(), "cycles": self.cycles, "total_messages": self.total_messages, "failed_messages": self.failed_messages, "pause_count": self.pause_count, "queue_count": int(self.queue.attributes["ApproximateNumberOfMessages"]), "queue_count_delayed": int( self.queue.attributes["ApproximateNumberOfMessagesDelayed"] ), "queue_count_not_visible": int( self.queue.attributes["ApproximateNumberOfMessagesNotVisible"] ), } with open(self.healthcheck_path, "w", encoding="utf-8") as healthcheck_file: json.dump(data, healthcheck_file) def pluralize(self, value: int, singular: str, plural: str | None = None) -> str: """Returns 's' suffix to make plural, like 's' in tasks""" if value == 1: return f"{value} {singular}" else: return f"{value} {plural or (singular + 's')}" def run_sns_inbound_logic( topic_arn: str, message_type: str, json_body: str ) -> HttpResponse: # Reset any exiting connection, verify it is usable with connection.cursor() as cursor: cursor.db.queries_log.clear() if not cursor.db.is_usable(): cursor.db.close() return cast(HttpResponse, _sns_inbound_logic(topic_arn, message_type, json_body))