pkg/internal/token/clientcertcredential.go (170 lines of code) (raw):

package token import ( "context" "crypto/rsa" "crypto/x509" "encoding/pem" "fmt" "os" "strings" "github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" "github.com/Azure/azure-sdk-for-go/sdk/azidentity" "github.com/Azure/azure-sdk-for-go/sdk/azidentity/cache" "golang.org/x/crypto/pkcs12" "k8s.io/klog/v2" ) type ClientCertificateCredential struct { cred *azidentity.ClientCertificateCredential } var _ CredentialProvider = (*ClientCertificateCredential)(nil) func newClientCertificateCredential(opts *Options) (CredentialProvider, error) { if opts.ClientID == "" { return nil, fmt.Errorf("client ID cannot be empty") } if opts.TenantID == "" { return nil, fmt.Errorf("tenant ID cannot be empty") } if opts.ClientCert == "" { return nil, fmt.Errorf("client certificate cannot be empty") } var ( c azidentity.Cache err error ) if opts.UsePersistentCache { c, err = cache.New(nil) if err != nil { klog.V(5).Infof("failed to create cache: %v", err) } } // Get the certificate and private key from file cert, rsaPrivateKey, err := readCertificate(opts.ClientCert, opts.ClientCertPassword) if err != nil { return nil, fmt.Errorf("failed to read certificate: %w", err) } azOpts := &azidentity.ClientCertificateCredentialOptions{ ClientOptions: azcore.ClientOptions{Cloud: opts.GetCloudConfiguration()}, Cache: c, SendCertificateChain: true, DisableInstanceDiscovery: opts.DisableInstanceDiscovery, } if opts.httpClient != nil { azOpts.ClientOptions.Transport = opts.httpClient } cred, err := azidentity.NewClientCertificateCredential( opts.TenantID, opts.ClientID, []*x509.Certificate{cert}, rsaPrivateKey, azOpts) if err != nil { return nil, fmt.Errorf("failed to create client certificate credential: %w", err) } return &ClientCertificateCredential{cred: cred}, nil } func (c *ClientCertificateCredential) Name() string { return "ClientCertificateCredential" } func (c *ClientCertificateCredential) Authenticate(ctx context.Context, opts *policy.TokenRequestOptions) (azidentity.AuthenticationRecord, error) { return azidentity.AuthenticationRecord{}, errAuthenticateNotSupported } func (c *ClientCertificateCredential) GetToken(ctx context.Context, opts policy.TokenRequestOptions) (azcore.AccessToken, error) { return c.cred.GetToken(ctx, opts) } func (c *ClientCertificateCredential) NeedAuthenticate() bool { return false } func isPublicKeyEqual(key1, key2 *rsa.PublicKey) bool { if key1.N == nil || key2.N == nil { return false } return key1.E == key2.E && key1.N.Cmp(key2.N) == 0 } func splitPEMBlock(pemBlock []byte) (certPEM []byte, keyPEM []byte) { for { var derBlock *pem.Block derBlock, pemBlock = pem.Decode(pemBlock) if derBlock == nil { break } if derBlock.Type == "CERTIFICATE" { certPEM = append(certPEM, pem.EncodeToMemory(derBlock)...) } else if derBlock.Type == "PRIVATE KEY" { keyPEM = append(keyPEM, pem.EncodeToMemory(derBlock)...) } } return certPEM, keyPEM } func parseRsaPrivateKey(privateKeyPEM []byte) (*rsa.PrivateKey, error) { block, _ := pem.Decode(privateKeyPEM) if block == nil { return nil, fmt.Errorf("failed to decode a pem block from private key") } privatePkcs1Key, errPkcs1 := x509.ParsePKCS1PrivateKey(block.Bytes) if errPkcs1 == nil { return privatePkcs1Key, nil } privatePkcs8Key, errPkcs8 := x509.ParsePKCS8PrivateKey(block.Bytes) if errPkcs8 == nil { privatePkcs8RsaKey, ok := privatePkcs8Key.(*rsa.PrivateKey) if !ok { return nil, fmt.Errorf("pkcs8 contained non-RSA key. Expected RSA key") } return privatePkcs8RsaKey, nil } return nil, fmt.Errorf("failed to parse private key as Pkcs#1 or Pkcs#8. (%w), (%w)", errPkcs1, errPkcs8) } func parseKeyPairFromPEMBlock(pemBlock []byte) (*x509.Certificate, *rsa.PrivateKey, error) { certPEM, keyPEM := splitPEMBlock(pemBlock) privateKey, err := parseRsaPrivateKey(keyPEM) if err != nil { return nil, nil, err } found := false var cert *x509.Certificate for { var certBlock *pem.Block var err error certBlock, certPEM = pem.Decode(certPEM) if certBlock == nil { break } cert, err = x509.ParseCertificate(certBlock.Bytes) if err != nil { return nil, nil, fmt.Errorf("unable to parse certificate: %w", err) } certPublicKey, ok := cert.PublicKey.(*rsa.PublicKey) if ok && isPublicKeyEqual(certPublicKey, &privateKey.PublicKey) { found = true break } } if !found { return nil, nil, fmt.Errorf("unable to find a matching public certificate") } return cert, privateKey, nil } func decodePkcs12(pkcs []byte, password string) (*x509.Certificate, *rsa.PrivateKey, error) { blocks, err := pkcs12.ToPEM(pkcs, password) if err != nil { return nil, nil, err } var pemData []byte for _, b := range blocks { pemData = append(pemData, pem.EncodeToMemory(b)...) } return parseKeyPairFromPEMBlock(pemData) } func readCertificate(certFile, password string) (*x509.Certificate, *rsa.PrivateKey, error) { if strings.HasSuffix(certFile, ".pfx") { cert, err := os.ReadFile(certFile) if err != nil { return nil, nil, fmt.Errorf("failed to read the certificate file (%s): %w", certFile, err) } return decodePkcs12(cert, password) } else { cert, err := os.ReadFile(certFile) if err != nil { return nil, nil, fmt.Errorf("failed to read the certificate file (%s): %w", certFile, err) } return parseKeyPairFromPEMBlock(cert) } }