composer_local_dev/environment.py (855 lines of code) (raw):

# Copyright 2022 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import getpass import io import json import logging import os import pathlib import platform import tarfile import time import warnings from functools import cached_property from typing import Dict, List, Optional, Tuple import docker from docker import errors as docker_errors from docker.types import Mount from google.api_core import exceptions as api_exception from google.auth import exceptions as auth_exception from google.cloud import artifactregistry_v1 from google.cloud.orchestration.airflow import service_v1 from composer_local_dev import console, constants, errors, files, utils LOG = logging.getLogger(__name__) DOCKER_FILES = pathlib.Path(__file__).parent / "docker_files" def timeout_occurred(start_time): """Returns whether time since start is greater than OPERATION_TIMEOUT.""" return time.time() - start_time >= constants.OPERATION_TIMEOUT_SECONDS def get_image_mounts( env_path: pathlib.Path, dags_path: str, gcloud_config_path: str, kube_config_path: Optional[str], requirements: pathlib.Path, database_mounts: Dict[pathlib.Path, str] ) -> List[docker.types.Mount]: """ Return list of docker volumes to be mounted inside container. Following paths are mounted: - requirements for python packages to be installed - dags, plugins and data for paths which contains dags, plugins and data - gcloud_config_path which contains user credentials to gcloud - kube_config_path which contains user cluster credentials for K8S [Optional] - environment airflow sqlite db file location - database_mounts which contains the path for database mounts """ mount_paths = { requirements: "composer_requirements.txt", dags_path: "gcs/dags/", env_path / "plugins": "gcs/plugins/", env_path / "data": "gcs/data/", gcloud_config_path: ".config/gcloud", **database_mounts, } # Add kube_config_path only if it's provided if kube_config_path: mount_paths[kube_config_path] = ".kube/" return [ docker.types.Mount( source=str(source), target=f"{constants.AIRFLOW_HOME}/{target}", type="bind", ) for source, target in mount_paths.items() ] def get_default_environment_variables( dag_dir_list_interval: int, project_id: str, default_db_variables: Dict[str, str] ) -> Dict: """Return environment variables that will be set inside container.""" return { "AIRFLOW__API__AUTH_BACKEND": "airflow.api.auth.backend.default", "AIRFLOW__WEBSERVER__EXPOSE_CONFIG": "true", "AIRFLOW__CORE__LOAD_EXAMPLES": "false", "AIRFLOW__SCHEDULER__DAG_DIR_LIST_INTERVAL": dag_dir_list_interval, "AIRFLOW__CORE__DAGS_FOLDER": "/home/airflow/gcs/dags", "AIRFLOW__CORE__PLUGINS_FOLDER": "/home/airflow/gcs/plugins", "AIRFLOW__CORE__DATA_FOLDER": "/home/airflow/gcs/data", "AIRFLOW__WEBSERVER__RELOAD_ON_PLUGIN_CHANGE": "True", "COMPOSER_PYTHON_VERSION": "3", # By default, the container runs as the user `airflow` with UID 999. Set # this env variable to "True" to make it run as the current host user. "COMPOSER_CONTAINER_RUN_AS_HOST_USER": "False", "COMPOSER_HOST_USER_NAME": f"{getpass.getuser()}", "COMPOSER_HOST_USER_ID": f"{os.getuid() if platform.system() != 'Windows' else ''}", "AIRFLOW_HOME": "/home/airflow/airflow", "AIRFLOW_CONN_GOOGLE_CLOUD_DEFAULT": f"google-cloud-platform://?" f"extra__google_cloud_platform__project={project_id}&" f"extra__google_cloud_platform__scope=" f"https://www.googleapis.com/auth/cloud-platform", **default_db_variables, } def parse_env_variable( line: str, env_file_path: pathlib.Path ) -> Tuple[str, str]: """Parse line in format of key=value and return (key, value) tuple.""" try: key, value = line.split("=", maxsplit=1) except ValueError: raise errors.FailedToParseVariablesError(env_file_path, line) return key.strip(), value.strip() def load_environment_variables(env_dir_path: pathlib.Path) -> Dict: """ Load environment variable to be sourced in the local Composer environment. Raises an error if the variables.env file does not exist in the ``env_dir_path``. Args: env_dir_path (pathlib.Path): Path to the local composer environment. Returns: Dict: Environment variables. """ env_file_path = env_dir_path / "variables.env" LOG.info("Loading environment variables from %s", env_file_path) if not env_file_path.is_file(): raise errors.ComposerCliError( f"Environment variables file '{env_file_path}' not found." ) env_vars = dict() with open(env_file_path) as fp: for line in fp: line = line.strip() if not line or line.startswith("#"): continue key, value = parse_env_variable(line, env_file_path) if key in constants.NOT_MODIFIABLE_ENVIRONMENT_VARIABLES: LOG.warning( "'%s' environment variable cannot be set " "and will be ignored.", key, ) elif key in constants.STRICT_ENVIRONMENT_VARIABLES: possible_values = constants.STRICT_ENVIRONMENT_VARIABLES[key] if value not in possible_values: LOG.warning( "'%s' environment variable can be set " "to the one of the following values: '%s'", key, ",".join(possible_values), ) else: env_vars[key] = value else: env_vars[key] = value return env_vars def filter_not_modifiable_env_vars(env_vars: Dict) -> Dict: """ Filter out environment variables that cannot be modified by user. """ filtered_env_vars = dict() for key, val in env_vars.items(): if key in constants.NOT_MODIFIABLE_ENVIRONMENT_VARIABLES: LOG.warning( "'%s' environment variable cannot be set and will be ignored.", key, ) elif key in constants.STRICT_ENVIRONMENT_VARIABLES: possible_values = constants.STRICT_ENVIRONMENT_VARIABLES[key] if val not in possible_values: LOG.warning( "'%s' environment variable can be set " "to the one of the following values: '%s'", key, ",".join(possible_values), ) else: env_vars[key] = val else: filtered_env_vars[key] = val return filtered_env_vars def get_software_config_from_environment( project: str, location: str, environment: str ): """Get software configuration from the Composer environment. Args: project (str): Composer GCP Project ID. location (str): Location of the Composer environment. environment (str): Composer environment name. Returns: SoftwareConfig: Software configuration of the Composer environment. """ LOG.info("Getting Cloud Composer environment configuration.") client = service_v1.EnvironmentsClient() name = f"projects/{project}/locations/{location}/environments/{environment}" request = service_v1.GetEnvironmentRequest(name=name) LOG.debug(f"GetEnvironmentRequest: %s", request) try: response = client.get_environment(request=request) except api_exception.GoogleAPIError as err: raise errors.ComposerCliError( constants.COMPOSER_SOFTWARE_CONFIG_API_ERROR.format(err=str(err)) ) LOG.debug(f"GetEnvironmentResponse: %s", response) return response.config.software_config def parse_airflow_override_to_env_var(airflow_override: str) -> str: """ Parse airflow override variable name in format of section-key to AIRFLOW__SECTION__KEY. """ section, key = airflow_override.split("-", maxsplit=1) return f"AIRFLOW__{section.upper()}__{key.upper()}" def get_airflow_overrides(software_config): """ Returns dictionary with environment variable names and their values mapped from Airflow overrides in Composer Software Config. """ return { parse_airflow_override_to_env_var(k): v for k, v in software_config.airflow_config_overrides.items() } def get_env_variables(software_config): """ Returns dictionary with environment variable names (with unset values) mapped from Airflow environment variables in Composer Software Config. """ return {k: "" for k, _ in software_config.env_variables.items()} def assert_image_exists(image_version: str): """Asserts that image version exists. Raises if the image does not exist. Warns if the API error occurs, or we cannot access to Artifact Registry. Args: image_version: Image version in format of 'composer-x.y.z-airflow-a.b.c' """ airflow_v, composer_v = utils.get_airflow_composer_versions(image_version) image_tag = utils.get_image_version_tag(airflow_v, composer_v) LOG.info("Asserting that %s composer image version exists", image_tag) image_url = constants.ARTIFACT_REGISTRY_IMAGE_URL.format( airflow_v=airflow_v, composer_v=composer_v ) client = artifactregistry_v1.ArtifactRegistryClient() request = artifactregistry_v1.GetTagRequest(name=image_url) LOG.debug(f"GetTagRequest for %s: %s", image_tag, str(request)) try: client.get_tag(request=request) except api_exception.NotFound: raise errors.ImageNotFoundError(image_version=image_tag) from None except api_exception.PermissionDenied: warnings.warn( constants.IMAGE_TAG_PERMISSION_DENIED_WARN.format( image_tag=image_tag ) ) except ( auth_exception.GoogleAuthError, api_exception.GoogleAPIError, ) as err: raise errors.InvalidAuthError(err) def get_docker_image_tag_from_image_version(image_version: str) -> str: """ Parse image version to Airflow and Composer versions and return image tag with those versions if it exists. Args: image_version: Image version in format of 'composer-x.y.z-airflow-a.b.c' Returns: Composer image tag in Artifact Registry """ airflow_v, composer_v = utils.get_airflow_composer_versions(image_version) return constants.DOCKER_REGISTRY_IMAGE_TAG.format( airflow_v=airflow_v, composer_v=composer_v ) def is_mount_permission_error(error: docker_errors.APIError) -> bool: """Checks if error is possibly a Docker mount permission error.""" return ( error.is_client_error() and error.response.status_code == constants.BAD_REQUEST_ERROR_CODE and "invalid mount config" in error.explanation ) def copy_to_container(container, src: pathlib.Path) -> None: """Copy entrypoint file to Docker container.""" logging.debug("Copying entrypoint file to Docker container.") stream = io.BytesIO() with tarfile.open(fileobj=stream, mode="w|") as tar, open(src, "rb") as f: info = tar.gettarinfo(fileobj=f) info.name = src.name tar.addfile(info, f) container.put_archive(constants.AIRFLOW_HOME, stream.getvalue()) class EnvironmentStatus: def __init__(self, name: str, version: str, status: str): self.name = name self.version = version self.status = status.capitalize() def get_image_version(env): """ Return environment image version. If the environment is running, get image version from the container tag. Otherwise, get image version from the configuration. """ try: container = env.get_container(env.container_name) except errors.EnvironmentNotRunningError: logging.debug( constants.IMAGE_VERSION_CONTAINER_MISSING.format(env_name=env.name) ) return env.image_version if not container.image.tags: LOG.warning( constants.IMAGE_VERSION_TAG_MISSING.format(env_name=env.name) ) return env.image_version tag = container.image.tags[0] image_tag = tag.split(":")[-1] airflow_v, composer_v = utils.get_airflow_composer_versions(image_tag) airflow_v = utils.format_airflow_version_dotted(airflow_v) return utils.get_image_version_tag(airflow_v, composer_v) def get_environments_status( envs: List[pathlib.Path], ) -> List[EnvironmentStatus]: """Get list of environment statuses.""" environments_status = [] for env_path in envs: try: env = Environment.load_from_config(env_path, None) env_status = env.status() image_version = get_image_version(env) except errors.InvalidConfigurationError: env_status = "Could not parse the config" image_version = "x" environment_status = EnvironmentStatus( env_path.name, image_version, env_status ) environments_status.append(environment_status) return environments_status class EnvironmentConfig: def __init__(self, env_dir_path: pathlib.Path, port: Optional[int]): self.env_dir_path = env_dir_path self.config = self.load_configuration_from_file() self.project_id = self.get_str_param("composer_project_id") self.image_version = self.get_str_param("composer_image_version") self.location = self.get_str_param("composer_location") self.dags_path = self.get_str_param("dags_path") self.dag_dir_list_interval = self.parse_int_param( "dag_dir_list_interval", allowed_range=(0,) ) self.port = ( port if port is not None else self.parse_int_param("port", allowed_range=(0, 65536)) ) self.database_engine = self.get_str_param("database_engine") def load_configuration_from_file(self) -> Dict: """ Load environment configuration from json file. Returns: Dict: Environment configuration dictionary. """ config_path = self.env_dir_path / "config.json" LOG.info("Loading configuration file from %s", config_path) if not config_path.is_file(): raise errors.ComposerCliError( f"Configuration file '{config_path}' not found." ) with open(config_path) as fp: try: config = json.load(fp) except json.JSONDecodeError as err: raise errors.FailedToParseConfigError(config_path, err) return config def get_str_param(self, name: str): """ Get parameter from the config. Raises an error if the parameter does not exist in the config. """ try: return self.config[name] except KeyError: raise errors.MissingRequiredParameterError(name) from None def parse_int_param( self, name: str, allowed_range: Optional[Tuple[int, int]] = None, ): """ Get parameter from the config and convert it to integer. Raises an error if the parameter value is not a valid integer. Optional ``allowed_range`` argument can be used to validate if parameter value is in the given range. Args: name: Name of the parameter in the config allowed_range: Tuple containing allowed range of values Returns: Parameter value converted to integer """ try: value = self.get_str_param(name) value = int(value) except ValueError as err: raise errors.FailedToParseConfigParamIntError(name, value) from None if allowed_range is None: return value if value < allowed_range[0] or ( len(allowed_range) > 1 and value > allowed_range[1] ): raise errors.FailedToParseConfigParamIntRangeError( name, value, allowed_range ) return value class Environment: def __init__( self, env_dir_path: pathlib.Path, project_id: str, image_version: str, location: str, dags_path: Optional[str], dag_dir_list_interval: int = 10, database_engine: str = constants.DatabaseEngine.sqlite3, port: Optional[int] = None, pypi_packages: Optional[Dict] = None, environment_vars: Optional[Dict] = None, ): self.name = env_dir_path.name self.container_name = f"{constants.CONTAINER_NAME}-{self.name}" self.db_container_name = f"{constants.DB_CONTAINER_NAME}-{self.name}" self.docker_network_name = f"{constants.DOCKER_NETWORK_NAME}-{self.name}" self.env_dir_path = env_dir_path self.airflow_db = self.env_dir_path / "airflow.db" self.entrypoint_file = DOCKER_FILES / "entrypoint.sh" self.run_file = DOCKER_FILES / "run_as_user.sh" self.requirements_file = self.env_dir_path / "requirements.txt" self.project_id = project_id self.image_version = image_version self.image_tag = get_docker_image_tag_from_image_version(image_version) self.db_image_tag = 'postgres:14-alpine' self.airflow_db_folder = self.env_dir_path / "postgresql_data" self.location = location self.dags_path = files.resolve_dags_path(dags_path, env_dir_path) self.dag_dir_list_interval = dag_dir_list_interval self.database_engine = database_engine self.is_database_sqlite3 = self.database_engine == constants.DatabaseEngine.sqlite3 self.port: int = port if port is not None else 8080 self.pypi_packages = ( pypi_packages if pypi_packages is not None else dict() ) self.environment_vars = ( environment_vars if environment_vars is not None else dict() ) self.docker_client = self.get_client() def get_client(self): try: return docker.from_env() except docker.errors.DockerException as err: logging.debug("Docker not found.", exc_info=True) raise errors.DockerNotAvailableError(err) from None def get_container( self, container_name: str, assert_running: bool = False, ignore_not_found: bool = False ): """ Returns created docker container and raises when it's not created. Args: container_name: name of the container assert_running: assert that container is running ignore_not_found: change the behaviour of raising error in case of not found the container """ try: container = self.docker_client.containers.get(container_name) if ( assert_running and container.status != constants.ContainerStatus.RUNNING ): raise errors.EnvironmentNotRunningError() from None return container except docker_errors.NotFound: logging.debug("Container not found.", exc_info=True) if not ignore_not_found: raise errors.EnvironmentNotFoundError() from None @classmethod def load_from_config(cls, env_dir_path: pathlib.Path, port: Optional[int]): """Create local environment using 'config.json' configuration file.""" config = EnvironmentConfig(env_dir_path, port) environment_vars = load_environment_variables(env_dir_path) return cls( env_dir_path=env_dir_path, project_id=config.project_id, image_version=config.image_version, location=config.location, dags_path=config.dags_path, dag_dir_list_interval=config.dag_dir_list_interval, port=config.port, database_engine=config.database_engine, environment_vars=environment_vars, ) @classmethod def from_source_environment( cls, source_environment: str, project: str, location: str, env_dir_path: pathlib.Path, web_server_port: Optional[int], dags_path: Optional[str], database_engine: str, ): """ Create Environment using configuration retrieved from Composer environment. """ software_config = get_software_config_from_environment( project, location, source_environment ) pypi_packages = {k: v for k, v in software_config.pypi_packages.items()} env_variables = get_env_variables(software_config) airflow_overrides = get_airflow_overrides(software_config) env_variables.update(airflow_overrides) env_variables = filter_not_modifiable_env_vars(env_variables) return cls( env_dir_path=env_dir_path, project_id=project, image_version=software_config.image_version, location=location, dags_path=dags_path, dag_dir_list_interval=10, port=web_server_port, pypi_packages=pypi_packages, environment_vars=env_variables, database_engine=database_engine, ) def pypi_packages_to_requirements(self): """Create requirements file using environment PyPi packagest list.""" reqs = sorted( f"{key}{value}" for key, value in self.pypi_packages.items() ) reqs_lines = "\n".join(reqs) with open(self.env_dir_path / "requirements.txt", "w") as fp: fp.write(reqs_lines) def environment_vars_to_env_file(self): """ Write fetched environment variables keys to `variables.env` file. """ env_vars = sorted( f"# {key}=" for key, _ in self.environment_vars.items() ) env_vars_lines = "\n".join(env_vars) with open(self.env_dir_path / "variables.env", "w") as fp: fp.write(env_vars_lines) def assert_requirements_exist(self): """Asserts that PyPi requirements file exist in environment directory.""" req_file = self.env_dir_path / "requirements.txt" if not req_file.is_file(): raise errors.ComposerCliError(f"Missing '{req_file}' file.") def write_environment_config_to_config_file(self): """Saves environment configuration to config.json file.""" config = { "composer_image_version": self.image_version, "composer_location": self.location, "composer_project_id": self.project_id, "dags_path": self.dags_path, "dag_dir_list_interval": int(self.dag_dir_list_interval), "port": int(self.port), "database_engine": self.database_engine, } with open(self.env_dir_path / "config.json", "w") as fp: json.dump(config, fp, indent=4) @cached_property def database_extras(self) -> Dict[str, Dict]: env_path = self.env_dir_path extras = { constants.DatabaseEngine.sqlite3: { "mounts": { "folders": {}, "files": { env_path / "airflow.db": "airflow/airflow.db", } }, "env_vars": {}, "ports": {}, }, constants.DatabaseEngine.postgresql: { "mounts": { "folders": { env_path / "postgresql_data": "/var/lib/postgresql/data", }, "files": { env_path / ".keep": "airflow/.keep", } }, "env_vars": { "PGDATA": "/var/lib/postgresql/data/pgdata", "POSTGRES_USER": "postgres", "POSTGRES_PASSWORD": "airflow", "POSTGRES_DB": "airflow", "AIRFLOW__DATABASE__SQL_ALCHEMY_CONN": f"postgresql+psycopg2://postgres:airflow@{self.db_container_name}:5432/airflow", }, "ports": { f"5432/tcp": "25432", }, }, } if self.database_engine in extras: return extras[self.database_engine] return extras[constants.DatabaseEngine.sqlite3] def create_container(self, **kwargs): try: return self.docker_client.containers.create(**kwargs) except docker_errors.APIError as err: logging.debug( "Received docker API error when creating container.", exc_info=True, ) if err.status_code == constants.CONFLICT_ERROR_CODE: raise errors.EnvironmentAlreadyRunningError( self.name ) from None raise def create_docker_container(self): """Creates docker container. Raises when docker container with the same name already exists. """ LOG.debug("Creating container") db_extras = self.database_extras grouped_db_mounts = db_extras["mounts"] db_mounts = { **grouped_db_mounts['files'], **grouped_db_mounts['folders'], } mounts = get_image_mounts( self.env_dir_path, self.dags_path, utils.resolve_gcloud_config_path(), utils.resolve_kube_config_path(), self.requirements_file, db_mounts, ) db_vars = db_extras["env_vars"] default_vars = get_default_environment_variables( self.dag_dir_list_interval, self.project_id, db_vars ) env_vars = {**default_vars, **self.environment_vars} if ( platform.system() == "Windows" and env_vars["COMPOSER_CONTAINER_RUN_AS_HOST_USER"] == "True" ): raise Exception( "COMPOSER_CONTAINER_RUN_AS_HOST_USER must be set to `False` on Windows" ) ports = { f"8080/tcp": self.port, } entrypoint = f"sh {constants.ENTRYPOINT_PATH}" memory_limit = constants.DOCKER_CONTAINER_MEMORY_LIMIT try: container = self.create_container( image=self.image_tag, name=self.container_name, entrypoint=entrypoint, environment=env_vars, mounts=mounts, ports=ports, mem_limit=memory_limit, detach=True, ) except docker_errors.ImageNotFound: LOG.debug( "Failed to create container with ImageNotFound error. " "Pulling the image..." ) self.pull_image() container = self.create_container( image=self.image_tag, name=self.container_name, entrypoint=entrypoint, environment=env_vars, mounts=mounts, ports=ports, mem_limit=memory_limit, detach=True, ) except docker_errors.APIError as err: error = f"Failed to create container with an error: {err}" if is_mount_permission_error(err): error += constants.DOCKER_PERMISSION_ERROR_HINT.format( docs_faq_url=constants.COMPOSER_FAQ_MOUNTING_LINK ) raise errors.EnvironmentStartError(error) copy_to_container(container, self.entrypoint_file) copy_to_container(container, self.run_file) return container def create_db_docker_container(self): """Creates docker container for database. Raises when docker container with the same name already exists. """ db_extras = self.database_extras grouped_db_mounts = db_extras["mounts"] db_mounts = { **grouped_db_mounts['files'], **grouped_db_mounts['folders'], } mounts = get_image_mounts( self.env_dir_path, self.dags_path, utils.resolve_gcloud_config_path(), utils.resolve_kube_config_path(), self.requirements_file, db_mounts, ) db_vars = db_extras["env_vars"] db_ports = db_extras["ports"] memory_limit = constants.DOCKER_CONTAINER_MEMORY_LIMIT self.docker_client.images.pull(self.db_image_tag) LOG.info("DB_VARS") LOG.info(db_vars) try: container = self.create_container( image=self.db_image_tag, name=self.db_container_name, environment=db_vars, mounts=mounts, ports=db_ports, mem_limit=memory_limit, detach=True, ) return container except docker_errors.APIError as err: error = f"Failed to create container for database with an error: {err}" if is_mount_permission_error(err): error += constants.DOCKER_PERMISSION_ERROR_HINT.format( docs_faq_url=constants.COMPOSER_FAQ_MOUNTING_LINK ) raise errors.EnvironmentStartError(error) def get_docker_network(self): try: return self.docker_client.networks.get(self.docker_network_name) except docker.errors.NotFound as _: return self.docker_client.networks.create(self.docker_network_name) except docker.errors.APIError as err: error = f"Failed to create/get network an error: {err}" raise errors.EnvironmentStartError(error) def pull_image(self): """Pull Composer docker image.""" try: # TODO: (b/237054183): Print detailed status (progress bar of image pulling) with console.get_console().status(constants.PULL_IMAGE_MSG): self.docker_client.images.pull(self.image_tag) except (docker_errors.ImageNotFound, docker_errors.APIError): logging.debug("Failed to pull composer image.", exc_info=True) raise errors.ImageNotFoundError(self.image_version) from None def pull_db_image(self): try: # TODO: (b/237054183): Print detailed status (progress bar of image pulling) with console.get_console().status(constants.DB_PULL_IMAGE_MSG): self.docker_client.images.pull(self.db_image_tag) except (docker_errors.ImageNotFound, docker_errors.APIError): logging.debug(f"Failed to pull database image ({self.db_image_tag}).", exc_info=True) raise errors.ImageNotFoundError(self.db_image_tag) from None def create_database_files(self): db_extras = self.database_extras db_mounts = db_extras["mounts"] for host_path in db_mounts['files'].keys(): files.create_empty_file(host_path, skip_if_exist=False) for host_path in db_mounts['folders'].keys(): files.create_empty_folder(host_path) def create(self): """Creates Composer local environment. Directory with environment name will be created under `composer` path and environment configuration will be saved to config.json and requirements.txt files. """ assert_image_exists(self.image_version) files.create_environment_directories(self.env_dir_path, self.dags_path) self.create_database_files() self.write_environment_config_to_config_file() self.pypi_packages_to_requirements() self.environment_vars_to_env_file() console.get_console().print( constants.CREATE_MESSAGE.format( env_dir=self.env_dir_path, env_name=self.name, config_path=self.env_dir_path / "config.json", requirements_path=self.env_dir_path / "requirements.txt", env_variables_path=self.env_dir_path / "variables.env", dags_path=self.dags_path, ) ) def assert_container_is_active(self, container_name): """ Asserts docker container is in running or created state (is active). """ status = self.get_container(container_name).status if status not in ( constants.ContainerStatus.RUNNING, constants.ContainerStatus.CREATED, ): raise errors.EnvironmentStartError() def wait_for_db_start(self): start_time = time.time() with console.get_console().status("[bold green]Starting database..."): self.assert_container_is_active(self.db_container_name) for line in self.get_container(self.db_container_name).logs(stream=True, timestamps=True): line = line.decode('utf-8').strip() console.get_console().print(line) if "database system is ready to accept connections" in line: start_duration = time.time() - start_time LOG.info("Database is started in %.2f seconds", start_duration) return if timeout_occurred(start_time): raise errors.EnvironmentStartTimeoutError() self.assert_container_is_active(self.db_container_name) raise errors.EnvironmentStartError() def wait_for_start(self): """ Poll environment logs to see if it is ready. When Airflow scheduler starts, it prints 'searching for files' in the logs. We are using it as marker of the environment readiness. """ start_time = time.time() with console.get_console().status("[bold green]Starting environment..."): self.assert_container_is_active(self.container_name) for line in self.get_container(self.container_name).logs(stream=True, timestamps=True): line = line.decode("utf-8").strip() console.get_console().print(line) # TODO: (b/234684803) Improve detecting container readiness if "Searching for files" in line: start_duration = time.time() - start_time LOG.info( "Environment started in %.2f seconds", start_duration ) return if timeout_occurred(start_time): raise errors.EnvironmentStartTimeoutError() self.assert_container_is_active(self.container_name) raise errors.EnvironmentStartError() def get_or_create_container(self, container_name: str): """ Get existing container or create new container if it does not exist. """ try: return self.get_container(container_name) except errors.EnvironmentNotRunningError: if container_name == self.container_name: # if the given container name is the main container return self.create_docker_container() else: # if the given container name is db container return self.create_db_docker_container() def start_container(self, container_name: str = None, assert_not_running=True): """ Start the given container """ container = self.get_or_create_container(container_name) if ( assert_not_running and container.status == constants.ContainerStatus.RUNNING ): raise errors.EnvironmentAlreadyRunningError(self.name) from None try: container.start() return container except docker.errors.APIError as err: logging.debug( "Starting environment failed with Docker API error.", exc_info=True, ) # TODO: (b/234552960) Test on different OS/language setting if ( err.status_code == constants.SERVER_ERROR_CODE and "port is already allocated" in str(err) ): container.remove() raise errors.ComposerCliError( constants.PORT_IN_USE_ERROR.format(port=self.port) ) error = f"Environment ({container_name}) failed to start with an error: {err}" raise errors.EnvironmentStartError(error) from None def start(self, assert_not_running=True): """Starts local composer environment. Before starting we are asserting that are required files in the environment directory. The docker container is created and started. This operation will raise an error if we are trying to use port that is already allocated. Started environment is polled until Airflow scheduler starts. """ assert_image_exists(self.image_version) self.assert_requirements_exist() files.assert_dag_path_exists(self.dags_path) self.create_database_files() db_path = self.airflow_db if self.is_database_sqlite3 else self.airflow_db_folder files.fix_file_permissions( entrypoint=self.entrypoint_file, run=self.run_file, requirements=self.requirements_file, db_path=db_path, ) files.fix_line_endings( entrypoint=self.entrypoint_file, run=self.run_file, requirements=self.requirements_file, ) if not self.is_database_sqlite3: LOG.info(f"Database engine is selected as {self.database_engine}. The container will start before") db_container = self.start_container(self.db_container_name, False) self.wait_for_db_start() self.ensure_container_is_attached_to_network(db_container) LOG.info(f"Database started!") container = self.start_container(self.container_name, assert_not_running) self.ensure_container_is_attached_to_network(container) self.wait_for_start() self.print_start_message() def ensure_container_is_attached_to_network(self, container): network = self.get_docker_network() existing_containers = [c.name for c in network.containers] if container.name in existing_containers: network.disconnect(container.name) network.connect(container) def print_start_message(self): """Print the start message after the environment is up and ready.""" console.get_console().print( constants.START_MESSAGE.format( env_name=self.name, dags_path=self.dags_path, port=self.port, ) ) def logs(self, follow, max_lines): """ Fetch and print logs from the running composer local environment. Container `logs` method returns blocking generator if follow is True, and byte-decoded string if follow is False. That's why we need two methods of handling and decoding logs. """ log_lines = self.get_container(self.container_name).logs( timestamps=True, stream=follow, follow=follow, tail=max_lines, ) if follow: LOG.debug( "Printing previous %s lines and following output " "from the container logs:", max_lines, ) for line in log_lines: line = line.decode("utf-8").strip() console.get_console().print(line) else: LOG.debug( "Printing previous %s lines from container logs:", max_lines ) log_lines = log_lines.decode("utf-8") for line in log_lines.split("\n"): console.get_console().print(line) def stop(self, remove_container=False): """ Stops the local composer environment. By default container is not removed. """ with console.get_console().status( f"[bold green]Stopping composer local environment..." ): db_container = self.get_container(self.db_container_name, ignore_not_found=True) if db_container: db_container.stop() if remove_container: db_container.remove() container = self.get_container(self.container_name, ignore_not_found=True) if container: container.stop() if remove_container: container.remove() if remove_container: network = self.get_docker_network() network.remove() def restart(self): """ Restarts the local composer environment. This operation will stop and remove container if it is running. Then it will start it again. """ try: self.stop(remove_container=True) except errors.EnvironmentNotRunningError: pass self.start(assert_not_running=False) def status(self) -> str: """Get status of the local composer environment.""" try: return self.get_container(self.container_name).status except errors.EnvironmentNotRunningError: return "Not started" def run_airflow_command(self, command: List) -> None: """ Run command list in the environment container. The commands are prefixed with `airflow`. """ container = self.get_container(self.container_name, assert_running=True) command.insert(0, "airflow") command.insert(0, "/home/airflow/run_as_user.sh") result = container.exec_run(cmd=command) console.get_console().print(result.output.decode()) def get_host_port(self) -> int: """ Return port of the running environment. If it fails to retrieve it, return port from the environment configuration. """ try: return self.get_container(self.container_name).ports["8080/tcp"][0]["HostPort"] except (IndexError, KeyError): LOG.info(constants.FAILED_TO_GET_DOCKER_PORT_WARN) return self.port def prepare_env_description(self, env_status: str) -> str: """Prepare description of the local composer environment.""" if env_status == constants.ContainerStatus.RUNNING: port = self.get_host_port() web_url = constants.WEBSERVER_URL_MESSAGE.format(port=port) else: web_url = "" env_status = utils.wrap_status_in_color(env_status) return constants.DESCRIBE_ENV_MESSAGE.format( name=self.name, state=env_status, web_url=web_url, image_version=self.image_version, dags_path=self.dags_path, gcloud_path=utils.resolve_gcloud_config_path(), ) + (constants.KUBECONFIG_PATH_MESSAGE.format(kube_config_path=utils.resolve_kube_config_path()) if utils.resolve_kube_config_path() else "no file") + constants.FINAL_ENV_MESSAGE def describe(self) -> None: """Describe the local composer environment.""" env_status = self.status() desc = self.prepare_env_description(env_status) console.get_console().print(desc) def remove(self, force, force_error): containers = {self.container_name} if not self.is_database_sqlite3: containers.add(self.db_container_name) for container_name in containers: container = self.get_container(container_name, ignore_not_found=True) if container is not None: if container.status == constants.ContainerStatus.RUNNING: if not force: raise force_error container.stop() container.remove() network = self.get_docker_network() network.remove()