images/airflow/2.9.2/python/mwaa/entrypoint.py (140 lines of code) (raw):
"""
This is the entrypoint of the Docker image when running Airflow components.
The script gets called with the Airflow component name, e.g. scheduler, as the
first and only argument. It accordingly runs the requested Airflow component
after setting up the necessary configurations.
"""
# Setup logging first thing to make sure all logs happen under the right setup. The
# reason for needing this is that typically a `logger` object is defined at the top
# of the module and is used through out it. So, if we import a module before logging
# is setup, its `logger` object will not have the right setup.
# ruff: noqa: E402
# fmt: off
from airflow.config_templates.airflow_local_settings import DEFAULT_LOGGING_CONFIG
import logging.config
logging.config.dictConfig(DEFAULT_LOGGING_CONFIG)
# fmt: on
# Python imports
from datetime import datetime
import asyncio
import logging
import os
import sys
import time
# 3rd party imports
import boto3
from botocore.exceptions import ClientError
# Our imports
from mwaa.execute_command import execute_command
from mwaa.config.setup_environment import setup_environment_variables
from mwaa.config.sqs import (
get_sqs_queue_name,
should_create_queue,
)
from mwaa.utils.cmd import run_command
from mwaa.utils.dblock import with_db_lock
from mwaa.utils.statsd import get_statsd
from mwaa.utils.user_requirements import install_user_requirements
# Usually, we pass the `__name__` variable instead as that defaults to the
# module path, i.e. `mwaa.entrypoint` in this case. However, since this is
# the entrypoint script, `__name__` will have the value of `__main__`, hence
# we hard-code the module path.
logger = logging.getLogger("mwaa.entrypoint")
def _setup_console_log_level(command: str):
# Set up console log level environment variable based on command
component_mapping = {
'scheduler': 'MWAA__LOGGING__AIRFLOW_SCHEDULER_LOG_LEVEL',
'worker': 'MWAA__LOGGING__AIRFLOW_WORKER_LOG_LEVEL',
'webserver': 'MWAA__LOGGING__AIRFLOW_WEBSERVER_LOG_LEVEL'
}
if command in component_mapping:
source_var = component_mapping[command]
os.environ['AIRFLOW_CONSOLE_LOG_LEVEL'] = os.environ[source_var]
else:
os.environ['AIRFLOW_CONSOLE_LOG_LEVEL'] = 'INFO'
def _configure_root_logger(command: str):
_setup_console_log_level(command)
# Doing a local import because we can't import
# LOGGING_CONFIG before setting AIRFLOW_CONSOLE_LOG_LEVEL
# as it will lead to root logger's log level set to default value
from mwaa.logging.config import LOGGING_CONFIG
logging.config.dictConfig(LOGGING_CONFIG)
# TODO Fix the "type: ignore"s in this file.
AVAILABLE_COMMANDS = [
"webserver",
"scheduler",
"worker",
"hybrid",
"shell",
"resetdb",
"spy",
"test-requirements",
"test-startup-script",
]
# Save the start time of the container. This is used later to with the sidecar
# monitoring because we need to have a grace period before we start reporting timeouts
# related to sidecar endpoint not reporting health messages.
CONTAINER_START_TIME = time.time()
async def airflow_db_init(environ: dict[str, str]):
"""
Initialize Airflow database.
Before Airflow can be used, a call to `airflow db migrate` must be done. This
function does this. This function is called in the entrypoint to make sure that,
for any Airflow component, the database is initialized before it starts.
:param environ: A dictionary containing the environment variables.
"""
await run_command("python3 -m mwaa.database.migrate", env=environ)
async def increase_pool_size_if_default_size(environ: dict[str, str]):
"""
Update the default pool size
Fix a regression where some 2.9.2 environments were created with the default 128 default
pool size. This function checks if the environment was created during the problematic
timeframe and update the size if it has not been updated by the customer.
:param environ: A dictionary containing the environment variables.
"""
created_at = os.environ.get("MWAA__CORE__CREATED_AT")
problematic_pool_size = 128
if created_at:
try:
date_format = "%a %b %d %H:%M:%S %Z %Y"
created_date = datetime.strptime(created_at, date_format)
# Has a little buffer from when 2.9.2 was released and when the fix was fully deployed
issue_beginning = datetime(2024, 7, 8)
issue_resolution = datetime(2024, 9, 6)
if created_date > issue_beginning and created_date < issue_resolution:
command_output = []
# Get the current default_pool size
await run_command("airflow pools get default_pool | grep default_pool | awk '{print $3}'",
env=environ, stdout_logging_method=lambda output : command_output.append(output))
# Increasing the pool size if it is the default size
if len(command_output) == 1 and int(command_output[0]) == problematic_pool_size:
logger.info("Setting default_pool size to 10000.")
await run_command("airflow pools set default_pool 10000 default", env=environ)
stats = get_statsd()
stats.incr("mwaa.pool.increased_default_pool_size", 1)
except Exception as error:
logger.error(f"Error checking if pool issue is present: {error}")
@with_db_lock(5678)
async def create_airflow_user(environ: dict[str, str]):
"""
Create the 'airflow' user.
To be able to login to the webserver, you need a user. This function creates a user
with default credentials.
Notice that this should only be used in development context. In production, other
means need to be employed to create users with strong passwords. Alternatively, with
MWAA setup, a plugin is employed to integrate with IAM (not implemented yet.)
:param environ: A dictionary containing the environment variables.
"""
logger.info("Calling 'airflow users create' to create the webserver user.")
await run_command(
"airflow users create "
"--username airflow "
"--firstname Airflow "
"--lastname Admin "
"--email airflow@example.com "
"--role Admin "
"--password airflow",
env=environ,
)
@with_db_lock(1357)
def create_queue() -> None:
"""
Create the SQS required by Celery.
In our setup, we use SQS as the backend for Celery. Usually, this should be created
before hand. However, sometimes you might want to create the SQS queue during
startup. One such example is when using the elasticmq server as a mock SQS server.
"""
if not should_create_queue():
return
queue_name = get_sqs_queue_name()
endpoint = os.environ.get("MWAA__SQS__CUSTOM_ENDPOINT")
sqs = boto3.client("sqs", endpoint_url=endpoint) # type: ignore
try:
# Try to get the queue URL to check if it exists
sqs.get_queue_url(QueueName=queue_name)["QueueUrl"] # type: ignore
logger.info(f"Queue {queue_name} already exists.")
except ClientError as e:
# If the queue does not exist, create it
if (
e.response.get("Error", {}).get("Message") # type: ignore
== "The specified queue does not exist."
):
response = sqs.create_queue(QueueName=queue_name) # type: ignore
queue_url = response["QueueUrl"] # type: ignore
logger.info(f"Queue created: {queue_url}")
else:
# If there is a different error, raise it
raise e
async def main() -> None:
"""Start execution of the script."""
try:
(
_,
command,
) = sys.argv
if command not in AVAILABLE_COMMANDS:
exit(
f"Invalid command: {command}. "
f'Use one of {", ".join(AVAILABLE_COMMANDS)}.'
)
except Exception as e:
exit(
f"Invalid arguments: {sys.argv}. Please provide one argument with one of"
f'the values: {", ".join(AVAILABLE_COMMANDS)}. Error was {e}.'
)
_configure_root_logger(command)
logger.info(f"Warming a Docker container for an Airflow {command}.")
# Get executor type
executor_type = os.environ.get("MWAA__CORE__EXECUTOR_TYPE", "CeleryExecutor")
environ = setup_environment_variables(command, executor_type)
await install_user_requirements(command, environ)
if command == "test-requirements":
print("Finished testing requirements")
return
await airflow_db_init(environ)
await increase_pool_size_if_default_size(environ)
if os.environ.get("MWAA__CORE__AUTH_TYPE", "").lower() == "testing":
# In "simple" auth mode, we create an admin user "airflow" with password
# "airflow". We use this to make the Docker Compose setup easy to use without
# having to create a user manually. Needless to say, this shouldn't be used in
# production environments.
await create_airflow_user(environ)
if executor_type.lower() == "celeryexecutor":
create_queue()
execute_command(command, environ, CONTAINER_START_TIME)
if __name__ == "__main__":
asyncio.run(main())
elif os.environ.get("MWAA__CORE__TESTING_MODE", "false") != "true":
logger.error("This module cannot be imported.")
sys.exit(1)