pkg/auth/auth.go (125 lines of code) (raw):

// Copyright (c) Microsoft and contributors. All rights reserved. // // This source code is licensed under the MIT license found in the // LICENSE file in the root directory of this source tree. package auth import ( "crypto/rsa" "crypto/x509" "fmt" "net/http" "os" "regexp" "github.com/Azure/kubernetes-kms/pkg/config" "github.com/Azure/kubernetes-kms/pkg/consts" "github.com/Azure/go-autorest/autorest" "github.com/Azure/go-autorest/autorest/adal" "github.com/Azure/go-autorest/autorest/azure" "golang.org/x/crypto/pkcs12" "monis.app/mlog" ) // GetKeyvaultToken() returns token for Keyvault endpoint. func GetKeyvaultToken(config *config.AzureConfig, env *azure.Environment, resource string, proxyMode bool) (authorizer autorest.Authorizer, err error) { servicePrincipalToken, err := GetServicePrincipalToken(config, env.ActiveDirectoryEndpoint, resource, proxyMode) if err != nil { return nil, err } authorizer = autorest.NewBearerAuthorizer(servicePrincipalToken) return authorizer, nil } // GetServicePrincipalToken creates a new service principal token based on the configuration. func GetServicePrincipalToken(config *config.AzureConfig, aadEndpoint, resource string, proxyMode bool) (adal.OAuthTokenProvider, error) { oauthConfig, err := adal.NewOAuthConfig(aadEndpoint, config.TenantID) if err != nil { return nil, fmt.Errorf("failed to create OAuth config, error: %v", err) } if config.UseManagedIdentityExtension { mlog.Info("using managed identity extension to retrieve access token") msiEndpoint, err := adal.GetMSIVMEndpoint() if err != nil { return nil, fmt.Errorf("failed to get managed service identity endpoint, error: %v", err) } // using user-assigned managed identity to access keyvault if len(config.UserAssignedIdentityID) > 0 { mlog.Info("using User-assigned managed identity to retrieve access token", "clientID", redactClientCredentials(config.UserAssignedIdentityID)) return adal.NewServicePrincipalTokenFromMSIWithUserAssignedID(msiEndpoint, resource, config.UserAssignedIdentityID) } mlog.Info("using system-assigned managed identity to retrieve access token") // using system-assigned managed identity to access keyvault return adal.NewServicePrincipalTokenFromMSI( msiEndpoint, resource) } if len(config.ClientSecret) > 0 && len(config.ClientID) > 0 { mlog.Info("azure: using client_id+client_secret to retrieve access token", "clientID", redactClientCredentials(config.ClientID), "clientSecret", redactClientCredentials(config.ClientSecret)) spt, err := adal.NewServicePrincipalToken( *oauthConfig, config.ClientID, config.ClientSecret, resource) if err != nil { return nil, err } if proxyMode { return addTargetTypeHeader(spt), nil } return spt, nil } if len(config.AADClientCertPath) > 0 && len(config.AADClientCertPassword) > 0 { mlog.Info("using jwt client_assertion (client_cert+client_private_key) to retrieve access token") certData, err := os.ReadFile(config.AADClientCertPath) if err != nil { return nil, fmt.Errorf("failed to read client certificate from file %s, error: %v", config.AADClientCertPath, err) } certificate, privateKey, err := decodePkcs12(certData, config.AADClientCertPassword) if err != nil { return nil, fmt.Errorf("failed to decode the client certificate, error: %v", err) } spt, err := adal.NewServicePrincipalTokenFromCertificate( *oauthConfig, config.ClientID, certificate, privateKey, resource) if err != nil { return nil, err } if proxyMode { return addTargetTypeHeader(spt), nil } return spt, nil } return nil, fmt.Errorf("no credentials provided for accessing keyvault") } // ParseAzureEnvironment returns azure environment by name. func ParseAzureEnvironment(cloudName string) (*azure.Environment, error) { var env azure.Environment var err error if cloudName == "" { env = azure.PublicCloud } else { env, err = azure.EnvironmentFromName(cloudName) } return &env, err } // decodePkcs12 decodes a PKCS#12 client certificate by extracting the public certificate and // the private RSA key. func decodePkcs12(pkcs []byte, password string) (*x509.Certificate, *rsa.PrivateKey, error) { privateKey, certificate, err := pkcs12.Decode(pkcs, password) if err != nil { return nil, nil, fmt.Errorf("decoding the PKCS#12 client certificate: %v", err) } rsaPrivateKey, isRsaKey := privateKey.(*rsa.PrivateKey) if !isRsaKey { return nil, nil, fmt.Errorf("PKCS#12 certificate must contain a RSA private key") } return certificate, rsaPrivateKey, nil } // redactClientCredentials applies regex to a sensitive string and return the redacted value. func redactClientCredentials(sensitiveString string) string { r := regexp.MustCompile(`^(\S{4})(\S|\s)*(\S{4})$`) return r.ReplaceAllString(sensitiveString, "$1##### REDACTED #####$3") } // addTargetTypeHeader adds the target header if proxy mode is enabled. func addTargetTypeHeader(spt *adal.ServicePrincipalToken) *adal.ServicePrincipalToken { spt.SetSender(autorest.CreateSender( (func() autorest.SendDecorator { return func(s autorest.Sender) autorest.Sender { return autorest.SenderFunc(func(r *http.Request) (*http.Response, error) { r.Header.Set(consts.RequestHeaderTargetType, consts.TargetTypeAzureActiveDirectory) return s.Do(r) }) } })())) return spt }