pkg/internal/crypto/crypto_windows.go (166 lines of code) (raw):
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
package crypto
import (
"fmt"
"syscall"
"time"
"unsafe"
)
const (
CertHashPropID = 3
CrypteEAsn1BadTag = 2148086027
certNCryptKeySpec = 0xFFFFFFFF
)
var (
Modcrypt32 = syscall.NewLazyDLL("crypt32.dll")
procCertGetCertificateContextProperty = Modcrypt32.NewProc("CertGetCertificateContextProperty")
ProcCryptDecryptMessage = Modcrypt32.NewProc("CryptDecryptMessage")
procCryptAcquireCertificatePrivateKey = Modcrypt32.NewProc("CryptAcquireCertificatePrivateKey")
procNCryptFreeObject = Modcrypt32.NewProc("NCryptFreeObject")
)
type CryptAlgorithmIdentifier struct {
PszObjID uintptr
Parameters CryptObjectIDBlob
}
type cryptIntegerBlob struct {
cbData uint32
pbData uintptr
}
type CryptObjectIDBlob struct {
CbData uint32
PbData uintptr
}
type cryptBitBlob struct {
cbData uint32
pbData uintptr
cUnusedBits uint32
}
type certNameBlob struct {
cbData uint32
pbData uintptr
}
type certPublicKeyInfo struct {
Algorithm CryptAlgorithmIdentifier
PublicKey cryptBitBlob
}
// This struct is not implemented in syscall, so we need to do this ourselves
type certInfo struct {
dwVersion uint32
serialNumber cryptIntegerBlob
signatureAlgorithm CryptAlgorithmIdentifier
issuer certNameBlob
notBefore syscall.Filetime
notAfter syscall.Filetime
subject certNameBlob
subjectPublicKeyInfo certPublicKeyInfo
issuerUniqueID cryptBitBlob
subjectUniqueID cryptBitBlob
cExtension uint32
rgExtension uintptr
}
type CertContext struct {
EncodingType uint32
EncodedCert *byte
Length uint32
CertInfo *certInfo
Store syscall.Handle
}
// We look for a cert with the following
// - Not expired
// - Has a private key
// Note that the dev code uses syscall.CertContext. However that doesn't have the CERT_INFO
// structure, so we need to find the cert manually, then convert it to the syscall structure
func GetAUsableCert(handle syscall.Handle) (cert *syscall.CertContext, _ error) {
var testCert *CertContext
var prevCert *CertContext
procCertEnumCertificatesInStore := Modcrypt32.NewProc("CertEnumCertificatesInStore")
for {
ret, _, _ := syscall.Syscall(
procCertEnumCertificatesInStore.Addr(),
2,
uintptr(handle),
uintptr(unsafe.Pointer(prevCert)),
0)
// Not that we don't handle ENotFound, since that's an error case for us (we couldn't find a cert)
testCert = (*CertContext)(unsafe.Pointer(ret))
usable := isAUsableCert(testCert)
if usable {
// We need a syscall.CertContext
syscallContext := (*syscall.CertContext)(unsafe.Pointer(ret))
return syscallContext, nil
}
prevCert = testCert
}
}
func isAUsableCert(cert *CertContext) (usable bool) {
// First check if the cert has expired
ended := time.Unix(0, cert.CertInfo.notAfter.Nanoseconds())
started := time.Unix(0, cert.CertInfo.notBefore.Nanoseconds())
now := time.Now()
if now.After(ended) || now.Before(started) {
return false
}
// Check that it has a private key
if !hasPrivateKey(cert) {
return false
}
return true
}
func hasPrivateKey(cert *CertContext) bool {
var ncryptKeyHandle uintptr
var dwKeySpec uint32
var fCallerFreeProvOrNCryptKey uint32
ret, _, err := syscall.Syscall6(
procCryptAcquireCertificatePrivateKey.Addr(),
6,
uintptr(unsafe.Pointer(cert)),
uintptr(0),
uintptr(0),
uintptr(unsafe.Pointer(&ncryptKeyHandle)),
uintptr(unsafe.Pointer(&dwKeySpec)),
uintptr(unsafe.Pointer(&fCallerFreeProvOrNCryptKey)))
if ret == 0 {
if err > 0 {
// If for some reason we can't retrieve the private key, move on
return false
}
}
// Figure out if we need to release the handle
if fCallerFreeProvOrNCryptKey != 0 {
if dwKeySpec == certNCryptKeySpec {
// We received an CERT_NCRYPT_KEY_SPEC
syscall.Syscall(
procNCryptFreeObject.Addr(),
1,
uintptr(ncryptKeyHandle),
0,
0)
} else {
handle := syscall.Handle(ncryptKeyHandle)
syscall.CryptReleaseContext(handle, 0)
}
}
return true
}
func GetCertificateThumbprint(cert *syscall.CertContext) ([]byte, error) {
// Call it once to retrieve the thumbprint size
var cbComputedHash uint32
ret, _, err := syscall.Syscall6(
procCertGetCertificateContextProperty.Addr(),
4,
uintptr(unsafe.Pointer(cert)), // pCertContext
uintptr(CertHashPropID), // dwPropId
uintptr(0), // pvData)
uintptr(unsafe.Pointer(&cbComputedHash)), // pcbData
0,
0,
)
if ret == 0 {
return nil, fmt.Errorf("VmExtension: Could not hash certificate due to '%d'", syscall.Errno(err))
}
// Create our buffer
if cbComputedHash == 0 {
return nil, nil
}
var computedHashBuffer = make([]byte, cbComputedHash)
var pComputedHash *byte
pComputedHash = &computedHashBuffer[0]
ret, _, err = syscall.Syscall6(
procCertGetCertificateContextProperty.Addr(),
4,
uintptr(unsafe.Pointer(cert)), // pCertContext
uintptr(CertHashPropID), // dwPropId
uintptr(unsafe.Pointer(pComputedHash)), // pvData)
uintptr(unsafe.Pointer(&cbComputedHash)), // pcbData
0,
0,
)
if ret == 0 {
return nil, fmt.Errorf("VmExtension: Could not hash certificate due to '%d'", syscall.Errno(err))
}
return computedHashBuffer, nil
}