google_guest_agent/agentcrypto/mtls_mds_windows.go (239 lines of code) (raw):

// Copyright 2023 Google LLC // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // https://www.apache.org/licenses/LICENSE-2.0 // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package agentcrypto import ( "context" "crypto/rand" "crypto/x509" "encoding/pem" "fmt" "os" "path/filepath" "syscall" "unsafe" "github.com/GoogleCloudPlatform/guest-agent/utils" "github.com/GoogleCloudPlatform/guest-logging-go/logger" "golang.org/x/sys/windows" "software.sslmate.com/src/go-pkcs12" ) const ( // rootCACertFileName is the root CA cert. rootCACertFileName = "mds-mtls-root.crt" // clientCredsFileName are client credentials, its basically the file // that has the EC private key and the client certificate concatenated. clientCredsFileName = "mds-mtls-client.key" // pfxFile stores client credentials in PFX format. pfxFile = "mds-mtls-client.key.pfx" // https://learn.microsoft.com/en-us/windows/win32/seccrypto/system-store-locations // my is predefined personal cert store. my = "MY" // root is predefined cert store for root trusted CA certs. root = "ROOT" // certificateIssuer is the issuer of client/root certificates for MDS mTLS. certificateIssuer = "google.internal" // maxCertEnumeration specifies the maximum number of times to search for a certificate // with a serial number from a given issuer before giving up. maxCertEnumeration = 5 ) var ( // defaultCredsDir is the directory location for MTLS MDS credentials. defaultCredsDir = filepath.Join(os.Getenv("ProgramData"), "Google", "Compute Engine") prevCtx *windows.CertContext ) // writeRootCACert writes Root CA cert from UEFI variable to output file. func (j *CredsJob) writeRootCACert(_ context.Context, cacert []byte, outputFile string) error { // Try to fetch previous certificate's serial number before it gets overwritten. num, err := serialNumber(outputFile) if err != nil { logger.Debugf("No previous MDS root certificate was found, will skip cleanup: %v", err) } if err := utils.SaferWriteFile(cacert, outputFile, 0644); err != nil { return err } if !j.useNativeStore.Load() { logger.Debugf("SkipNativeStore is enabled, will not write root cert to certstore") return nil } x509Cert, err := parseCertificate(cacert) if err != nil { return fmt.Errorf("failed to parse root CA cert: %w", err) } // https://learn.microsoft.com/en-us/windows/win32/api/wincrypt/nf-wincrypt-certcreatecertificatecontext certContext, err := windows.CertCreateCertificateContext( windows.X509_ASN_ENCODING|windows.PKCS_7_ASN_ENCODING, &x509Cert.Raw[0], uint32(len(x509Cert.Raw))) if err != nil { return fmt.Errorf("CertCreateCertificateContext returned: %v", err) } // https://learn.microsoft.com/en-us/windows/win32/api/wincrypt/nf-wincrypt-certfreecertificatecontext defer windows.CertFreeCertificateContext(certContext) // Adds certificate to Root Trusted certificates. if err := addCtxToLocalSystemStore(root, certContext, uint32(windows.CERT_STORE_ADD_REPLACE_EXISTING)); err != nil { return fmt.Errorf("failed to store root cert ctx in store: %w", err) } // MDS root cert was not refreshed or there's no previous cert, nothing to do, return. if num == "" || fmt.Sprintf("%x", x509Cert.SerialNumber) == num { return nil } // Certificate is refreshed. Best effort to find the certcontext and delete it. // Don't throw error here, it would skip client credential generation which // may be about to expire. oldCtx, err := findCert(root, certificateIssuer, num) if err != nil { logger.Warningf("Failed to find previous MDS root certificate with error: %v", err) return nil } if err := deleteCert(oldCtx, root); err != nil { logger.Warningf("Failed to delete previous MDS root certificate(%s) with error: %v", num, err) return nil } return nil } // findCert finds and returns certificate issued by issuer with the serial number in the given the store. func findCert(storeName, issuer, certID string) (*windows.CertContext, error) { logger.Infof("Searching for certificate with serial number %s in store %s by issuer %s", certID, storeName, issuer) storeNamePtr, err := syscall.UTF16PtrFromString(storeName) if err != nil { return nil, fmt.Errorf("UTF16PtrFromString(%s) failed with error: %v", storeName, err) } issuerPtr, err := syscall.UTF16PtrFromString(issuer) if err != nil { return nil, fmt.Errorf("UTF16PtrFromString(%s) failed with error: %v", issuer, err) } st, err := windows.CertOpenStore( windows.CERT_STORE_PROV_SYSTEM, 0, 0, windows.CERT_SYSTEM_STORE_LOCAL_MACHINE, uintptr(unsafe.Pointer(storeNamePtr))) if err != nil { return nil, fmt.Errorf("failed to open cert store: %w", err) } defer windows.CertCloseStore(st, 0) // prev is used for enumerating through all the certificates that matches the issuer. // On the first call to the function this parameter is NULL on all subsequent calls, // this parameter is the last CertContext pointer returned by the CertFindCertificateInStore function var prev *windows.CertContext // maxCertEnumeration would avoid requiring a infinite loop that relies on enumerating // until we get nil crt. for i := 1; i <= maxCertEnumeration; i++ { logger.Debugf("Attempt %d, searching certificate...", i) // https://learn.microsoft.com/en-us/windows/win32/api/wincrypt/nf-wincrypt-certfindcertificateinstore crt, err := windows.CertFindCertificateInStore( st, windows.X509_ASN_ENCODING|windows.PKCS_7_ASN_ENCODING, 0, windows.CERT_FIND_ISSUER_STR, unsafe.Pointer(issuerPtr), prev) if err != nil { return nil, fmt.Errorf("unable to find certificate: %w", err) } if crt == nil { return nil, fmt.Errorf("no certificate by issuer %s with ID %s", issuer, certID) } x509Cert, err := certContextToX509(crt) if err != nil { return nil, fmt.Errorf("failed to parse certificate context: %w", err) } if fmt.Sprintf("%x", x509Cert.SerialNumber) == certID { return crt, nil } prev = crt } return nil, nil } // writeClientCredentials stores client credentials (certificate and private key). func (j *CredsJob) writeClientCredentials(creds []byte, outputFile string) error { num, err := serialNumber(outputFile) if err != nil { logger.Warningf("Could not get previous serial number, will skip cleanup: %v", err) } if err := utils.SaferWriteFile(creds, outputFile, 0644); err != nil { return fmt.Errorf("failed to write client key: %w", err) } pfx, err := generatePFX(creds) if err != nil { return fmt.Errorf("failed to generate PFX data from client credentials: %w", err) } p := filepath.Join(filepath.Dir(outputFile), pfxFile) if err := utils.SaferWriteFile(pfx, p, 0644); err != nil { return fmt.Errorf("failed to write PFX file: %w", err) } if !j.useNativeStore.Load() { logger.Debugf("SkipNativeStore is enabled, will not write root cert to certstore") return nil } blob := windows.CryptDataBlob{ Size: uint32(len(pfx)), Data: &pfx[0], } emptyPtr, err := syscall.UTF16PtrFromString("") if err != nil { return fmt.Errorf("UTF16PtrFromString(%q) empty pointer failed with error: %v", "", err) } // https://learn.microsoft.com/en-us/windows/win32/api/wincrypt/nf-wincrypt-pfximportcertstore handle, err := windows.PFXImportCertStore(&blob, emptyPtr, windows.CRYPT_MACHINE_KEYSET) if err != nil { return fmt.Errorf("failed to import PFX in cert store: %w", err) } defer windows.CertCloseStore(handle, 0) var crtCtx *windows.CertContext // https://learn.microsoft.com/en-us/windows/win32/api/wincrypt/nf-wincrypt-certenumcertificatesinstore crtCtx, err = windows.CertEnumCertificatesInStore(handle, crtCtx) if err != nil { return fmt.Errorf("failed to get cert context for PFX from store: %w", err) } defer windows.CertFreeCertificateContext(crtCtx) // Add certificate to personal store. if err := addCtxToLocalSystemStore(my, crtCtx, uint32(windows.CERT_STORE_ADD_NEWER)); err != nil { return fmt.Errorf("failed to store pfx cert context: %w", err) } // Search for previous certificate if its not already in memory. if prevCtx == nil && num != "" { prevCtx, err = findCert(my, certificateIssuer, num) if err != nil { logger.Warningf("Failed to find previous certificate with error: %v", err) } } // Remove previous certificate only after successful refresh. if err := deleteCert(prevCtx, my); err != nil { logger.Warningf("Failed to delete previous certificate(%s) with error: %v", num, err) } prevCtx = windows.CertDuplicateCertificateContext(crtCtx) return nil } // certContextToX509 creates an x509 Certificate from a Windows cert context. func certContextToX509(ctx *windows.CertContext) (*x509.Certificate, error) { der := unsafe.Slice(ctx.EncodedCert, int(ctx.Length)) return x509.ParseCertificate(der) } // generatePFX accepts certificate concatenated with private key and generates a PFX out of it. // https://learn.microsoft.com/en-us/windows-hardware/drivers/install/personal-information-exchange---pfx--files func generatePFX(creds []byte) (pfxData []byte, err error) { cert, key := pem.Decode(creds) x509Cert, err := x509.ParseCertificate(cert.Bytes) if err != nil { return []byte{}, fmt.Errorf("failed to parse client certificate: %w", err) } ecpvt, err := parsePvtKey(key) if err != nil { return []byte{}, fmt.Errorf("failed to parse EC PrivateKey from client credentials: %w", err) } return pkcs12.Encode(rand.Reader, ecpvt, x509Cert, nil, "") } func addCtxToLocalSystemStore(storeName string, certContext *windows.CertContext, disposition uint32) error { storeNamePtr, err := syscall.UTF16PtrFromString(storeName) if err != nil { return fmt.Errorf("UTF16PtrFromString(%s) failed with error: %v", storeName, err) } // https://learn.microsoft.com/en-us/windows/win32/api/wincrypt/nf-wincrypt-certopenstore // https://learn.microsoft.com/en-us/windows-hardware/drivers/install/local-machine-and-current-user-certificate-stores // https://learn.microsoft.com/en-us/windows/win32/seccrypto/system-store-locations#cert_system_store_local_machine st, err := windows.CertOpenStore( windows.CERT_STORE_PROV_SYSTEM, 0, 0, windows.CERT_SYSTEM_STORE_LOCAL_MACHINE, uintptr(unsafe.Pointer(storeNamePtr))) if err != nil { return fmt.Errorf("failed to open cert store: %w", err) } // https://learn.microsoft.com/en-us/windows/win32/api/wincrypt/nf-wincrypt-certclosestore defer windows.CertCloseStore(st, 0) // https://learn.microsoft.com/en-us/windows/win32/api/wincrypt/nf-wincrypt-certaddcertificatecontexttostore if err := windows.CertAddCertificateContextToStore(st, certContext, disposition, nil); err != nil { return fmt.Errorf("failed to add certificate context to store: %w", err) } return nil } func deleteCert(crtCtx *windows.CertContext, storeName string) error { if crtCtx == nil { return nil } storeNamePtr, err := syscall.UTF16PtrFromString(storeName) if err != nil { return fmt.Errorf("UTF16PtrFromString(%s) failed with error: %v", storeName, err) } st, err := windows.CertOpenStore( windows.CERT_STORE_PROV_SYSTEM, 0, 0, windows.CERT_SYSTEM_STORE_LOCAL_MACHINE, uintptr(unsafe.Pointer(storeNamePtr))) if err != nil { return fmt.Errorf("failed to open cert store: %w", err) } defer windows.CertCloseStore(st, 0) var dlCtx *windows.CertContext dlCtx, err = windows.CertFindCertificateInStore( st, windows.X509_ASN_ENCODING|windows.PKCS_7_ASN_ENCODING, 0, windows.CERT_FIND_EXISTING, unsafe.Pointer(crtCtx), dlCtx, ) if err != nil { return fmt.Errorf("unable to find the certificate in %q store to delete: %w", storeName, err) } return windows.CertDeleteCertificateFromStore(dlCtx) }