composer_local_dev/utils.py (211 lines of code) (raw):
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import logging
import os
import pathlib
import re
import subprocess
import sys
from functools import total_ordering
from typing import List, Optional, Tuple
import click
import rich.box
import rich.table
from google.api_core import exceptions as api_exception
from google.cloud.orchestration.airflow import service_v1
from rich.logging import RichHandler
from composer_local_dev import constants, errors
LOG = logging.getLogger(__name__)
# The name of the Cloud SDK shell script
_CLOUD_CLI_POSIX_COMMAND = "gcloud"
_CLOUD_CLI_WINDOWS_COMMAND = "gcloud.cmd"
# The command to get the Cloud SDK configuration
_CLOUD_CLI_CONFIG_COMMAND = "config config-helper --format json"
LOG_FORMAT = "%(name)s:%(message)s"
LOG_DATE_FORMAT = "%Y-%m-%dT%H:%M:%S"
@total_ordering
class ImageVersion:
def __init__(self, image_version: service_v1.types.ImageVersion):
self.image_version_id = image_version.image_version_id
self.release_date = ImageVersionReleaseDate(image_version.release_date)
def __eq__(self, other):
return (
self.release_date == other.release_date
and self.image_version_id == other.image_version_id
)
def __lt__(self, other):
if self.release_date == other.release_date:
return self.image_version_id < other.image_version_id
return self.release_date < other.release_date
@total_ordering
class ImageVersionReleaseDate:
def __init__(self, release_date):
self.release_date = release_date
def __str__(self):
return (
f"{self.release_date.day:0>2}/{self.release_date.month:0>2}/"
f"{self.release_date.year}"
)
def __eq__(self, other):
return self.release_date == other.release_date
def __lt__(self, other):
return (
self.release_date.year,
self.release_date.month,
self.release_date.day,
) < (
other.release_date.year,
other.release_date.month,
other.release_date.day,
)
def is_windows_os() -> bool:
return os.name == "nt"
def is_linux_os() -> bool:
return sys.platform.startswith("linux")
def gcloud_cmd() -> str:
if is_windows_os():
return _CLOUD_CLI_WINDOWS_COMMAND
return _CLOUD_CLI_POSIX_COMMAND
def get_project_id() -> Optional[str]:
"""Gets the project ID from the Cloud CLI."""
try:
output = subprocess.run(
[gcloud_cmd()] + _CLOUD_CLI_CONFIG_COMMAND.split(),
check=True,
capture_output=True,
text=True,
).stdout
LOG.debug("Cloud CLI output: %s", output)
except (subprocess.CalledProcessError, OSError, IOError) as err:
logging.debug(
"Failed to get project ID from the Cloud CLI.", exc_info=True
)
raise errors.InvalidAuthError(err)
try:
configuration = json.loads(output)
except ValueError as err:
raise errors.ComposerCliError(
f"Failed to decode gcloud CLI configuration: {err}"
) from None
try:
project_id = configuration["configuration"]["properties"]["core"][
"project"
]
LOG.info("Using GCP project %s", project_id)
return project_id
except KeyError:
raise errors.ComposerCliError(
"gcloud configuration is missing project id."
) from None
def resolve_gcloud_config_path() -> str:
"""
Returns the absolute path the Cloud CLI's configuration directory.
"""
if constants.CLOUD_CLI_CONFIG_PATH_ENV in os.environ:
return os.environ[constants.CLOUD_CLI_CONFIG_PATH_ENV]
if is_windows_os() and "APPDATA" in os.environ:
config_path = pathlib.Path(os.environ["APPDATA"], "gcloud")
else:
config_path = pathlib.Path("~/.config/gcloud").expanduser()
# TODO (b/234553956) Check if found directory is correct gcloud config
if config_path.is_dir():
return str(config_path)
raise errors.ComposerCliError(constants.GCLOUD_CONFIG_NOT_FOUND_ERROR)
def resolve_kube_config_path() -> Optional[str]:
"""
Returns the absolute path the kubectl CLI's configuration directory from environmental variable if provided.
"""
return os.environ.get(constants.KUBECONFIG_PATH_ENV)
def assert_environment_name_is_valid(env_name: str):
"""
Asserts that environment name is a valid name.
Valid name uses only characters from [A-Za-z0-9_-] range and its length
is between 3 and 40 characters.
"""
if len(env_name) < 3:
raise errors.ComposerCliError(
constants.ENVIRONMENT_NAME_TOO_SHORT_ERROR.format(env_name=env_name)
)
if len(env_name) > 40:
raise errors.ComposerCliError(
constants.ENVIRONMENT_NAME_TOO_LONG_ERROR.format(env_name=env_name)
)
if re.search("[^A-Za-z0-9_-]", env_name):
raise errors.ComposerCliError(
constants.ENVIRONMENT_NAME_NOT_VALID_ERROR.format(env_name=env_name)
)
def get_airflow_composer_versions(image_version: str) -> Tuple[str, str]:
"""
Get airflow and composer versions from image_version.
Args:
image_version: Image version in format of 'composer-x.y.z-airflow-a.b.c'
Returns:
airflow_v: Airflow version (in x-y-z format).
composer_v: Composer version (in a.b.c format).
"""
version_match = re.match(constants.IMAGE_VERSION_PATTERN, image_version)
if not version_match:
raise errors.ComposerCliError(constants.INVALID_IMAGE_VERSION_ERROR)
composer_v, airflow_v = version_match.group(1), version_match.group(2)
airflow_v = airflow_v.replace(".", "-")
return airflow_v, composer_v
def format_airflow_version_dotted(airflow_v: str) -> str:
"""Format Airflow version to use '.' instead of '-'."""
return airflow_v.replace("-", ".")
def get_image_version_tag(airflow_v: str, composer_v: str) -> str:
"""
Returns Composer image version tag created from
Airflow and Composer versions.
"""
return f"composer-{composer_v}-airflow-{airflow_v}"
def get_environment_status_table(envs_status: List) -> rich.table.Table:
"""Get Environment status table to print."""
table = rich.table.Table(box=rich.box.MINIMAL)
for col in ("Environment Name", "Version*", "State"):
table.add_column(col)
for env_status in envs_status:
table.add_row(env_status.name, env_status.version, env_status.status)
return table
def filter_image_versions(image_versions: List) -> List:
"""
Filter out Composer 1 versions out of list of image versions.
"""
return [
version
for version in image_versions
if not version.image_version_id.startswith("composer-1")
]
def sort_and_limit_image_versions(image_versions: List, limit: int) -> List:
"""
Sort image versions by date and return only first ``limit`` image versions.
"""
return sorted(image_versions, reverse=True)[:limit]
def get_image_versions_table(image_versions: List) -> rich.table.Table:
"""Get image versions table to print."""
table = rich.table.Table(box=rich.box.MINIMAL)
for col in ("Image version", "Release Date"):
table.add_column(col)
for image_version in image_versions:
table.add_row(
image_version.image_version_id, str(image_version.release_date)
)
return table
def wrap_status_in_color(status: str) -> str:
"""
Wrap container status in color tags.
Used to pretty print container status in the cli."""
status_color = (
"green" if status == constants.ContainerStatus.RUNNING else "red"
)
return f"[bold {status_color}]{status}[/]"
def get_log_level(verbose: bool, debug: bool):
"""
Get logging log level for our package depending on
verbose and debug flags values.
"""
if debug:
return logging.DEBUG
if verbose:
return logging.INFO
return logging.WARNING
def get_external_log_level(debug: bool):
"""
Get logging log level for external packages depending on debug flag value.
"""
if debug:
return logging.DEBUG
return logging.WARNING
def setup_logging(verbose: bool, debug: bool):
log_level = get_log_level(verbose, debug)
external_log_level = get_external_log_level(debug)
logging.basicConfig(
level=log_level,
format=LOG_FORMAT,
datefmt=LOG_DATE_FORMAT,
handlers=[RichHandler()],
)
logging.captureWarnings(True)
logging.getLogger("docker").setLevel(external_log_level)
logging.getLogger("urllib3").setLevel(external_log_level)
def get_image_versions(
project: str, location: str, include_past_releases: bool
):
"""
Query Composer API to get list of released image versions for given
project and location.
"""
client = service_v1.ImageVersionsClient()
parent = f"projects/{project}/locations/{location}"
request = service_v1.ListImageVersionsRequest(
parent=parent, include_past_releases=include_past_releases
)
try:
page_result = client.list_image_versions(request=request)
except api_exception.GoogleAPIError as err:
raise errors.ComposerCliError(
constants.LIST_VERSIONS_API_ERROR.format(err=str(err))
)
return [ImageVersion(response) for response in page_result]
def resolve_project_id(project_id: Optional[str]) -> str:
"""
Resolve optional ``project_id``. If ``project_id`` is None, retrieve its
value from Google Cloud CLI.
"""
if project_id is not None:
return project_id
LOG.info(
"Project id was not provided. It will be retrieved using Cloud CLI."
)
try:
return get_project_id()
except errors.ComposerCliError as err:
raise click.UsageError(
f"Please provide Google Cloud project id "
f"(using '-p' / '--project' option). Failed to retrieve "
f"project id from gcloud configuration:\n{err}"
)