signer/msk_auth_token_provider.go (267 lines of code) (raw):
package signer
import (
"context"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"fmt"
"log"
"net/http"
"net/url"
"runtime"
"strconv"
"strings"
"time"
"github.com/aws/aws-sdk-go-v2/aws"
v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/service/sts"
)
const (
ActionType = "Action" // ActionType represents the key for the action type in the request.
ActionName = "kafka-cluster:Connect" // ActionName represents the specific action name for connecting to a Kafka cluster.
SigningName = "kafka-cluster" // SigningName represents the signing name for the Kafka cluster.
UserAgentKey = "User-Agent" // UserAgentKey represents the key for the User-Agent parameter in the request.
LibName = "aws-msk-iam-sasl-signer-go" // LibName represents the name of the library.
DateQueryKey = "X-Amz-Date" // DateQueryKey represents the key for the date in the query parameters.
ExpiresQueryKey = "X-Amz-Expires" // ExpiresQueryKey represents the key for the expiration time in the query parameters.
DefaultSessionName = "MSKSASLDefaultSession" // DefaultSessionName represents the default session name for assuming a role.
DefaultExpirySeconds = 900 // DefaultExpirySeconds represents the default expiration time in seconds.
)
var (
endpointURLTemplate = "kafka.%s.amazonaws.com" // endpointURLTemplate represents the template for the Kafka endpoint URL
AwsDebugCreds = false // AwsDebugCreds flag indicates whether credentials should be debugged
)
// GenerateAuthToken generates base64 encoded signed url as auth token from default credentials.
// Loads the IAM credentials from default credentials provider chain.
func GenerateAuthToken(ctx context.Context, region string) (string, int64, error) {
credentials, err := loadDefaultCredentials(ctx, region)
if err != nil {
return "", 0, fmt.Errorf("failed to load credentials: %w", err)
}
return constructAuthToken(ctx, region, credentials)
}
// GenerateAuthTokenFromProfile generates base64 encoded signed url as auth token by loading IAM credentials from an AWS named profile.
func GenerateAuthTokenFromProfile(ctx context.Context, region string, awsProfile string) (string, int64, error) {
credentials, err := loadCredentialsFromProfile(ctx, region, awsProfile)
if err != nil {
return "", 0, fmt.Errorf("failed to load credentials: %w", err)
}
return constructAuthToken(ctx, region, credentials)
}
// GenerateAuthTokenFromRole generates base64 encoded signed url as auth token by loading IAM credentials from an aws role Arn
func GenerateAuthTokenFromRole(
ctx context.Context, region string, roleArn string, stsSessionName string,
) (string, int64, error) {
return GenerateAuthTokenFromRoleWithExternalId(ctx, region, roleArn, stsSessionName, "")
}
// GenerateAuthTokenFromRoleWithExternalId generates base64 encoded signed url as auth token by loading IAM credentials from an aws role Arn
//
// If the provided externalId is empty, it behaves exactly like GenerateAuthTokenFromRole.
func GenerateAuthTokenFromRoleWithExternalId(
ctx context.Context, region string, roleArn string, stsSessionName string, externalId string,
) (string, int64, error) {
if stsSessionName == "" {
stsSessionName = DefaultSessionName
}
credentials, err := loadCredentialsFromRoleArn(ctx, region, roleArn, stsSessionName, externalId)
if err != nil {
return "", 0, fmt.Errorf("failed to load credentials: %w", err)
}
return constructAuthToken(ctx, region, credentials)
}
// GenerateAuthTokenFromWebIdentity generates base64 encoded signed url as auth token by loading IAM credentials from a web identity role Arn.
func GenerateAuthTokenFromWebIdentity(
ctx context.Context, region string, roleArn string, webIdentityToken string, stsSessionName string,
) (string, int64, error) {
if stsSessionName == "" {
stsSessionName = DefaultSessionName
}
credentials, err := loadCredentialsFromWebIdentityParameters(ctx, region, roleArn, webIdentityToken, stsSessionName)
if err != nil {
return "", 0, fmt.Errorf("failed to load credentials: %w", err)
}
return constructAuthToken(ctx, region, credentials)
}
// GenerateAuthTokenFromCredentialsProvider generates base64 encoded signed url as auth token by loading IAM credentials
// from an aws credentials provider
func GenerateAuthTokenFromCredentialsProvider(
ctx context.Context, region string, credentialsProvider aws.CredentialsProvider,
) (string, int64, error) {
credentials, err := loadCredentialsFromCredentialsProvider(ctx, credentialsProvider)
if err != nil {
return "", 0, fmt.Errorf("failed to load credentials: %w", err)
}
return constructAuthToken(ctx, region, credentials)
}
// Loads credentials from the default credential chain.
func loadDefaultCredentials(ctx context.Context, region string) (*aws.Credentials, error) {
cfg, err := config.LoadDefaultConfig(ctx, config.WithRegion(region))
if err != nil {
return nil, fmt.Errorf("unable to load SDK config: %w", err)
}
return loadCredentialsFromCredentialsProvider(ctx, cfg.Credentials)
}
// Loads credentials from a named aws profile.
func loadCredentialsFromProfile(ctx context.Context, region string, awsProfile string) (*aws.Credentials, error) {
cfg, err := config.LoadDefaultConfig(ctx,
config.WithRegion(region),
config.WithSharedConfigProfile(awsProfile),
)
if err != nil {
return nil, fmt.Errorf("unable to load SDK config: %w", err)
}
return loadCredentialsFromCredentialsProvider(ctx, cfg.Credentials)
}
// Loads credentials from a named by assuming the passed role.
// This implementation creates a new sts client for every call to get or refresh token. In order to avoid this, please
// use your own credentials provider.
// If you wish to use regional endpoint, please pass your own credentials provider.
func loadCredentialsFromRoleArn(
ctx context.Context, region string, roleArn string, stsSessionName string, externalId string,
) (*aws.Credentials, error) {
cfg, err := config.LoadDefaultConfig(ctx, config.WithRegion(region))
if err != nil {
return nil, fmt.Errorf("unable to load SDK config: %w", err)
}
stsClient := sts.NewFromConfig(cfg)
assumeRoleInput := &sts.AssumeRoleInput{
RoleArn: aws.String(roleArn),
RoleSessionName: aws.String(stsSessionName),
}
if externalId != "" {
assumeRoleInput.ExternalId = aws.String(externalId)
}
assumeRoleOutput, err := stsClient.AssumeRole(ctx, assumeRoleInput)
if err != nil {
return nil, fmt.Errorf("unable to assume role, %s: %w", roleArn, err)
}
//Create new aws.Credentials instance using the credentials from AssumeRoleOutput.Credentials
creds := aws.Credentials{
AccessKeyID: *assumeRoleOutput.Credentials.AccessKeyId,
SecretAccessKey: *assumeRoleOutput.Credentials.SecretAccessKey,
SessionToken: *assumeRoleOutput.Credentials.SessionToken,
}
return &creds, nil
}
// Loads credentials from a named by assuming the passed web identity role and id token.
// This implementation creates a new sts client for every call to get or refresh token. In order to avoid this, please
// use your own credentials' provider.
// If you wish to use regional endpoint, please pass your own credentials' provider.
func loadCredentialsFromWebIdentityParameters(
ctx context.Context, region, roleArn, webIdentityToken, stsSessionName string,
) (*aws.Credentials, error) {
cfg, err := config.LoadDefaultConfig(ctx, config.WithRegion(region))
if err != nil {
return nil, fmt.Errorf("unable to load SDK config: %w", err)
}
stsClient := sts.NewFromConfig(cfg)
assumeRoleWithWebIdentityInput := &sts.AssumeRoleWithWebIdentityInput{
RoleArn: aws.String(roleArn),
RoleSessionName: aws.String(stsSessionName),
WebIdentityToken: aws.String(webIdentityToken),
}
assumeRoleWithWebIdentityOutput, err := stsClient.AssumeRoleWithWebIdentity(ctx, assumeRoleWithWebIdentityInput)
if err != nil {
return nil, fmt.Errorf("unable to assume role with web identity, %s: %w", roleArn, err)
}
//Create new aws.Credentials instance using the credentials from AssumeRoleWithWebIdentityOutput.Credentials
creds := aws.Credentials{
AccessKeyID: *assumeRoleWithWebIdentityOutput.Credentials.AccessKeyId,
SecretAccessKey: *assumeRoleWithWebIdentityOutput.Credentials.SecretAccessKey,
SessionToken: *assumeRoleWithWebIdentityOutput.Credentials.SessionToken,
}
return &creds, nil
}
// Loads credentials from the credentials provider
func loadCredentialsFromCredentialsProvider(
ctx context.Context, credentialsProvider aws.CredentialsProvider,
) (*aws.Credentials, error) {
creds, err := credentialsProvider.Retrieve(ctx)
return &creds, err
}
// Constructs Auth Token.
func constructAuthToken(ctx context.Context, region string, credentials *aws.Credentials) (string, int64, error) {
endpointURL := fmt.Sprintf(endpointURLTemplate, region)
if credentials == nil || credentials.AccessKeyID == "" || credentials.SecretAccessKey == "" {
return "", 0, fmt.Errorf("aws credentials cannot be empty")
}
if AwsDebugCreds {
logCallerIdentity(ctx, region, *credentials)
}
req, err := buildRequest(DefaultExpirySeconds, endpointURL)
if err != nil {
return "", 0, fmt.Errorf("failed to build request for signing: %w", err)
}
signedURL, err := signRequest(ctx, req, region, credentials)
if err != nil {
return "", 0, fmt.Errorf("failed to sign request with aws sig v4: %w", err)
}
expirationTimeMs, err := getExpirationTimeMs(signedURL)
if err != nil {
return "", 0, fmt.Errorf("failed to extract expiration from signed url: %w", err)
}
signedURLWithUserAgent, err := addUserAgent(signedURL)
if err != nil {
return "", 0, fmt.Errorf("failed to add user agent to the signed url: %w", err)
}
return base64Encode(signedURLWithUserAgent), expirationTimeMs, nil
}
// Build https request with query parameters in order to sign.
func buildRequest(expirySeconds int, endpointURL string) (*http.Request, error) {
query := url.Values{
ActionType: {ActionName},
ExpiresQueryKey: {strconv.FormatInt(int64(expirySeconds), 10)},
}
authURL := url.URL{
Host: endpointURL,
Scheme: "https",
Path: "/",
RawQuery: query.Encode(),
}
return http.NewRequest(http.MethodGet, authURL.String(), nil)
}
// Sign request with aws sig v4.
func signRequest(ctx context.Context, req *http.Request, region string, credentials *aws.Credentials) (string, error) {
signer := v4.NewSigner()
signedURL, _, err := signer.PresignHTTP(ctx, *credentials, req,
calculateSHA256Hash(""),
SigningName,
region,
time.Now().UTC(),
)
return signedURL, err
}
// Parses the URL and gets the expiration time in millis associated with the signed url
func getExpirationTimeMs(signedURL string) (int64, error) {
parsedURL, err := url.Parse(signedURL)
if err != nil {
return 0, fmt.Errorf("failed to parse the signed url: %w", err)
}
params := parsedURL.Query()
date, err := time.Parse("20060102T150405Z", params.Get(DateQueryKey))
if err != nil {
return 0, fmt.Errorf("failed to parse the 'X-Amz-Date' param from signed url: %w", err)
}
signingTimeMs := date.UnixNano() / int64(time.Millisecond)
expiryDurationSeconds, err := strconv.ParseInt(params.Get(ExpiresQueryKey), 10, 64)
if err != nil {
return 0, fmt.Errorf("failed to parse the 'X-Amz-Expires' param from signed url: %w", err)
}
expiryDurationMs := expiryDurationSeconds * 1000
expiryMs := signingTimeMs + expiryDurationMs
return expiryMs, nil
}
// Calculate sha256Hash and hex encode it.
func calculateSHA256Hash(input string) string {
hash := sha256.Sum256([]byte(input))
return hex.EncodeToString(hash[:])
}
// Base64 encode with raw url encoding.
func base64Encode(signedURL string) string {
signedURLBytes := []byte(signedURL)
return base64.RawURLEncoding.EncodeToString(signedURLBytes)
}
// Add user agent to the signed url
func addUserAgent(signedURL string) (string, error) {
parsedSignedURL, err := url.Parse(signedURL)
if err != nil {
return "", fmt.Errorf("failed to parse signed url: %w", err)
}
query := parsedSignedURL.Query()
userAgent := strings.Join([]string{LibName, version, runtime.Version()}, "/")
query.Set(UserAgentKey, userAgent)
parsedSignedURL.RawQuery = query.Encode()
return parsedSignedURL.String(), nil
}
// Log caller identity to debug which credentials are being picked up
func logCallerIdentity(ctx context.Context, region string, awsCredentials aws.Credentials) {
cfg, err := config.LoadDefaultConfig(ctx,
config.WithRegion(region),
config.WithCredentialsProvider(credentials.StaticCredentialsProvider{
Value: awsCredentials,
}),
)
if err != nil {
log.Printf("failed to load AWS configuration: %v", err)
}
stsClient := sts.NewFromConfig(cfg)
callerIdentity, err := stsClient.GetCallerIdentity(ctx, &sts.GetCallerIdentityInput{})
if err != nil {
log.Printf("failed to get caller identity: %v", err)
}
log.Printf("Credentials Identity: {UserId: %s, Account: %s, Arn: %s}\n",
*callerIdentity.UserId,
*callerIdentity.Account,
*callerIdentity.Arn)
}