pkg/auth/cred.go (125 lines of code) (raw):

/* Copyright (c) Microsoft Corporation. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package auth import ( "context" "crypto/tls" "crypto/x509" "encoding/pem" "fmt" "os" "time" "github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" "github.com/Azure/azure-sdk-for-go/services/keyvault/2016-10-01/keyvault" "github.com/Azure/go-autorest/autorest" "github.com/Azure/go-autorest/autorest/to" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/confidential" "github.com/azure/gpu-provisioner/pkg/utils" "github.com/pkg/errors" ) const ( e2eOverlayResourceVersionKey = "AKS_E2E_OVERLAY_RESOURCE_VERSION" ) // ClientAssertionCredential authenticates an application with assertions provided by a callback function. type ClientAssertionCredential struct { assertion, file string client confidential.Client lastRead time.Time } // NewCredential provides a token credential for msi and service principal auth func NewCredential(cfg *Config, authorizer autorest.Authorizer) (azcore.TokenCredential, error) { if cfg == nil { return nil, fmt.Errorf("failed to create credential, nil config provided") } // Azure AD Workload Identity webhook will inject the following env vars: // AZURE_FEDERATED_TOKEN_FILE is the service account token path // AZURE_AUTHORITY_HOST is the AAD authority hostname tokenFilePath := os.Getenv("AZURE_FEDERATED_TOKEN_FILE") authority := os.Getenv("AZURE_AUTHORITY_HOST") if tokenFilePath == "" || authority == "" { return nil, fmt.Errorf("required environment variables not set, AZURE_FEDERATED_TOKEN_FILE: %s, AZURE_AUTHORITY_HOST: %s", tokenFilePath, authority) } c := &ClientAssertionCredential{file: tokenFilePath} var cred confidential.Credential isE2E := utils.WithDefaultBool("E2E_TEST_MODE", false) if isE2E { armClientCert, err := getE2ETestingCert(authorizer) if err != nil { return nil, err } certPEM, keyPEM := splitPEMBlock([]byte(to.String(armClientCert))) if len(certPEM) == 0 { return nil, errors.New("malformed cert pem format") } // Load client cert cert, err := tls.X509KeyPair(certPEM, keyPEM) if err != nil { return nil, err } leafCert := []tls.Certificate{cert} cred, err = confidential.NewCredFromCert([]*x509.Certificate{leafCert[0].Leaf}, keyPEM) if err != nil { return nil, err } } else { cred = confidential.NewCredFromAssertionCallback( func(ctx context.Context, _ confidential.AssertionRequestOptions) (string, error) { return c.readJWTFromFS() }, ) } // create the confidential client to request an AAD token confidentialClientApp, err := confidential.New( fmt.Sprintf("%s%s/oauth2/token", authority, cfg.TenantID), cfg.UserAssignedIdentityID, cred) if err != nil { return nil, fmt.Errorf("failed to create confidential client app: %w", err) } c.client = confidentialClientApp return c, nil } // GetToken implements the TokenCredential interface func (c *ClientAssertionCredential) GetToken(ctx context.Context, opts policy.TokenRequestOptions) (azcore.AccessToken, error) { // get the token from the confidential client token, err := c.client.AcquireTokenByCredential(ctx, opts.Scopes) if err != nil { return azcore.AccessToken{}, err } return azcore.AccessToken{ Token: token.AccessToken, ExpiresOn: token.ExpiresOn, }, nil } // readJWTFromFS reads the jwt from file system // Source: https://github.com/Azure/azure-workload-identity/blob/d126293e3c7c669378b225ad1b1f29cf6af4e56d/examples/msal-go/token_credential.go#L88 func (c *ClientAssertionCredential) readJWTFromFS() (string, error) { if now := time.Now(); c.lastRead.Add(5 * time.Minute).Before(now) { content, err := os.ReadFile(c.file) if err != nil { return "", err } c.assertion = string(content) c.lastRead = now } return c.assertion, nil } func getE2ETestingCert(authorizer autorest.Authorizer) (*string, error) { e2eOverlayResourceVersion := os.Getenv(e2eOverlayResourceVersionKey) if e2eOverlayResourceVersion == "" { return nil, fmt.Errorf("E2E overlay resource version is not set") } keyVaultUrl := fmt.Sprintf("https://hcp%s.vault.azure.net/", e2eOverlayResourceVersion) client := keyvault.New() client.Authorizer = authorizer ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute) defer cancel() result, err := client.GetSecret(ctx, keyVaultUrl, "e2e-arm-client-cert", "") if err != nil { //+gocover:ignore:block keyvault fetch return nil, err } return result.Value, nil } // split the pem block to cert/key func splitPEMBlock(pemBlock []byte) (certPEM []byte, keyPEM []byte) { for { var derBlock *pem.Block derBlock, pemBlock = pem.Decode(pemBlock) if derBlock == nil { break } if derBlock.Type == "CERTIFICATE" { certPEM = append(certPEM, pem.EncodeToMemory(derBlock)...) } else if derBlock.Type == "PRIVATE KEY" { keyPEM = append(keyPEM, pem.EncodeToMemory(derBlock)...) } } return certPEM, keyPEM }