client/internal/bootstrap/auth.go (67 lines of code) (raw):
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
package bootstrap
import (
"encoding/base64"
"fmt"
"strings"
"github.com/Azure/aks-secure-tls-bootstrap/client/internal/datamodel"
"github.com/Azure/go-autorest/autorest/adal"
"github.com/Azure/go-autorest/autorest/azure"
"go.uber.org/zap"
)
const (
certificateSecretPrefix = "certificate:"
)
// extractAccessTokenFunc extracts an oauth access token from the specified service principal token after a refresh, fake implementations given in unit tests.
type extractAccessTokenFunc func(token *adal.ServicePrincipalToken) (string, error)
func extractAccessToken(token *adal.ServicePrincipalToken) (string, error) {
if err := token.Refresh(); err != nil {
return "", fmt.Errorf("obtaining fresh access token: %w", err)
}
return token.OAuthToken(), nil
}
// getAccessToken retrieves an AAD access token (JWT) using the specified custom client ID, resource, and azure config.
// MSI access tokens are retrieved from IMDS, while service principal tokens are retrieved directly from AAD.
func (c *Client) getAccessToken(customClientID, resource string, azureConfig *datamodel.AzureConfig) (string, error) {
userAssignedID := azureConfig.UserAssignedIdentityID
if customClientID != "" {
userAssignedID = customClientID
}
if userAssignedID != "" {
c.logger.Info("generating MSI access token", zap.String("clientId", userAssignedID))
token, err := adal.NewServicePrincipalTokenFromManagedIdentity(resource, &adal.ManagedIdentityOptions{
ClientID: userAssignedID,
})
if err != nil {
return "", fmt.Errorf("generating MSI access token: %w", err)
}
return c.extractAccessTokenFunc(token)
}
env, err := azure.EnvironmentFromName(azureConfig.Cloud)
if err != nil {
return "", fmt.Errorf("getting azure environment config for cloud %q: %w", azureConfig.Cloud, err)
}
oauthConfig, err := adal.NewOAuthConfig(env.ActiveDirectoryEndpoint, azureConfig.TenantID)
if err != nil {
return "", fmt.Errorf("creating oauth config with azure environment: %w", err)
}
if !strings.HasPrefix(azureConfig.ClientSecret, certificateSecretPrefix) {
c.logger.Info("generating SPN access token with username and password", zap.String("clientId", azureConfig.ClientID))
token, err := adal.NewServicePrincipalToken(*oauthConfig, azureConfig.ClientID, azureConfig.ClientSecret, resource)
if err != nil {
return "", fmt.Errorf("generating SPN access token with username and password: %w", err)
}
return c.extractAccessTokenFunc(token)
}
c.logger.Info("client secret contains certificate data, using certificate to generate SPN access token", zap.String("clientId", azureConfig.ClientID))
certData, err := base64.StdEncoding.DecodeString(strings.TrimPrefix(azureConfig.ClientSecret, certificateSecretPrefix))
if err != nil {
return "", fmt.Errorf("b64-decoding certificate data in client secret: %w", err)
}
certificate, privateKey, err := adal.DecodePfxCertificateData(certData, "")
if err != nil {
return "", fmt.Errorf("decoding pfx certificate data in client secret: %w", err)
}
c.logger.Info("generating SPN access token with certificate", zap.String("clientId", azureConfig.ClientID))
token, err := adal.NewServicePrincipalTokenFromCertificate(*oauthConfig, azureConfig.ClientID, certificate, privateKey, resource)
if err != nil {
return "", fmt.Errorf("generating SPN access token with certificate: %w", err)
}
return c.extractAccessTokenFunc(token)
}