aws_msk_iam_sasl_signer/MSKAuthTokenProvider.py (108 lines of code) (raw):
# Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
import base64
import importlib.metadata
import logging
from datetime import datetime, timezone
from urllib.parse import parse_qs, urlparse
import boto3
import botocore.session
from botocore.auth import SigV4QueryAuth
from botocore.awsrequest import AWSRequest
from botocore.config import Config
from botocore.credentials import CredentialProvider, Credentials
ENDPOINT_URL_TEMPLATE = "https://kafka.{}.amazonaws.com/"
DEFAULT_TOKEN_EXPIRY_SECONDS = 900
DEFAULT_STS_SESSION_NAME = "MSKSASLDefaultSession"
ACTION_TYPE = "Action"
ACTION_NAME = "kafka-cluster:Connect"
SIGNING_NAME = "kafka-cluster"
USER_AGENT_KEY = "User-Agent"
LIB_NAME = "aws-msk-iam-sasl-signer-python"
def __get_user_agent__():
"""
Builds the user-agent
Returns:
str: The user-agent identifying this signer library.
"""
return f"{LIB_NAME}/{importlib.metadata.version(LIB_NAME)}"
def __load_default_credentials__():
"""
Loads IAM credentials from default credentials chain.
Returns:
:class:`botocore.credentials.Credentials` object
"""
# Create a botocore session with default settings
botocore_session = botocore.session.Session()
return botocore_session.get_credentials()
def __load_credentials_from_aws_profile__(aws_profile):
"""
Loads IAM credentials from named aws profile.
Parameters:
- aws_profile (str): The name of the AWS profile to use for the session.
Returns:
:class:`botocore.credentials.Credentials` object
"""
# Create a botocore session with an aws named profile
botocore_session = botocore.session.Session(profile=aws_profile)
return botocore_session.get_credentials()
def __load_credentials_from_aws_role_arn__(
role_arn, sts_session_name=DEFAULT_STS_SESSION_NAME
):
"""
Loads IAM credentials from an aws role arn. At each refresh it creates a
new sts client with a global endpoint. If this is not the desired
behavior, please use your own credentials provider.
Parameters:
- role_arn (str): The ARN of the IAM role to assume for the session.
- sts_session_name (str): The sts session name for assumed role's session.
Returns:
:class:`botocore.credentials.Credentials` object
"""
# Create sts client
sts_client = boto3.client("sts", config=Config())
assumed_role = sts_client.assume_role(
RoleArn=role_arn, RoleSessionName=sts_session_name
)
assumed_role_credentials = assumed_role["Credentials"]
return Credentials(
assumed_role_credentials["AccessKeyId"],
assumed_role_credentials["SecretAccessKey"],
assumed_role_credentials["SessionToken"],
)
def __load_credentials_from_aws_credentials_provider__(
aws_credentials_provider
):
"""
Loads IAM credentials from aws credentials provider.
Parameters: - aws_credentials_provider (
botocore.credentials.CredentialProvider): The aws credential provider.
Returns:
:class:`botocore.credentials.Credentials` object
"""
# Load credentials
return aws_credentials_provider.load()
def generate_auth_token(region, aws_debug_creds=False):
"""
Generates an base64-encoded signed url as auth token to authenticate
with an Amazon MSK cluster using default IAM credentials.
Args:
region (str): The AWS region where the cluster is located.
Returns:
str: A base64-encoded authorization token.
"""
# Load credentials
aws_credentials = __load_default_credentials__()
if aws_debug_creds and logging.getLogger().isEnabledFor(logging.DEBUG):
__log_caller_identity(aws_credentials)
return __construct_auth_token(region, aws_credentials)
def generate_auth_token_from_profile(region, aws_profile):
"""
Generates an base64-encoded signed url as auth token to authenticate
with an Amazon MSK cluster using IAM credentials from an aws named
profile.
Args:
region (str): The AWS region where the cluster is located.
aws_profile (str): The name of the AWS profile to use.
Returns:
str: A base64-encoded authorization token.
"""
# Load credentials
aws_credentials = __load_credentials_from_aws_profile__(aws_profile)
return __construct_auth_token(region, aws_credentials)
def generate_auth_token_from_role_arn(
region, role_arn, sts_session_name=DEFAULT_STS_SESSION_NAME
):
"""
Generates an base64-encoded signed url as auth token to authenticate
with an Amazon MSK cluster using IAM Credentials by assuming the
provided role arn.
Args: region (str): The AWS region where the cluster is located.
role_arn (str): The ARN of the IAM role to assume for the session.
sts_session_name (str): The sts session name for assumed role's session.
Optional. Returns: str: A base64-encoded authorization token.
"""
# Load credentials
aws_credentials = __load_credentials_from_aws_role_arn__(role_arn,
sts_session_name)
return __construct_auth_token(region, aws_credentials)
def generate_auth_token_from_credentials_provider(region,
aws_credentials_provider):
"""
Generates an base64-encoded signed url as auth token to authenticate
with an Amazon MSK cluster using IAM Credentials provided by a
credentials provider.
Args: region (str): The AWS region where the cluster is located.
aws_credentials_provider (botocore.credentials.CredentialProvider): The
credentials provider that provides IAM credentials. Returns: str: A
base64-encoded authorization token.
"""
# Check the type of the credentials provider
if not isinstance(aws_credentials_provider, CredentialProvider):
raise TypeError(
"aws_credentials_provider should be of type "
"botocore.credentials.CredentialProvider "
)
# Load credentials
aws_credentials = __load_credentials_from_aws_credentials_provider__(
aws_credentials_provider
)
return __construct_auth_token(region, aws_credentials)
def __construct_auth_token(region, aws_credentials):
"""
Private function that constructs the authorization token using IAM
Credentials.
Args: region (str): The AWS region where the cluster is located.
aws_credentials (dict): The credentials to be used to generate signed
url. Returns: str: A base64-encoded authorization token.
"""
# Validate credentials are not empty
if not aws_credentials.access_key or not aws_credentials.secret_key:
raise ValueError("AWS Credentials can not be empty")
# Extract endpoint URL
endpoint_url = ENDPOINT_URL_TEMPLATE.format(region)
# Set up resource path and query parameters
query_params = {ACTION_TYPE: ACTION_NAME}
# Create SigV4 instance
sig_v4 = SigV4QueryAuth(
aws_credentials, SIGNING_NAME, region,
expires=DEFAULT_TOKEN_EXPIRY_SECONDS
)
# Create request with url and parameters
request = AWSRequest(method="GET", url=endpoint_url, params=query_params)
# Add auth to the request and prepare the request
sig_v4.add_auth(request)
query_params = {USER_AGENT_KEY: __get_user_agent__()}
request.params = query_params
prepped = request.prepare()
# Get the signed url
signed_url = prepped.url
# Base 64 encode and remove the padding from the end
signed_url_bytes = signed_url.encode("utf-8")
base64_bytes = base64.urlsafe_b64encode(signed_url_bytes)
base64_encoded_signed_url = base64_bytes.decode("utf-8").rstrip("=")
return base64_encoded_signed_url, __get_expiration_time_ms(request)
def __get_expiration_time_ms(request):
"""
Private function that parses the url and gets the expiration time
Args: request (AWSRequest): The signed aws request object
"""
# Parse the signed request
parsed_url = urlparse(request.url)
parsed_ul_params = parse_qs(parsed_url.query)
parsed_signing_time = datetime.strptime(parsed_ul_params['X-Amz-Date'][0],
"%Y%m%dT%H%M%SZ")
# Make the datetime object timezone-aware
signing_time = parsed_signing_time.replace(tzinfo=timezone.utc)
# Convert the Unix timestamp to milliseconds
expiration_timestamp_seconds = int(
signing_time.timestamp()) + DEFAULT_TOKEN_EXPIRY_SECONDS
# Get lifetime of token
expiration_timestamp_ms = expiration_timestamp_seconds * 1000
return expiration_timestamp_ms
def __log_caller_identity(aws_credentials):
"""
Private function that logs the caller identity
Args: aws_credentials (dict): The credentials to be used to generate signed
url
"""
# Create sts client
sts_client = boto3.client("sts",
aws_access_key_id=aws_credentials.access_key,
aws_secret_access_key=aws_credentials.secret_key,
aws_session_token=aws_credentials.token)
# Get caller identity
caller_identity = sts_client.get_caller_identity()
# Log the identity in debug mode
logging.debug("Credentials Identity: {UserId: %s, Account: %s, Arn: %s}",
caller_identity.get('UserId'),
caller_identity.get('Account'),
caller_identity.get('Arn'))