security/pkg/pki/util/keycertbundle.go (250 lines of code) (raw):

// Copyright Istio Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // Provides utility methods to generate X.509 certificates with different // options. This implementation is Largely inspired from // https://golang.org/src/crypto/tls/generate_cert.go. package util import ( "crypto" "crypto/ecdsa" "crypto/rsa" "crypto/tls" "crypto/x509" "errors" "fmt" "os" "sync" "time" ) // KeyCertBundle stores the cert, private key, cert chain and root cert for an entity. It is thread safe. // The cert and privKey should be a public/private key pair. // The cert should be verifiable from the rootCert through the certChain. // cert and priveKey are pointers to the cert/key parsed from certBytes/privKeyBytes. type KeyCertBundle struct { certBytes []byte cert *x509.Certificate privKeyBytes []byte privKey *crypto.PrivateKey certChainBytes []byte rootCertBytes []byte // mutex protects the R/W to all keys and certs. mutex sync.RWMutex } // NewKeyCertBundleFromPem returns a new KeyCertBundle, regardless of whether or not the key can be correctly parsed. func NewKeyCertBundleFromPem(certBytes, privKeyBytes, certChainBytes, rootCertBytes []byte) *KeyCertBundle { bundle := &KeyCertBundle{} bundle.setAllFromPem(certBytes, privKeyBytes, certChainBytes, rootCertBytes) return bundle } // NewVerifiedKeyCertBundleFromPem returns a new KeyCertBundle, or error if the provided certs failed the // verification. func NewVerifiedKeyCertBundleFromPem(certBytes, privKeyBytes, certChainBytes, rootCertBytes []byte) ( *KeyCertBundle, error) { bundle := &KeyCertBundle{} if err := bundle.VerifyAndSetAll(certBytes, privKeyBytes, certChainBytes, rootCertBytes); err != nil { return nil, err } return bundle, nil } // NewVerifiedKeyCertBundleFromFile returns a new KeyCertBundle, or error if the provided certs failed the // verification. func NewVerifiedKeyCertBundleFromFile(certFile string, privKeyFile string, certChainFiles []string, rootCertFile string) ( *KeyCertBundle, error, ) { certBytes, err := os.ReadFile(certFile) if err != nil { return nil, err } privKeyBytes, err := os.ReadFile(privKeyFile) if err != nil { return nil, err } var certChainBytes []byte if len(certChainFiles) > 0 { for _, f := range certChainFiles { var b []byte if b, err = os.ReadFile(f); err != nil { return nil, err } certChainBytes = append(certChainBytes, b...) } } rootCertBytes, err := os.ReadFile(rootCertFile) if err != nil { return nil, err } return NewVerifiedKeyCertBundleFromPem(certBytes, privKeyBytes, certChainBytes, rootCertBytes) } // NewKeyCertBundleWithRootCertFromFile returns a new KeyCertBundle with the root cert without verification. func NewKeyCertBundleWithRootCertFromFile(rootCertFile string) (*KeyCertBundle, error) { var rootCertBytes []byte var err error if rootCertFile == "" { rootCertBytes = []byte{} } else { rootCertBytes, err = os.ReadFile(rootCertFile) if err != nil { return nil, err } } return &KeyCertBundle{ certBytes: []byte{}, cert: nil, privKeyBytes: []byte{}, privKey: nil, certChainBytes: []byte{}, rootCertBytes: rootCertBytes, }, nil } // GetAllPem returns all key/cert PEMs in KeyCertBundle together. Getting all values together avoids inconsistency. func (b *KeyCertBundle) GetAllPem() (certBytes, privKeyBytes, certChainBytes, rootCertBytes []byte) { b.mutex.RLock() certBytes = copyBytes(b.certBytes) privKeyBytes = copyBytes(b.privKeyBytes) certChainBytes = copyBytes(b.certChainBytes) rootCertBytes = copyBytes(b.rootCertBytes) b.mutex.RUnlock() return } // GetAll returns all key/cert in KeyCertBundle together. Getting all values together avoids inconsistency. // NOTE: Callers should not modify the content of cert and privKey. func (b *KeyCertBundle) GetAll() (cert *x509.Certificate, privKey *crypto.PrivateKey, certChainBytes, rootCertBytes []byte) { b.mutex.RLock() cert = b.cert privKey = b.privKey certChainBytes = copyBytes(b.certChainBytes) rootCertBytes = copyBytes(b.rootCertBytes) b.mutex.RUnlock() return } // GetCertChainPem returns the certificate chain PEM. func (b *KeyCertBundle) GetCertChainPem() []byte { b.mutex.RLock() defer b.mutex.RUnlock() return copyBytes(b.certChainBytes) } // GetRootCertPem returns the root certificate PEM. func (b *KeyCertBundle) GetRootCertPem() []byte { b.mutex.RLock() defer b.mutex.RUnlock() return copyBytes(b.rootCertBytes) } // VerifyAndSetAll verifies the key/certs, and sets all key/certs in KeyCertBundle together. // Setting all values together avoids inconsistency. func (b *KeyCertBundle) VerifyAndSetAll(certBytes, privKeyBytes, certChainBytes, rootCertBytes []byte) error { if err := Verify(certBytes, privKeyBytes, certChainBytes, rootCertBytes); err != nil { return err } b.setAllFromPem(certBytes, privKeyBytes, certChainBytes, rootCertBytes) return nil } // Setting all values together avoids inconsistency. func (b *KeyCertBundle) setAllFromPem(certBytes, privKeyBytes, certChainBytes, rootCertBytes []byte) { b.mutex.Lock() b.certBytes = copyBytes(certBytes) b.privKeyBytes = copyBytes(privKeyBytes) b.certChainBytes = copyBytes(certChainBytes) b.rootCertBytes = copyBytes(rootCertBytes) // cert and privKey are always reset to point to new addresses. This avoids modifying the pointed structs that // could be still used outside of the class. b.cert, _ = ParsePemEncodedCertificate(certBytes) privKey, _ := ParsePemEncodedKey(privKeyBytes) b.privKey = &privKey b.mutex.Unlock() } // CertOptions returns the certificate config based on currently stored cert. func (b *KeyCertBundle) CertOptions() (*CertOptions, error) { b.mutex.RLock() defer b.mutex.RUnlock() ids, err := ExtractIDs(b.cert.Extensions) if err != nil { return nil, fmt.Errorf("failed to extract id %v", err) } if len(ids) != 1 { return nil, fmt.Errorf("expect single id from the cert, found %v", ids) } opts := &CertOptions{ Host: ids[0], Org: b.cert.Issuer.Organization[0], IsCA: b.cert.IsCA, TTL: b.cert.NotAfter.Sub(b.cert.NotBefore), IsDualUse: ids[0] == b.cert.Subject.CommonName, } switch (*b.privKey).(type) { case *rsa.PrivateKey: size, err := GetRSAKeySize(*b.privKey) if err != nil { return nil, fmt.Errorf("failed to get RSA key size: %v", err) } opts.RSAKeySize = size case *ecdsa.PrivateKey: opts.ECSigAlg = EcdsaSigAlg default: return nil, errors.New("unknown private key type") } return opts, nil } // UpdateVerifiedKeyCertBundleFromFile Verifies and updates KeyCertBundle with new certs func (b *KeyCertBundle) UpdateVerifiedKeyCertBundleFromFile(certFile string, privKeyFile string, certChainFiles []string, rootCertFile string) error { certBytes, err := os.ReadFile(certFile) if err != nil { return err } privKeyBytes, err := os.ReadFile(privKeyFile) if err != nil { return err } certChainBytes := []byte{} if len(certChainFiles) != 0 { for _, f := range certChainFiles { var b []byte if b, err = os.ReadFile(f); err != nil { return err } certChainBytes = append(certChainBytes, b...) } } rootCertBytes, err := os.ReadFile(rootCertFile) if err != nil { return err } err = b.VerifyAndSetAll(certBytes, privKeyBytes, certChainBytes, rootCertBytes) if err != nil { return err } return nil } // ExtractRootCertExpiryTimestamp returns the unix timestamp when the root becomes expires. func (b *KeyCertBundle) ExtractRootCertExpiryTimestamp() (float64, error) { return extractCertExpiryTimestamp("root cert", b.GetRootCertPem()) } // ExtractCACertExpiryTimestamp returns the unix timestamp when the cert chain becomes expires. func (b *KeyCertBundle) ExtractCACertExpiryTimestamp() (float64, error) { return extractCertExpiryTimestamp("CA cert", b.GetCertChainPem()) } // TimeBeforeCertExpires returns the time duration before the cert gets expired. // It returns an error if it failed to extract the cert expiration timestamp. // The returned time duration could be a negative value indicating the cert has already been expired. func TimeBeforeCertExpires(certBytes []byte, now time.Time) (time.Duration, error) { if len(certBytes) == 0 { return 0, fmt.Errorf("no certificate found") } certExpiryTimestamp, err := extractCertExpiryTimestamp("cert", certBytes) if err != nil { return 0, fmt.Errorf("failed to extract cert expiration timestamp: %v", err) } certExpiry := time.Duration(certExpiryTimestamp-float64(now.Unix())) * time.Second return certExpiry, nil } // Verify that the cert chain, root cert and key/cert match. func Verify(certBytes, privKeyBytes, certChainBytes, rootCertBytes []byte) error { // Verify the cert can be verified from the root cert through the cert chain. rcp := x509.NewCertPool() rcp.AppendCertsFromPEM(rootCertBytes) icp := x509.NewCertPool() icp.AppendCertsFromPEM(certChainBytes) opts := x509.VerifyOptions{ Intermediates: icp, Roots: rcp, } cert, err := ParsePemEncodedCertificate(certBytes) if err != nil { return fmt.Errorf("failed to parse cert PEM: %v", err) } chains, err := cert.Verify(opts) if len(chains) == 0 || err != nil { return fmt.Errorf( "cannot verify the cert with the provided root chain and cert "+ "pool with error: %v", err) } // Verify that the key can be correctly parsed. if _, err = ParsePemEncodedKey(privKeyBytes); err != nil { return fmt.Errorf("failed to parse private key PEM: %v", err) } // Verify the cert and key match. if _, err := tls.X509KeyPair(certBytes, privKeyBytes); err != nil { return fmt.Errorf("the cert does not match the key") } return nil } func extractCertExpiryTimestamp(certType string, certPem []byte) (float64, error) { cert, err := ParsePemEncodedCertificate(certPem) if err != nil { return -1, fmt.Errorf("failed to parse the %s: %v", certType, err) } end := cert.NotAfter expiryTimestamp := float64(end.Unix()) if end.Before(time.Now()) { return expiryTimestamp, fmt.Errorf("expired %s found, x509.NotAfter %v, please transit your %s", certType, end, certType) } return expiryTimestamp, nil } func copyBytes(src []byte) []byte { bs := make([]byte, len(src)) copy(bs, src) return bs }