pkg/provider/provider.go (601 lines of code) (raw):

package provider import ( "crypto/ecdsa" "crypto/elliptic" "crypto/rsa" "crypto/x509" "encoding/base64" "encoding/hex" "encoding/pem" "fmt" "math/big" "os" "path/filepath" "reflect" "regexp" "sort" "strconv" "strings" "time" "github.com/Azure/secrets-store-csi-driver-provider-azure/pkg/auth" "github.com/Azure/secrets-store-csi-driver-provider-azure/pkg/metrics" "github.com/Azure/secrets-store-csi-driver-provider-azure/pkg/provider/types" "github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/azkeys" "github.com/Azure/go-autorest/autorest/azure" "github.com/pkg/errors" "golang.org/x/crypto/pkcs12" "golang.org/x/net/context" "gopkg.in/yaml.v3" "k8s.io/klog/v2" ) // Provider implements the secrets-store-csi-driver provider interface type Interface interface { GetSecretsStoreObjectContent(ctx context.Context, attrib, secrets map[string]string, defaultFilePermission os.FileMode) ([]types.SecretFile, error) } type provider struct { reporter metrics.StatsReporter constructPEMChain bool writeCertAndKeyInSeparateFiles bool defaultCloudEnvironment azure.Environment } // mountConfig holds the information for the mount event type mountConfig struct { // the name of the Azure Key Vault instance keyvaultName string // the type of azure cloud based on azure go sdk azureCloudEnvironment azure.Environment // authConfig is the config parameters for accessing Key Vault authConfig auth.Config // tenantID in AAD tenantID string // podName is the pod name podName string // podNamespace is the pod namespace podNamespace string } type keyvaultObject struct { content string fileNameSuffix string version string } // NewProvider creates a new provider func NewProvider(constructPEMChain, writeCertAndKeyInSeparateFiles bool, defaultCloudEnvironment azure.Environment) Interface { return &provider{ reporter: metrics.NewStatsReporter(), constructPEMChain: constructPEMChain, writeCertAndKeyInSeparateFiles: writeCertAndKeyInSeparateFiles, defaultCloudEnvironment: defaultCloudEnvironment, } } // parseAzureEnvironment returns azure environment by name func (p *provider) parseAzureEnvironment(cloudName string) (azure.Environment, error) { if cloudName == "" { return p.defaultCloudEnvironment, nil } return azure.EnvironmentFromName(cloudName) } func (mc *mountConfig) initializeKvClient(vaultURI string) (KeyVault, error) { kvEndpoint := strings.TrimSuffix(mc.azureCloudEnvironment.KeyVaultEndpoint, "/") cred, err := mc.authConfig.GetCredential(mc.podName, mc.podNamespace, kvEndpoint, mc.azureCloudEnvironment.ActiveDirectoryEndpoint, mc.tenantID, types.PodIdentityNMIPort) if err != nil { return nil, err } return NewClient(cred, vaultURI) } func (mc *mountConfig) getVaultURL() (vaultURL *string, err error) { // Key Vault name must be a 3-24 character string if len(mc.keyvaultName) < 3 || len(mc.keyvaultName) > 24 { return nil, errors.Errorf("Invalid vault name: %q, must be between 3 and 24 chars", mc.keyvaultName) } // See docs for validation spec: https://docs.microsoft.com/en-us/azure/key-vault/about-keys-secrets-and-certificates#objects-identifiers-and-versioning isValid := regexp.MustCompile(`^[-A-Za-z0-9]+$`).MatchString if !isValid(mc.keyvaultName) { return nil, errors.Errorf("Invalid vault name: %q, must match [-a-zA-Z0-9]{3,24}", mc.keyvaultName) } vaultDNSSuffixValue := mc.azureCloudEnvironment.KeyVaultDNSSuffix vaultURI := "https://" + mc.keyvaultName + "." + vaultDNSSuffixValue + "/" return &vaultURI, nil } // GetSecretsStoreObjectContent gets the objects (secret, key, certificate) from keyvault and returns the content // to the CSI driver. The driver will write the content to the file system. func (p *provider) GetSecretsStoreObjectContent(ctx context.Context, attrib, secrets map[string]string, defaultFilePermission os.FileMode) ([]types.SecretFile, error) { keyvaultName := types.GetKeyVaultName(attrib) cloudName := types.GetCloudName(attrib) userAssignedIdentityID := types.GetUserAssignedIdentityID(attrib) tenantID := types.GetTenantID(attrib) cloudEnvFileName := types.GetCloudEnvFileName(attrib) podName := types.GetPodName(attrib) podNamespace := types.GetPodNamespace(attrib) usePodIdentity, err := types.GetUsePodIdentity(attrib) if err != nil { return nil, fmt.Errorf("failed to parse usePodIdentity flag, error: %w", err) } useVMManagedIdentity, err := types.GetUseVMManagedIdentity(attrib) if err != nil { return nil, fmt.Errorf("failed to parse useVMManagedIdentity flag, error: %w", err) } // attributes for workload identity workloadIdentityClientID := types.GetClientID(attrib) saTokens := types.GetServiceAccountTokens(attrib) if keyvaultName == "" { return nil, fmt.Errorf("keyvaultName is not set") } if tenantID == "" { return nil, fmt.Errorf("tenantId is not set") } err = setAzureEnvironmentFilePath(cloudEnvFileName) if err != nil { return nil, fmt.Errorf("failed to set AZURE_ENVIRONMENT_FILEPATH env to %s, error %w", cloudEnvFileName, err) } azureCloudEnv, err := p.parseAzureEnvironment(cloudName) if err != nil { return nil, fmt.Errorf("cloudName %s is not valid, error: %w", cloudName, err) } // parse bound service account tokens for workload identity only if the clientID is set var workloadIdentityToken string if workloadIdentityClientID != "" { if workloadIdentityToken, err = auth.ParseServiceAccountToken(saTokens); err != nil { return nil, fmt.Errorf("failed to parse workload identity tokens, error: %w", err) } } authConfig, err := auth.NewConfig(usePodIdentity, useVMManagedIdentity, userAssignedIdentityID, workloadIdentityClientID, workloadIdentityToken, secrets) if err != nil { return nil, fmt.Errorf("failed to create auth config, error: %w", err) } mc := &mountConfig{ keyvaultName: keyvaultName, azureCloudEnvironment: azureCloudEnv, authConfig: authConfig, tenantID: tenantID, podName: podName, podNamespace: podNamespace, } objectsStrings := types.GetObjects(attrib) if objectsStrings == "" { return nil, fmt.Errorf("objects is not set") } klog.V(2).InfoS("objects string defined in secret provider class", "objects", objectsStrings, "pod", klog.ObjectRef{Namespace: podNamespace, Name: podName}) objects, err := types.GetObjectsArray(objectsStrings) if err != nil { return nil, fmt.Errorf("failed to yaml unmarshal objects, error: %w", err) } klog.V(2).InfoS("unmarshaled objects yaml array", "objectsArray", objects.Array, "pod", klog.ObjectRef{Namespace: podNamespace, Name: podName}) keyVaultObjects := []types.KeyVaultObject{} for i, object := range objects.Array { var keyVaultObject types.KeyVaultObject err = yaml.Unmarshal([]byte(object), &keyVaultObject) if err != nil { return nil, fmt.Errorf("unmarshal failed for keyVaultObjects at index %d, error: %w", i, err) } // remove whitespace from all fields in keyVaultObject formatKeyVaultObject(&keyVaultObject) if err = validate(keyVaultObject); err != nil { return nil, wrapObjectTypeError(err, keyVaultObject.ObjectType, keyVaultObject.ObjectName, keyVaultObject.ObjectVersion) } keyVaultObjects = append(keyVaultObjects, keyVaultObject) } klog.V(5).InfoS("unmarshaled key vault objects", "keyVaultObjects", keyVaultObjects, "count", len(keyVaultObjects), "pod", klog.ObjectRef{Namespace: podNamespace, Name: podName}) if len(keyVaultObjects) == 0 { return nil, nil } vaultURL, err := mc.getVaultURL() if err != nil { return nil, errors.Wrap(err, "failed to get vault") } klog.V(2).InfoS("vault url", "vaultName", mc.keyvaultName, "vaultURL", *vaultURL, "pod", klog.ObjectRef{Namespace: podNamespace, Name: podName}) // the keyvault name is per SPC and we don't need to recreate the client for every single keyvault object defined kvClient, err := mc.initializeKvClient(*vaultURL) if err != nil { return nil, errors.Wrap(err, "failed to get keyvault client") } files := []types.SecretFile{} for _, keyVaultObject := range keyVaultObjects { klog.V(5).InfoS("fetching object from key vault", "objectName", keyVaultObject.ObjectName, "objectType", keyVaultObject.ObjectType, "keyvault", mc.keyvaultName, "pod", klog.ObjectRef{Namespace: podNamespace, Name: podName}) resolvedKvObjects, err := p.resolveObjectVersions(ctx, kvClient, keyVaultObject) if err != nil { return nil, err } for _, resolvedKvObject := range resolvedKvObjects { // fetch the object from Key Vault result, err := p.getKeyVaultObjectContent(ctx, kvClient, resolvedKvObject) if err != nil { return nil, err } for idx := range result { r := result[idx] objectContent, err := getContentBytes(r.content, resolvedKvObject.ObjectType, resolvedKvObject.ObjectEncoding) if err != nil { return nil, err } // objectUID is a unique identifier in the format <object type>/<object name> // This is the object id the user sees in the SecretProviderClassPodStatus objectUID := resolvedKvObject.GetObjectUID() file := types.SecretFile{ Path: resolvedKvObject.GetFileName() + r.fileNameSuffix, Content: objectContent, UID: objectUID, Version: r.version, } // the validity of file permission is already checked in the validate function above file.FileMode, _ = resolvedKvObject.GetFilePermission(defaultFilePermission) files = append(files, file) klog.V(5).InfoS("added file to the gRPC response", "file", file.Path, "pod", klog.ObjectRef{Namespace: podNamespace, Name: podName}) } } } return files, nil } func (p *provider) resolveObjectVersions(ctx context.Context, kvClient KeyVault, kvObject types.KeyVaultObject) (versions []types.KeyVaultObject, err error) { if kvObject.IsSyncingSingleVersion() { // version history less than or equal to 1 means only sync the latest and // don't add anything to the file name return []types.KeyVaultObject{kvObject}, nil } kvObjectVersions, err := p.getKeyVaultObjectVersions(ctx, kvClient, kvObject) if err != nil { return nil, err } return getLatestNKeyVaultObjects(kvObject, kvObjectVersions), nil } /* Given a base key vault object and a list of object versions and their created dates, find the latest kvObject.ObjectVersionHistory versions and return key vault objects with the appropriate alias and version. The alias is determine by the index of the version starting with 0 at the specified version (or latest if no version is specified). */ func getLatestNKeyVaultObjects(kvObject types.KeyVaultObject, kvObjectVersions types.KeyVaultObjectVersionList) []types.KeyVaultObject { baseFileName := kvObject.GetFileName() objects := []types.KeyVaultObject{} sort.Sort(kvObjectVersions) // if we're being asked for the latest, then there's no need to skip any versions foundFirst := kvObject.ObjectVersion == "" || kvObject.ObjectVersion == "latest" for _, objectVersion := range kvObjectVersions { foundFirst = foundFirst || objectVersion.Version == kvObject.ObjectVersion if foundFirst { length := len(objects) newObject := kvObject newObject.ObjectAlias = filepath.Join(baseFileName, strconv.Itoa(length)) newObject.ObjectVersion = objectVersion.Version objects = append(objects, newObject) if length+1 > int(kvObject.ObjectVersionHistory) { break } } } return objects } func (p *provider) getKeyVaultObjectVersions(ctx context.Context, kvClient KeyVault, kvObject types.KeyVaultObject) (versions types.KeyVaultObjectVersionList, err error) { start := time.Now() defer func() { var errMsg string if err != nil { errMsg = err.Error() } p.reporter.ReportKeyvaultRequest(ctx, time.Since(start).Seconds(), kvObject.ObjectType, kvObject.ObjectName, errMsg) }() switch kvObject.ObjectType { case types.VaultObjectTypeSecret: return getSecretVersions(ctx, kvClient, kvObject) case types.VaultObjectTypeKey: return getKeyVersions(ctx, kvClient, kvObject) case types.VaultObjectTypeCertificate: return getCertificateVersions(ctx, kvClient, kvObject) default: err := errors.Errorf("Invalid vaultObjectTypes. Should be secret, key, or cert") return nil, wrapObjectTypeError(err, kvObject.ObjectType, kvObject.ObjectName, kvObject.ObjectVersion) } } func getSecretVersions(ctx context.Context, kvClient KeyVault, kvObject types.KeyVaultObject) ([]types.KeyVaultObjectVersion, error) { return kvClient.GetSecretVersions(ctx, kvObject.ObjectName) } func getKeyVersions(ctx context.Context, kvClient KeyVault, kvObject types.KeyVaultObject) ([]types.KeyVaultObjectVersion, error) { return kvClient.GetKeyVersions(ctx, kvObject.ObjectName) } func getCertificateVersions(ctx context.Context, kvClient KeyVault, kvObject types.KeyVaultObject) ([]types.KeyVaultObjectVersion, error) { return kvClient.GetCertificateVersions(ctx, kvObject.ObjectName) } // getKeyVaultObjectContent gets content of the keyvault object func (p *provider) getKeyVaultObjectContent(ctx context.Context, kvClient KeyVault, kvObject types.KeyVaultObject) (result []keyvaultObject, err error) { start := time.Now() defer func() { var errMsg string if err != nil { errMsg = err.Error() } p.reporter.ReportKeyvaultRequest(ctx, time.Since(start).Seconds(), kvObject.ObjectType, kvObject.ObjectName, errMsg) }() switch kvObject.ObjectType { case types.VaultObjectTypeSecret: return p.getSecret(ctx, kvClient, kvObject) case types.VaultObjectTypeKey: return p.getKey(ctx, kvClient, kvObject) case types.VaultObjectTypeCertificate: return p.getCertificate(ctx, kvClient, kvObject) default: err := errors.Errorf("Invalid vaultObjectTypes. Should be secret, key, or cert") return nil, wrapObjectTypeError(err, kvObject.ObjectType, kvObject.ObjectName, kvObject.ObjectVersion) } } // getSecret retrieves the secret from the vault func (p *provider) getSecret(ctx context.Context, kvClient KeyVault, kvObject types.KeyVaultObject) ([]keyvaultObject, error) { secret, err := kvClient.GetSecret(ctx, kvObject.ObjectName, kvObject.ObjectVersion) if err != nil { return nil, wrapObjectTypeError(err, kvObject.ObjectType, kvObject.ObjectName, kvObject.ObjectVersion) } if secret.Value == nil { return nil, errors.Errorf("secret value is nil") } if secret.ID == nil { return nil, errors.Errorf("secret id is nil") } content := *secret.Value id := *secret.ID version := id.Version() result := []keyvaultObject{} // if the secret is part of a certificate, then we need to convert the certificate and key to PEM format if secret.Kid != nil && len(*secret.Kid) > 0 { switch *secret.ContentType { case types.CertTypePem: case types.CertTypePfx: // object format requested is pfx, then return the content as is if strings.EqualFold(kvObject.ObjectFormat, types.ObjectFormatPFX) { break } // convert to pem as that's the default object format for this provider if content, err = p.decodePKCS12(*secret.Value); err != nil { return nil, wrapObjectTypeError(err, kvObject.ObjectType, kvObject.ObjectName, kvObject.ObjectVersion) } default: err := errors.Errorf("failed to get certificate. unknown content type '%s'", *secret.ContentType) return nil, wrapObjectTypeError(err, kvObject.ObjectType, kvObject.ObjectName, kvObject.ObjectVersion) } if p.writeCertAndKeyInSeparateFiles { // when writeCertAndKeyInSeparateFiles feature flag is enabled, we write the cert and key in separate files // with suffixes .crt and .key respectively. These files are written in addition to the default file which // contains the cert and key in a single file to maintain backward compatibility with the existing behavior. cert, key := splitCertAndKey(content) result = append(result, keyvaultObject{version: version, content: cert, fileNameSuffix: ".crt"}, keyvaultObject{version: version, content: key, fileNameSuffix: ".key"}, ) } } result = append(result, keyvaultObject{content: content, version: version}) return result, nil } // getKey retrieves the key from the vault func (p *provider) getKey(ctx context.Context, kvClient KeyVault, kvObject types.KeyVaultObject) ([]keyvaultObject, error) { keybundle, err := kvClient.GetKey(ctx, kvObject.ObjectName, kvObject.ObjectVersion) if err != nil { return nil, wrapObjectTypeError(err, kvObject.ObjectType, kvObject.ObjectName, kvObject.ObjectVersion) } if keybundle.Key == nil { return nil, errors.Errorf("key value is nil") } if keybundle.Key.KID == nil { return nil, errors.Errorf("key id is nil") } id := *keybundle.Key.KID version := id.Version() // for object type "key" the public key is written to the file in PEM format switch *keybundle.Key.Kty { case azkeys.JSONWebKeyTypeRSA, azkeys.JSONWebKeyTypeRSAHSM: nb := keybundle.Key.N eb := keybundle.Key.E e := new(big.Int).SetBytes(eb).Int64() pKey := &rsa.PublicKey{ N: new(big.Int).SetBytes(nb), E: int(e), } derBytes, err := x509.MarshalPKIXPublicKey(pKey) if err != nil { return nil, wrapObjectTypeError(err, kvObject.ObjectType, kvObject.ObjectName, kvObject.ObjectVersion) } pubKeyBlock := &pem.Block{ Type: "PUBLIC KEY", Bytes: derBytes, } var pemData []byte pemData = append(pemData, pem.EncodeToMemory(pubKeyBlock)...) return []keyvaultObject{{content: string(pemData), version: version}}, nil case azkeys.JSONWebKeyTypeEC, azkeys.JSONWebKeyTypeECHSM: xb := keybundle.Key.X yb := keybundle.Key.Y crv, err := getCurve(*keybundle.Key.Crv) if err != nil { return nil, wrapObjectTypeError(err, kvObject.ObjectType, kvObject.ObjectName, kvObject.ObjectVersion) } pKey := &ecdsa.PublicKey{ X: new(big.Int).SetBytes(xb), Y: new(big.Int).SetBytes(yb), Curve: crv, } derBytes, err := x509.MarshalPKIXPublicKey(pKey) if err != nil { return nil, wrapObjectTypeError(err, kvObject.ObjectType, kvObject.ObjectName, kvObject.ObjectVersion) } pubKeyBlock := &pem.Block{ Type: "PUBLIC KEY", Bytes: derBytes, } var pemData []byte pemData = append(pemData, pem.EncodeToMemory(pubKeyBlock)...) return []keyvaultObject{{content: string(pemData), version: version}}, nil default: err := errors.Errorf("failed to get key. key type '%s' currently not supported", *keybundle.Key.Kty) return nil, wrapObjectTypeError(err, kvObject.ObjectType, kvObject.ObjectName, kvObject.ObjectVersion) } } // getCertificate retrieves the certificate from the vault func (p *provider) getCertificate(ctx context.Context, kvClient KeyVault, kvObject types.KeyVaultObject) ([]keyvaultObject, error) { // for object type "cert" the certificate is written to the file in PEM format certbundle, err := kvClient.GetCertificate(ctx, kvObject.ObjectName, kvObject.ObjectVersion) if err != nil { return nil, wrapObjectTypeError(err, kvObject.ObjectType, kvObject.ObjectName, kvObject.ObjectVersion) } if certbundle.CER == nil { return nil, errors.Errorf("certificate value is nil") } if certbundle.ID == nil { return nil, errors.Errorf("certificate id is nil") } id := *certbundle.ID version := id.Version() certBlock := &pem.Block{ Type: types.CertificateType, Bytes: certbundle.CER, } var pemData []byte pemData = append(pemData, pem.EncodeToMemory(certBlock)...) return []keyvaultObject{{content: string(pemData), version: version}}, nil } func wrapObjectTypeError(err error, objectType, objectName, objectVersion string) error { return errors.Wrapf(err, "failed to get objectType:%s, objectName:%s, objectVersion:%s", objectType, objectName, objectVersion) } // decodePkcs12 decodes PKCS#12 client certificates by extracting the public certificates, the private // keys and converts it to PEM format func (p *provider) decodePKCS12(value string) (content string, err error) { pfxRaw, err := base64.StdEncoding.DecodeString(value) if err != nil { return "", err } // using ToPEM to extract more than one certificate and key in pfxData pemBlock, err := pkcs12.ToPEM(pfxRaw, "") if err != nil { return "", err } var pemKeyData, pemCertData, pemData []byte for _, block := range pemBlock { // PEM block encoded form contains the headers // -----BEGIN Type----- // Headers // base64-encoded Bytes // -----END Type----- // Setting headers to nil to ensure no headers included in the encoded block block.Headers = make(map[string]string) if block.Type == types.CertificateType { pemCertData = append(pemCertData, pem.EncodeToMemory(block)...) } else { key, err := parsePrivateKey(block.Bytes) if err != nil { return "", err } // pkcs1 RSA private key PEM file is specific for RSA keys. RSA is not used exclusively inside X509 // and SSL/TLS, a more generic key format is available in the form of PKCS#8 that identifies the type // of private key and contains the relevant data. // Converting to pkcs8 private key as ToPEM uses pkcs1 // The driver determines the key type from the pkcs8 form of the key and marshals appropriately block.Bytes, err = x509.MarshalPKCS8PrivateKey(key) if err != nil { return "", err } pemKeyData = append(pemKeyData, pem.EncodeToMemory(block)...) } } // construct the pem chain in the order // SERVER, INTERMEDIATE, ROOT if p.constructPEMChain { pemCertData, err = fetchCertChains(pemCertData) if err != nil { return "", err } } pemData = append(pemData, pemKeyData...) pemData = append(pemData, pemCertData...) return string(pemData), nil } func getCurve(crv azkeys.JSONWebKeyCurveName) (elliptic.Curve, error) { switch crv { case azkeys.JSONWebKeyCurveNameP256: return elliptic.P256(), nil case azkeys.JSONWebKeyCurveNameP384: return elliptic.P384(), nil case azkeys.JSONWebKeyCurveNameP521: return elliptic.P521(), nil default: return nil, fmt.Errorf("curve %s is not supported", crv) } } func parsePrivateKey(block []byte) (interface{}, error) { if key, err := x509.ParsePKCS1PrivateKey(block); err == nil { return key, nil } if key, err := x509.ParsePKCS8PrivateKey(block); err == nil { return key, nil } if key, err := x509.ParseECPrivateKey(block); err == nil { return key, nil } return nil, fmt.Errorf("failed to parse key for type pkcs1, pkcs8 or ec") } // setAzureEnvironmentFilePath sets the AZURE_ENVIRONMENT_FILEPATH env var which is used by // go-autorest for AZURESTACKCLOUD func setAzureEnvironmentFilePath(envFileName string) error { if envFileName == "" { return nil } klog.V(5).InfoS("setting AZURE_ENVIRONMENT_FILEPATH for custom cloud", "fileName", envFileName) return os.Setenv(azure.EnvironmentFilepathName, envFileName) } // getContentBytes takes the given content string and returns the bytes to write to disk // If an encoding is specified it will decode the string first func getContentBytes(content, objectType, objectEncoding string) ([]byte, error) { if !strings.EqualFold(objectType, types.VaultObjectTypeSecret) || len(objectEncoding) == 0 || strings.EqualFold(objectEncoding, types.ObjectEncodingUtf8) { return []byte(content), nil } if strings.EqualFold(objectEncoding, types.ObjectEncodingBase64) { return base64.StdEncoding.DecodeString(content) } if strings.EqualFold(objectEncoding, types.ObjectEncodingHex) { return hex.DecodeString(content) } return make([]byte, 0), fmt.Errorf("invalid objectEncoding. Should be utf-8, base64, or hex") } // formatKeyVaultObject formats the fields in KeyVaultObject func formatKeyVaultObject(object *types.KeyVaultObject) { if object == nil { return } objectPtr := reflect.ValueOf(object) objectValue := objectPtr.Elem() for i := 0; i < objectValue.NumField(); i++ { field := objectValue.Field(i) if field.Type() != reflect.TypeOf("") { continue } str := field.Interface().(string) str = strings.TrimSpace(str) field.SetString(str) } } type node struct { cert *x509.Certificate parent *node isParent bool } // implementation xref: https://social.technet.microsoft.com/wiki/contents/articles/3147.pki-certificate-chaining-engine-cce.aspx#Building_the_Certificate_Chain func fetchCertChains(data []byte) ([]byte, error) { var newCertChain []*x509.Certificate var pemData []byte nodes := make([]*node, 0) currData := data for { // decode pem to der first block, rest := pem.Decode(currData) currData = rest if block == nil { break } cert, err := x509.ParseCertificate(block.Bytes) if err != nil { return pemData, err } // this should not be the case because ParseCertificate should return a non nil // certificate when there is no error. if cert == nil { return pemData, fmt.Errorf("certificate is nil") } nodes = append(nodes, &node{ cert: cert, parent: nil, isParent: false, }) } // at the end of this computation, the output will be a single linked list // the tail of the list will be the root node (which has no parents) // the head of the list will be the leaf node (whose parent will be intermediate certs) // (head) leaf -> intermediates -> root (tail) for i := range nodes { for j := range nodes { // ignore same node to prevent generating a cycle if i == j { continue } // a leaf cert SubjectKeyId is optional per RFC3280 if nodes[i].cert.AuthorityKeyId == nil && nodes[j].cert.SubjectKeyId == nil { continue } // if ith node AuthorityKeyId is same as jth node SubjectKeyId, jth node was used // to sign the ith certificate if string(nodes[i].cert.AuthorityKeyId) == string(nodes[j].cert.SubjectKeyId) { nodes[j].isParent = true nodes[i].parent = nodes[j] break } } } var leaf *node for i := range nodes { if !nodes[i].isParent { // this is the leaf node as it's not a parent for any other node // TODO (aramase) handle errors if there are more than 1 leaf nodes leaf = nodes[i] break } } if leaf == nil { return nil, fmt.Errorf("no leaf found") } processedNodes := 0 // iterate through the directed list and append the nodes to new cert chain for leaf != nil { processedNodes++ // ensure we aren't stuck in a cyclic loop if processedNodes > len(nodes) { return pemData, fmt.Errorf("constructing chain resulted in cycle") } newCertChain = append(newCertChain, leaf.cert) leaf = leaf.parent } if len(nodes) != len(newCertChain) { klog.Warning("certificate chain is not complete due to missing intermediate/root certificates in the cert from key vault") // if we're unable to construct the full chain, return the original order we got from the key vault return data, nil } for _, cert := range newCertChain { b := &pem.Block{ Type: types.CertificateType, Bytes: cert.Raw, } pemData = append(pemData, pem.EncodeToMemory(b)...) } return pemData, nil } // splitCertAndKey takes the given data and splits it into cert and key // this function doesn't check if the returned cert and key is not empty as this // can't be enforced. It is possible the secret in the key vault only contains the // cert or key. func splitCertAndKey(certAndKey string) (certs string, privKey string) { // split the cert and key for PEM format // This does not handle the case where cert and key is in PFX format // TODO(aramase) consider adding support for PFX format if there is an ask var cert, key []byte data := []byte(certAndKey) for { block, rest := pem.Decode(data) if block == nil { break } if block.Type == types.CertificateType { cert = append(cert, pem.EncodeToMemory(block)...) } else { key = append(key, pem.EncodeToMemory(block)...) } data = rest } certs = string(cert) privKey = string(key) return certs, privKey }