aws_advanced_python_wrapper/aws_secrets_manager_plugin.py (152 lines of code) (raw):

# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations from json import JSONDecodeError, loads from re import search from types import SimpleNamespace from typing import TYPE_CHECKING, Callable, Dict, Optional, Set, Tuple import boto3 from botocore.exceptions import ClientError, EndpointConnectionError if TYPE_CHECKING: from boto3 import Session from aws_advanced_python_wrapper.driver_dialect import DriverDialect from aws_advanced_python_wrapper.hostinfo import HostInfo from aws_advanced_python_wrapper.pep249 import Connection from aws_advanced_python_wrapper.plugin_service import PluginService from aws_advanced_python_wrapper.errors import AwsWrapperError from aws_advanced_python_wrapper.plugin import Plugin, PluginFactory from aws_advanced_python_wrapper.utils.log import Logger from aws_advanced_python_wrapper.utils.messages import Messages from aws_advanced_python_wrapper.utils.properties import (Properties, WrapperProperties) from aws_advanced_python_wrapper.utils.region_utils import RegionUtils from aws_advanced_python_wrapper.utils.telemetry.telemetry import \ TelemetryTraceLevel logger = Logger(__name__) class AwsSecretsManagerPlugin(Plugin): _SUBSCRIBED_METHODS: Set[str] = {"connect", "force_connect"} _SECRETS_ARN_PATTERN = r"^arn:aws:secretsmanager:(?P<region>[^:\n]*):[^:\n]*:([^:/\n]*[:/])?(.*)$" _secrets_cache: Dict[Tuple, SimpleNamespace] = {} _secret_key: Tuple = () @property def subscribed_methods(self) -> Set[str]: return self._SUBSCRIBED_METHODS def __init__(self, plugin_service: PluginService, props: Properties, session: Optional[Session] = None): self._plugin_service = plugin_service self._session = session secret_id = WrapperProperties.SECRETS_MANAGER_SECRET_ID.get(props) if not secret_id: raise AwsWrapperError( Messages.get_formatted("AwsSecretsManagerPlugin.MissingRequiredConfigParameter", WrapperProperties.SECRETS_MANAGER_SECRET_ID.name)) self._region_utils = RegionUtils() region: str = self._get_rds_region(secret_id, props) secrets_endpoint = WrapperProperties.SECRETS_MANAGER_ENDPOINT.get(props) self._secret_key: Tuple = (secret_id, region, secrets_endpoint) telemetry_factory = self._plugin_service.get_telemetry_factory() self._fetch_credentials_counter = telemetry_factory.create_counter("secrets_manager.fetch_credentials.count") def connect( self, target_driver_func: Callable, driver_dialect: DriverDialect, host_info: HostInfo, props: Properties, is_initial_connection: bool, connect_func: Callable) -> Connection: return self._connect(props, connect_func) def force_connect( self, target_driver_func: Callable, driver_dialect: DriverDialect, host_info: HostInfo, props: Properties, is_initial_connection: bool, force_connect_func: Callable) -> Connection: return self._connect(props, force_connect_func) def _connect(self, props: Properties, connect_func: Callable) -> Connection: secret_fetched: bool = self._update_secret() try: self._apply_secret_to_properties(props) return connect_func() except Exception as e: if not self._plugin_service.is_login_exception(error=e) or secret_fetched: raise AwsWrapperError( Messages.get_formatted("AwsSecretsManagerPlugin.ConnectException", e)) from e secret_fetched = self._update_secret(True) if secret_fetched: try: self._apply_secret_to_properties(props) return connect_func() except Exception as unhandled_error: raise AwsWrapperError( Messages.get_formatted("AwsSecretsManagerPlugin.UnhandledException", unhandled_error)) from unhandled_error raise AwsWrapperError(Messages.get_formatted("AwsSecretsManagerPlugin.FailedLogin", e)) from e def _update_secret(self, force_refetch: bool = False) -> bool: """ Called to update credentials from the cache, or from the AWS Secrets Manager service. :param force_refetch: Allows ignoring cached credentials and force fetches the latest credentials from the service. :return: `True`, if credentials were fetched from the service. """ telemetry_factory = self._plugin_service.get_telemetry_factory() context = telemetry_factory.open_telemetry_context("fetch credentials", TelemetryTraceLevel.NESTED) self._fetch_credentials_counter.inc() try: fetched: bool = False self._secret: Optional[SimpleNamespace] = AwsSecretsManagerPlugin._secrets_cache.get(self._secret_key) endpoint = self._secret_key[2] if not self._secret or force_refetch: try: self._secret = self._fetch_latest_credentials() if self._secret: AwsSecretsManagerPlugin._secrets_cache[self._secret_key] = self._secret fetched = True except (ClientError, AttributeError) as e: logger.debug("AwsSecretsManagerPlugin.FailedToFetchDbCredentials", e) raise AwsWrapperError( Messages.get_formatted("AwsSecretsManagerPlugin.FailedToFetchDbCredentials", e)) from e except JSONDecodeError as e: logger.debug("AwsSecretsManagerPlugin.JsonDecodeError", e) raise AwsWrapperError( Messages.get_formatted("AwsSecretsManagerPlugin.JsonDecodeError", e)) except EndpointConnectionError: logger.debug("AwsSecretsManagerPlugin.EndpointOverrideInvalidConnection", endpoint) raise AwsWrapperError( Messages.get_formatted("AwsSecretsManagerPlugin.EndpointOverrideInvalidConnection", endpoint)) except ValueError: logger.debug("AwsSecretsManagerPlugin.EndpointOverrideMisconfigured", endpoint) raise AwsWrapperError( Messages.get_formatted("AwsSecretsManagerPlugin.EndpointOverrideMisconfigured", endpoint)) return fetched except Exception as ex: context.set_success(False) context.set_exception(ex) raise ex finally: context.close_context() def _fetch_latest_credentials(self): """ Fetches the current credentials from AWS Secrets Manager service. :return: a Secret object containing the credentials fetched from the AWS Secrets Manager service. """ session = self._session if self._session else boto3.Session() client = session.client( 'secretsmanager', region_name=self._secret_key[1], endpoint_url=self._secret_key[2], ) secret = client.get_secret_value( SecretId=self._secret_key[0], ) client.close() return loads(secret.get("SecretString"), object_hook=lambda d: SimpleNamespace(**d)) def _apply_secret_to_properties(self, properties: Properties): """ Updates credentials in provided properties. Other plugins in the plugin chain may change them if needed. Eventually, credentials will be used to open a new connection in :py:class:`DefaultConnectionPlugin`. :param properties: Properties to store credentials. """ if self._secret: WrapperProperties.USER.set(properties, self._secret.username) WrapperProperties.PASSWORD.set(properties, self._secret.password) def _get_rds_region(self, secret_id: str, props: Properties) -> str: session = self._session if self._session else boto3.Session() region = self._region_utils.get_region(props, WrapperProperties.SECRETS_MANAGER_REGION.name, session=session) if region: return region match = search(self._SECRETS_ARN_PATTERN, secret_id) if match: region = match.group("region") if region: return self._region_utils.verify_region(region) else: raise AwsWrapperError( Messages.get_formatted("AwsSecretsManagerPlugin.MissingRequiredConfigParameter", WrapperProperties.SECRETS_MANAGER_REGION.name)) class AwsSecretsManagerPluginFactory(PluginFactory): def get_instance(self, plugin_service: PluginService, props: Properties) -> Plugin: return AwsSecretsManagerPlugin(plugin_service, props)