aws_msk_iam_sasl_signer/MSKAuthTokenProvider.py (108 lines of code) (raw):

# Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 import base64 import importlib.metadata import logging from datetime import datetime, timezone from urllib.parse import parse_qs, urlparse import boto3 import botocore.session from botocore.auth import SigV4QueryAuth from botocore.awsrequest import AWSRequest from botocore.config import Config from botocore.credentials import CredentialProvider, Credentials ENDPOINT_URL_TEMPLATE = "https://kafka.{}.amazonaws.com/" DEFAULT_TOKEN_EXPIRY_SECONDS = 900 DEFAULT_STS_SESSION_NAME = "MSKSASLDefaultSession" ACTION_TYPE = "Action" ACTION_NAME = "kafka-cluster:Connect" SIGNING_NAME = "kafka-cluster" USER_AGENT_KEY = "User-Agent" LIB_NAME = "aws-msk-iam-sasl-signer-python" def __get_user_agent__(): """ Builds the user-agent Returns: str: The user-agent identifying this signer library. """ return f"{LIB_NAME}/{importlib.metadata.version(LIB_NAME)}" def __load_default_credentials__(): """ Loads IAM credentials from default credentials chain. Returns: :class:`botocore.credentials.Credentials` object """ # Create a botocore session with default settings botocore_session = botocore.session.Session() return botocore_session.get_credentials() def __load_credentials_from_aws_profile__(aws_profile): """ Loads IAM credentials from named aws profile. Parameters: - aws_profile (str): The name of the AWS profile to use for the session. Returns: :class:`botocore.credentials.Credentials` object """ # Create a botocore session with an aws named profile botocore_session = botocore.session.Session(profile=aws_profile) return botocore_session.get_credentials() def __load_credentials_from_aws_role_arn__( role_arn, sts_session_name=DEFAULT_STS_SESSION_NAME ): """ Loads IAM credentials from an aws role arn. At each refresh it creates a new sts client with a global endpoint. If this is not the desired behavior, please use your own credentials provider. Parameters: - role_arn (str): The ARN of the IAM role to assume for the session. - sts_session_name (str): The sts session name for assumed role's session. Returns: :class:`botocore.credentials.Credentials` object """ # Create sts client sts_client = boto3.client("sts", config=Config()) assumed_role = sts_client.assume_role( RoleArn=role_arn, RoleSessionName=sts_session_name ) assumed_role_credentials = assumed_role["Credentials"] return Credentials( assumed_role_credentials["AccessKeyId"], assumed_role_credentials["SecretAccessKey"], assumed_role_credentials["SessionToken"], ) def __load_credentials_from_aws_credentials_provider__( aws_credentials_provider ): """ Loads IAM credentials from aws credentials provider. Parameters: - aws_credentials_provider ( botocore.credentials.CredentialProvider): The aws credential provider. Returns: :class:`botocore.credentials.Credentials` object """ # Load credentials return aws_credentials_provider.load() def generate_auth_token(region, aws_debug_creds=False): """ Generates an base64-encoded signed url as auth token to authenticate with an Amazon MSK cluster using default IAM credentials. Args: region (str): The AWS region where the cluster is located. Returns: str: A base64-encoded authorization token. """ # Load credentials aws_credentials = __load_default_credentials__() if aws_debug_creds and logging.getLogger().isEnabledFor(logging.DEBUG): __log_caller_identity(aws_credentials) return __construct_auth_token(region, aws_credentials) def generate_auth_token_from_profile(region, aws_profile): """ Generates an base64-encoded signed url as auth token to authenticate with an Amazon MSK cluster using IAM credentials from an aws named profile. Args: region (str): The AWS region where the cluster is located. aws_profile (str): The name of the AWS profile to use. Returns: str: A base64-encoded authorization token. """ # Load credentials aws_credentials = __load_credentials_from_aws_profile__(aws_profile) return __construct_auth_token(region, aws_credentials) def generate_auth_token_from_role_arn( region, role_arn, sts_session_name=DEFAULT_STS_SESSION_NAME ): """ Generates an base64-encoded signed url as auth token to authenticate with an Amazon MSK cluster using IAM Credentials by assuming the provided role arn. Args: region (str): The AWS region where the cluster is located. role_arn (str): The ARN of the IAM role to assume for the session. sts_session_name (str): The sts session name for assumed role's session. Optional. Returns: str: A base64-encoded authorization token. """ # Load credentials aws_credentials = __load_credentials_from_aws_role_arn__(role_arn, sts_session_name) return __construct_auth_token(region, aws_credentials) def generate_auth_token_from_credentials_provider(region, aws_credentials_provider): """ Generates an base64-encoded signed url as auth token to authenticate with an Amazon MSK cluster using IAM Credentials provided by a credentials provider. Args: region (str): The AWS region where the cluster is located. aws_credentials_provider (botocore.credentials.CredentialProvider): The credentials provider that provides IAM credentials. Returns: str: A base64-encoded authorization token. """ # Check the type of the credentials provider if not isinstance(aws_credentials_provider, CredentialProvider): raise TypeError( "aws_credentials_provider should be of type " "botocore.credentials.CredentialProvider " ) # Load credentials aws_credentials = __load_credentials_from_aws_credentials_provider__( aws_credentials_provider ) return __construct_auth_token(region, aws_credentials) def __construct_auth_token(region, aws_credentials): """ Private function that constructs the authorization token using IAM Credentials. Args: region (str): The AWS region where the cluster is located. aws_credentials (dict): The credentials to be used to generate signed url. Returns: str: A base64-encoded authorization token. """ # Validate credentials are not empty if not aws_credentials.access_key or not aws_credentials.secret_key: raise ValueError("AWS Credentials can not be empty") # Extract endpoint URL endpoint_url = ENDPOINT_URL_TEMPLATE.format(region) # Set up resource path and query parameters query_params = {ACTION_TYPE: ACTION_NAME} # Create SigV4 instance sig_v4 = SigV4QueryAuth( aws_credentials, SIGNING_NAME, region, expires=DEFAULT_TOKEN_EXPIRY_SECONDS ) # Create request with url and parameters request = AWSRequest(method="GET", url=endpoint_url, params=query_params) # Add auth to the request and prepare the request sig_v4.add_auth(request) query_params = {USER_AGENT_KEY: __get_user_agent__()} request.params = query_params prepped = request.prepare() # Get the signed url signed_url = prepped.url # Base 64 encode and remove the padding from the end signed_url_bytes = signed_url.encode("utf-8") base64_bytes = base64.urlsafe_b64encode(signed_url_bytes) base64_encoded_signed_url = base64_bytes.decode("utf-8").rstrip("=") return base64_encoded_signed_url, __get_expiration_time_ms(request) def __get_expiration_time_ms(request): """ Private function that parses the url and gets the expiration time Args: request (AWSRequest): The signed aws request object """ # Parse the signed request parsed_url = urlparse(request.url) parsed_ul_params = parse_qs(parsed_url.query) parsed_signing_time = datetime.strptime(parsed_ul_params['X-Amz-Date'][0], "%Y%m%dT%H%M%SZ") # Make the datetime object timezone-aware signing_time = parsed_signing_time.replace(tzinfo=timezone.utc) # Convert the Unix timestamp to milliseconds expiration_timestamp_seconds = int( signing_time.timestamp()) + DEFAULT_TOKEN_EXPIRY_SECONDS # Get lifetime of token expiration_timestamp_ms = expiration_timestamp_seconds * 1000 return expiration_timestamp_ms def __log_caller_identity(aws_credentials): """ Private function that logs the caller identity Args: aws_credentials (dict): The credentials to be used to generate signed url """ # Create sts client sts_client = boto3.client("sts", aws_access_key_id=aws_credentials.access_key, aws_secret_access_key=aws_credentials.secret_key, aws_session_token=aws_credentials.token) # Get caller identity caller_identity = sts_client.get_caller_identity() # Log the identity in debug mode logging.debug("Credentials Identity: {UserId: %s, Account: %s, Arn: %s}", caller_identity.get('UserId'), caller_identity.get('Account'), caller_identity.get('Arn'))