aws_advanced_python_wrapper/utils/iam_utils.py (95 lines of code) (raw):

# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations from datetime import datetime from typing import TYPE_CHECKING, Dict, Optional import boto3 from aws_advanced_python_wrapper.errors import AwsWrapperError from aws_advanced_python_wrapper.utils.log import Logger from aws_advanced_python_wrapper.utils.messages import Messages from aws_advanced_python_wrapper.utils.rds_url_type import RdsUrlType from aws_advanced_python_wrapper.utils.rdsutils import RdsUtils from aws_advanced_python_wrapper.utils.telemetry.telemetry import \ TelemetryTraceLevel if TYPE_CHECKING: from aws_advanced_python_wrapper.hostinfo import HostInfo from aws_advanced_python_wrapper.plugin_service import PluginService from boto3 import Session from aws_advanced_python_wrapper.utils.properties import (Properties, WrapperProperties) logger = Logger(__name__) class IamAuthUtils: @staticmethod def get_iam_host(props: Properties, host_info: HostInfo): host = WrapperProperties.IAM_HOST.get(props) if WrapperProperties.IAM_HOST.get(props) else host_info.host IamAuthUtils.validate_iam_host(host) return host @staticmethod def validate_iam_host(host: str | None): if host is None: raise AwsWrapperError(Messages.get_formatted("IamAuthPlugin.InvalidHost", "[No host provided]")) utils = RdsUtils() rds_type = utils.identify_rds_type(host) if rds_type == RdsUrlType.OTHER or rds_type == RdsUrlType.IP_ADDRESS: raise AwsWrapperError(Messages.get_formatted("IamAuthPlugin.InvalidHost", host)) @staticmethod def get_port(props: Properties, host_info: HostInfo, dialect_default_port: int) -> int: default_port: int = WrapperProperties.IAM_DEFAULT_PORT.get_int(props) if default_port > 0: return default_port if host_info.is_port_specified(): return host_info.port return dialect_default_port @staticmethod def get_cache_key(user: Optional[str], hostname: Optional[str], port: int, region: Optional[str]) -> str: return f"{region}:{hostname}:{port}:{user}" @staticmethod def generate_authentication_token( plugin_service: PluginService, user: Optional[str], host_name: Optional[str], port: Optional[int], region: Optional[str], credentials: Optional[Dict[str, str]] = None, client_session: Optional[Session] = None) -> str: telemetry_factory = plugin_service.get_telemetry_factory() context = telemetry_factory.open_telemetry_context("fetch authentication token", TelemetryTraceLevel.NESTED) try: session = client_session if client_session else boto3.Session() if credentials is not None: client = session.client( 'rds', region_name=region, aws_access_key_id=credentials.get('AccessKeyId'), aws_secret_access_key=credentials.get('SecretAccessKey'), aws_session_token=credentials.get('SessionToken') ) else: client = session.client( 'rds', region_name=region ) token = client.generate_db_auth_token( DBHostname=host_name, Port=port, DBUsername=user ) client.close() logger.debug("IamAuthUtils.GeneratedNewAuthToken", token) return token except Exception as ex: context.set_success(False) context.set_exception(ex) raise ex finally: context.close_context() class TokenInfo: @property def token(self): return self._token @property def expiration(self): return self._expiration def __init__(self, token: str, expiration: datetime): self._token = token self._expiration = expiration def is_expired(self) -> bool: return datetime.now() > self._expiration