pkg/decrypt/decypt_windows.go (133 lines of code) (raw):

// Copyright (c) Microsoft Corporation. // Licensed under the MIT License. package decrypt import ( "bytes" "fmt" "strconv" "syscall" "unsafe" "github.com/Azure/azure-extension-platform/pkg/extensionerrors" "github.com/Azure/azure-extension-platform/pkg/internal/crypto" "golang.org/x/sys/windows" ) type cryptDecryptMessagePara struct { cbSize uint32 dwMsgAndCertEncodingType uint32 cCertStore uint32 rghCertStore uintptr dwFlags uint32 } // decryptProtectedSettings decrypts the read protected settings using certificates func DecryptProtectedSettings(configFolder string, thumbprint string, decoded []byte) (string, error) { // Open My/Local handle, err := syscall.CertOpenStore(windows.CERT_STORE_PROV_SYSTEM, 0, 0, windows.CERT_SYSTEM_STORE_LOCAL_MACHINE, uintptr(unsafe.Pointer(syscall.StringToUTF16Ptr("MY")))) if err != nil { return "", fmt.Errorf("VMextension: Cannot open certificate store due to '%v'", err) } if handle == 0 { return "", extensionerrors.ErrMustRunAsAdmin } defer syscall.CertCloseStore(handle, 0) // Convert the thumbprint to bytes. We do byte comparison vs string comparison because otherwise we'd need to normalize the strings decodedThumbprint, err := thumbprintStringToHex(thumbprint) if err != nil { return "", fmt.Errorf("VmExtension: Invalid thumbprint") } // Find the certificate by thumbprint const crypteENotFound = 0x80092004 var cert *syscall.CertContext var prevContext *syscall.CertContext found := false for { // Keep retrieving the next certificate cert, err := syscall.CertEnumCertificatesInStore(handle, prevContext) if err != nil { if errno, ok := err.(syscall.Errno); ok { if errno == crypteENotFound { // We've reached the last certificate break } } return "", fmt.Errorf("VmExtension: Could not enumerate certificates due to '%v'", err) } if cert == nil { break } // Determine the cert thumbprint foundthumbprint, err := crypto.GetCertificateThumbprint(cert) if err == nil && foundthumbprint != nil { // TODO: consider logging if we have an error. For now, we just ignore the cert if bytes.Compare(decodedThumbprint, foundthumbprint) == 0 { found = true break } } prevContext = cert } if !found { return "", extensionerrors.ErrCertWithThumbprintNotFound } // Decrypt the protected settings decryptedBytes, err := decryptDataWithCert(decoded, cert, uintptr(handle)) if err != nil { return "", err } // Now deserialize the data v := string(decryptedBytes[:]) return v, err } func thumbprintStringToHex(s string) ([]byte, error) { // Remove the UTF mark if we have one runes := []rune(s) if len(runes)%2 == 1 { runes = []rune(s)[1:] } length := len(runes) / 2 parts := make([]byte, length) for count := 0; count < length; count++ { r := runes[count*2 : count*2+2] sp := string(r) bp, err := strconv.ParseUint(sp, 16, 16) if err == nil { parts[count] = byte(bp) } } return parts, nil } // decryptDataWithCert calls the Windows APIs to do the decryption func decryptDataWithCert(decoded []byte, cert *syscall.CertContext, storeHandle uintptr) ([]byte, error) { var cryptDecryptMessagePara cryptDecryptMessagePara cryptDecryptMessagePara.cbSize = uint32(len(decoded)) cryptDecryptMessagePara.dwMsgAndCertEncodingType = uint32(windows.X509_ASN_ENCODING | windows.PKCS_7_ASN_ENCODING) cryptDecryptMessagePara.cCertStore = uint32(1) cryptDecryptMessagePara.rghCertStore = uintptr(unsafe.Pointer(&storeHandle)) cryptDecryptMessagePara.dwFlags = uint32(0) // Call it once to get the decrypted data size var pbEncryptedBlob *byte var cbDecryptedBlob uint32 pbEncryptedBlob = &decoded[0] raw, _, err := syscall.Syscall6( crypto.ProcCryptDecryptMessage.Addr(), 6, uintptr(unsafe.Pointer(&cryptDecryptMessagePara)), uintptr(unsafe.Pointer(pbEncryptedBlob)), uintptr(len(decoded)), uintptr(0), uintptr(unsafe.Pointer(&cbDecryptedBlob)), uintptr(0), ) if raw == 0 { errno := syscall.Errno(err) if errno == crypto.CrypteEAsn1BadTag { return nil, extensionerrors.ErrInvalidProtectedSettingsData } return nil, fmt.Errorf("VmExtension: Could not decrypt data due to '%d'", syscall.Errno(err)) } // Create our buffer if cbDecryptedBlob == 0 { return nil, nil } var decryptedBytes = make([]byte, cbDecryptedBlob) var pdecryptedBytes *byte pdecryptedBytes = &decryptedBytes[0] raw, _, err = syscall.Syscall6( crypto.ProcCryptDecryptMessage.Addr(), 6, uintptr(unsafe.Pointer(&cryptDecryptMessagePara)), uintptr(unsafe.Pointer(pbEncryptedBlob)), uintptr(len(decoded)), uintptr(unsafe.Pointer(pdecryptedBytes)), uintptr(unsafe.Pointer(&cbDecryptedBlob)), uintptr(0), ) if raw == 0 { return nil, fmt.Errorf("VmExtension: Could not decrypt data due to '%d'", syscall.Errno(err)) } // Get rid of the null terminator or deserialization will fail returnedBytes := decryptedBytes[:cbDecryptedBlob] return returnedBytes, nil }