aws_signing_helper/cert_store_signer_windows.go (471 lines of code) (raw):
//go:build windows
package aws_signing_helper
// This code is based on the smimesign repository at
// https://github.com/github/smimesign
/*
#cgo windows LDFLAGS: -lcrypt32 -lncrypt
#include <windows.h>
#include <wincrypt.h>
#include <bcrypt.h>
#include <ncrypt.h>
//
// Go complains about LPCWSTR constants and the MAKELANGID function not being
// defined, so we define methods for them.
//
LPCWSTR GET_BCRYPT_SHA1_ALGORITHM() { return BCRYPT_SHA1_ALGORITHM; }
LPCWSTR GET_BCRYPT_SHA256_ALGORITHM() { return BCRYPT_SHA256_ALGORITHM; }
LPCWSTR GET_BCRYPT_SHA384_ALGORITHM() { return BCRYPT_SHA384_ALGORITHM; }
LPCWSTR GET_BCRYPT_SHA512_ALGORITHM() { return BCRYPT_SHA512_ALGORITHM; }
int MAKE_LANG_ID(int p, int s) {
return MAKELANGID(p, s);
}
*/
import "C"
import (
"crypto"
"crypto/ecdsa"
"crypto/rsa"
"crypto/sha256"
"crypto/sha512"
"crypto/x509"
"errors"
"fmt"
"golang.org/x/sys/windows"
"io"
"log"
"sort"
"strconv"
"strings"
"unsafe"
)
// winPrivateKey is a wrapper around a HCRYPTPROV_OR_NCRYPT_KEY_HANDLE.
type winPrivateKey struct {
publicKey crypto.PublicKey
mustFree bool
// CryptoAPI fields
cspHandle windows.Handle
keySpec uint32
// CNG fields
cngKeyHandle windows.Handle
}
type WindowsCertStoreSigner struct {
store windows.Handle
certCtx *windows.CertContext
cert *x509.Certificate
certChain []*x509.Certificate
privateKey *winPrivateKey
}
const (
WIN_FALSE C.WINBOOL = 0
// ERROR_SUCCESS — The call succeeded
ERROR_SUCCESS = 0x00000000
// NTE_BAD_ALGID — Invalid algorithm specified
NTE_BAD_ALGID = 0x80090008
// NTE_SILENT_CONTEXT - KSP must display UI to operate
NTE_SILENT_CONTEXT = 0x80090022
// WIN_API_FLAG specifies the flags that should be passed to
// CryptAcquireCertificatePrivateKey. This impacts whether the CryptoAPI or CNG
// API will be used.
//
// Possible values are:
//
// 0x00000000 — — Only use CryptoAPI.
// 0x00010000 — CRYPT_ACQUIRE_ALLOW_NCRYPT_KEY_FLAG — Prefer CryptoAPI.
// 0x00020000 — CRYPT_ACQUIRE_PREFER_NCRYPT_KEY_FLAG — Prefer CNG.
// 0x00040000 — CRYPT_ACQUIRE_ONLY_NCRYPT_KEY_FLAG — Only use CNG.
WIN_API_FLAG = windows.CRYPT_ACQUIRE_PREFER_NCRYPT_KEY_FLAG
)
var (
// ErrRequiresUI is used when providers are required to display
// UI to perform signing operations.
ErrRequiresUI = errors.New("provider requries UI to operate")
)
// Error codes for Windows APIs - implements the error interface.
// Error codes are maintained on a per-thread basis. In order to
// get the last error code, C.GetLastError needs to be called (called
// within the checkError() function). Some Windows APIs only return
// BOOLs indicating success or failure, and for more detailed error
// information, error codes are used.
type errCode uint64
// Implements the error interface for errCode and returns a string
// version of the errCode
func (c errCode) Error() string {
var cMsg C.LPSTR
ret := C.FormatMessage(
C.FORMAT_MESSAGE_ALLOCATE_BUFFER|
C.FORMAT_MESSAGE_FROM_SYSTEM|
C.FORMAT_MESSAGE_IGNORE_INSERTS,
nil,
C.DWORD(c),
C.ulong(C.MAKE_LANG_ID(C.LANG_NEUTRAL, C.SUBLANG_DEFAULT)),
cMsg,
0, nil)
if ret == 0 {
return fmt.Sprintf("Error %X", int(c))
}
if cMsg == nil {
return fmt.Sprintf("Error %X", int(c))
}
goMsg := C.GoString(cMsg)
return fmt.Sprintf("Error: %X %s", int(c), goMsg)
}
// Security status for Windows APIs - implements the error interface.
// Some Windows API calls return this type directly.
// Go representation of the C SECURITY_STATUS.
type securityStatus uint64
// Implements the error interface
func (secStatus securityStatus) Error() string {
return fmt.Sprintf("SECURITY_STATUS %d", int(secStatus))
}
// Gets the certificates that match the given CertIdentifier within the user's specified system
// certificate store. By default, that is "MY".
// If there is only a single matching certificate, then its chain will be returned too
func GetMatchingCertsAndChain(certIdentifier CertIdentifier) (store windows.Handle, certCtxs []*windows.CertContext, certChains [][]*x509.Certificate, certContainers []CertificateContainer, err error) {
storeName, err := windows.UTF16PtrFromString(certIdentifier.SystemStoreName)
if err != nil {
return 0, nil, nil, nil, errors.New("unable to UTF-16 encode personal certificate store name")
}
store, err = windows.CertOpenStore(windows.CERT_STORE_PROV_SYSTEM_W, 0, 0, windows.CERT_SYSTEM_STORE_CURRENT_USER, uintptr(unsafe.Pointer(storeName)))
if err != nil {
return 0, nil, nil, nil, errors.New("failed to open system cert store")
}
var (
// CertFindChainInStore parameters
encoding = uint32(windows.X509_ASN_ENCODING)
flags = uint32(windows.CERT_CHAIN_FIND_BY_ISSUER_CACHE_ONLY_FLAG | windows.CERT_CHAIN_FIND_BY_ISSUER_CACHE_ONLY_URL_FLAG)
findType = uint32(windows.CERT_CHAIN_FIND_BY_ISSUER)
params windows.CertChainFindByIssuerPara
paramsPtr unsafe.Pointer
chainCtx *windows.CertChainContext = nil
certCtx *windows.CertContext
)
params.Size = uint32(unsafe.Sizeof(params))
paramsPtr = unsafe.Pointer(¶ms)
var curCertCtx *windows.CertContext
var curCert *x509.Certificate
certContainerIndex := 0
for {
// Previous chainCtx should be freed here if it isn't nil
chainCtx, err = windows.CertFindChainInStore(store, encoding, flags, findType, paramsPtr, chainCtx)
if err != nil {
if strings.Contains(err.Error(), "Cannot find object or property.") {
break
}
err = errors.New("unable to find certificate chain in store")
goto fail
}
if chainCtx.ChainCount < 1 {
err = errors.New("bad chain")
goto fail
}
// When multiple valid certification paths that are found for a given
// certificate, only the first one is considered
simpleChain := *chainCtx.Chains
if simpleChain.NumElements < 1 {
err = errors.New("bad chain")
goto fail
}
// Convert the array into a pointer
chainElts := unsafe.Slice(simpleChain.Elements, simpleChain.NumElements)
// Build chain of certificates from each element's certificate context.
x509CertChain := make([]*x509.Certificate, len(chainElts))
for j := range chainElts {
curCertCtx = chainElts[j].CertContext
x509CertChain[j], err = exportCertContext(curCertCtx)
if err != nil {
if Debug {
log.Printf("unable to parse certificate with error (%s) - skipping\n", err)
}
goto nextIteration
}
}
curCert = x509CertChain[0]
if certMatches(certIdentifier, *curCert) {
certContainers = append(certContainers, CertificateContainer{certContainerIndex, curCert, ""})
certContainerIndex += 1
certChains = append(certChains, x509CertChain[:])
certCtx = chainElts[0].CertContext
// It's the responsibility of the caller to free the below once they are done using it.
windows.CertDuplicateCertificateContext(certCtx)
certCtxs = append(certCtxs, certCtx)
}
nextIteration:
}
if Debug {
log.Printf("found %d matching identities\n", len(certContainers))
}
return store, certCtxs, certChains, certContainers, nil
fail:
if chainCtx != nil {
windows.CertFreeCertificateChain(chainCtx)
chainCtx = nil
}
for i, curCertCtx := range certCtxs {
if curCertCtx != nil {
windows.CertFreeCertificateContext(curCertCtx)
certCtxs[i] = nil
}
}
windows.CertCloseStore(store, 0)
return 0, nil, nil, nil, err
}
// Gets the certificates that match a CertIdentifier
func GetMatchingCerts(certIdentifier CertIdentifier) ([]CertificateContainer, error) {
store, certCtxs, _, certContainers, err := GetMatchingCertsAndChain(certIdentifier)
for i, curCertCtx := range certCtxs {
if curCertCtx != nil {
windows.CertFreeCertificateContext(curCertCtx)
certCtxs[i] = nil
}
}
windows.CertCloseStore(store, 0)
return certContainers, err
}
// Gets a WindowsCertStoreSigner based on the CertIdentifier
func GetCertStoreSigner(certIdentifier CertIdentifier, useLatestExpiringCert bool) (signer Signer, signingAlgorithm string, err error) {
var (
privateKey *winPrivateKey
selectedCertContainer CertificateContainer
cert *x509.Certificate
certCtx *windows.CertContext
certChain []*x509.Certificate
)
store, certCtxs, certChains, certContainers, err := GetMatchingCertsAndChain(certIdentifier)
if err != nil {
goto fail
}
if len(certContainers) == 0 {
err = errors.New("no matching certs found in cert store")
goto fail
}
if useLatestExpiringCert {
sort.Sort(CertificateContainerList(certContainers))
} else {
if len(certContainers) > 1 {
err = errors.New("multiple matching identities")
goto fail
}
}
selectedCertContainer = certContainers[len(certContainers)-1]
if Debug {
log.Printf("selected certificate: %s", DefaultCertContainerToString(selectedCertContainer))
}
cert = selectedCertContainer.Cert
certCtx = certCtxs[selectedCertContainer.Index]
certChain = certChains[selectedCertContainer.Index]
signer = &WindowsCertStoreSigner{store: store, cert: cert, certCtx: certCtx, certChain: certChain}
privateKey, err = signer.(*WindowsCertStoreSigner).getPrivateKey()
if err != nil {
goto fail
}
// Find the signing algorithm
switch privateKey.publicKey.(type) {
case *ecdsa.PublicKey:
signingAlgorithm = aws4_x509_ecdsa_sha256
case *rsa.PublicKey:
signingAlgorithm = aws4_x509_rsa_sha256
default:
err = errors.New("unsupported algorithm")
goto fail
}
return signer, signingAlgorithm, err
fail:
for i, curCertCtx := range certCtxs {
if curCertCtx != nil {
windows.CertFreeCertificateContext(curCertCtx)
certCtxs[i] = nil
}
}
if signer != nil {
signer.Close()
}
if store != 0 {
windows.CertCloseStore(store, 0)
}
return nil, "", err
}
// Certificate implements the aws_signing_helper.Signer interface and returns a pointer
// to the x509.Certificate associated with this signer
func (signer *WindowsCertStoreSigner) Certificate() (cert *x509.Certificate, err error) {
return signer.cert, nil
}
// CertificateChain implements the aws_signing_helper.Signer interface and returns
// the certificate chain associated with this signer
func (signer *WindowsCertStoreSigner) CertificateChain() ([]*x509.Certificate, error) {
return signer.certChain, nil
}
// Close implements the aws_signing_helper.Signer interface and closes the signer
func (signer *WindowsCertStoreSigner) Close() {
if signer.privateKey != nil && signer.privateKey.mustFree {
if signer.privateKey.cngKeyHandle != 0 {
cngHandle := (*C.NCRYPT_KEY_HANDLE)(unsafe.Pointer(&signer.privateKey.cngKeyHandle))
C.NCryptFreeObject(*cngHandle)
}
if signer.privateKey.cspHandle != 0 {
windows.CryptReleaseContext(signer.privateKey.cspHandle, 0)
}
}
signer.privateKey = nil
// If signer.privateKey.mustFree is false, key handles will be released on the
// last free action of the certificate context.
// See https://learn.microsoft.com/en-us/windows/win32/api/wincrypt/nf-wincrypt-cryptacquirecertificateprivatekey
if signer.certCtx != nil {
windows.CertFreeCertificateContext(signer.certCtx)
signer.certCtx = nil
}
windows.CertCloseStore(signer.store, 0)
signer.store = 0
}
// getPrivateKey gets this identity's private *winPrivateKey
func (signer *WindowsCertStoreSigner) getPrivateKey() (*winPrivateKey, error) {
if signer.privateKey != nil {
return signer.privateKey, nil
}
cert, err := signer.Certificate()
if err != nil {
return nil, fmt.Errorf("failed to get identity certificate: %w", err)
}
signer.privateKey, err = newWinPrivateKey(signer.certCtx, cert.PublicKey)
if err != nil {
return nil, fmt.Errorf("failed to load identity private key: %w", err)
}
return signer.privateKey, nil
}
// Gets a *winPrivateKey for the given certificate
func newWinPrivateKey(certCtx *windows.CertContext, publicKey crypto.PublicKey) (*winPrivateKey, error) {
var (
cspHandleOrCngKey windows.Handle
keySpec uint32
mustFree bool
)
if publicKey == nil {
return nil, errors.New("nil public key")
}
// Get a handle for the found private key
if err := windows.CryptAcquireCertificatePrivateKey(certCtx, WIN_API_FLAG, nil, &cspHandleOrCngKey, &keySpec, &mustFree); err != nil {
return nil, err
}
if keySpec == C.CERT_NCRYPT_KEY_SPEC {
return &winPrivateKey{
publicKey: publicKey,
cngKeyHandle: cspHandleOrCngKey,
mustFree: mustFree,
}, nil
} else {
return &winPrivateKey{
publicKey: publicKey,
cspHandle: cspHandleOrCngKey,
keySpec: keySpec,
mustFree: mustFree,
}, nil
}
}
// Public implements the crypto.Signer interface.
func (signer *WindowsCertStoreSigner) Public() crypto.PublicKey {
privateKey, err := signer.getPrivateKey()
if err != nil {
return nil
}
return privateKey.publicKey
}
// Sign implements the crypto.Signer interface and signs the digest
func (signer *WindowsCertStoreSigner) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) ([]byte, error) {
var hash []byte
switch opts.HashFunc() {
case crypto.SHA256:
sum := sha256.Sum256(digest)
hash = sum[:]
case crypto.SHA384:
sum := sha512.Sum384(digest)
hash = sum[:]
case crypto.SHA512:
sum := sha512.Sum512(digest)
hash = sum[:]
default:
return nil, ErrUnsupportedHash
}
privateKey, err := signer.getPrivateKey()
if err != nil {
return nil, err
}
if privateKey.cspHandle != 0 {
return signer.cryptoSignHash(hash, opts.HashFunc())
} else if privateKey.cngKeyHandle != 0 {
return signer.cngSignHash(hash, opts.HashFunc())
} else {
return nil, errors.New("bad private key")
}
}
// cngSignHash signs a digest using CNG APIs
func (signer *WindowsCertStoreSigner) cngSignHash(digest []byte, hash crypto.Hash) ([]byte, error) {
if len(digest) != hash.Size() {
return nil, errors.New("bad digest for hash")
}
var (
// Input
padPtr = unsafe.Pointer(nil)
digestPtr = (*C.BYTE)(&digest[0])
digestLen = C.DWORD(len(digest))
flags = C.DWORD(0)
// Output
sigLen = C.DWORD(0)
)
// Set up pkcs1v1.5 padding for RSA
privateKey, _ := signer.getPrivateKey()
if _, isRSA := privateKey.publicKey.(*rsa.PublicKey); isRSA {
flags |= C.BCRYPT_PAD_PKCS1
padInfo := C.BCRYPT_PKCS1_PADDING_INFO{}
padPtr = unsafe.Pointer(&padInfo)
switch hash {
case crypto.SHA1:
padInfo.pszAlgId = C.GET_BCRYPT_SHA1_ALGORITHM()
case crypto.SHA256:
padInfo.pszAlgId = C.GET_BCRYPT_SHA256_ALGORITHM()
case crypto.SHA384:
padInfo.pszAlgId = C.GET_BCRYPT_SHA384_ALGORITHM()
case crypto.SHA512:
padInfo.pszAlgId = C.GET_BCRYPT_SHA512_ALGORITHM()
default:
return nil, ErrUnsupportedHash
}
}
// Get C.NCRYPT_KEY_HANDLE in order to do the signature
cngKeyHandle := (*C.NCRYPT_KEY_HANDLE)(unsafe.Pointer(&privateKey.cngKeyHandle))
// Get signature length
if err := checkStatus(C.NCryptSignHash(*cngKeyHandle, padPtr, digestPtr, digestLen, nil, 0, &sigLen, flags)); err != nil {
return nil, fmt.Errorf("failed to get signature length: %w", err)
}
// Get signature
sig := make([]byte, sigLen)
sigPtr := (*C.BYTE)(&sig[0])
if err := checkStatus(C.NCryptSignHash(*cngKeyHandle, padPtr, digestPtr, digestLen, sigPtr, sigLen, &sigLen, flags|C.NCRYPT_SILENT_FLAG)); err != nil {
if err == ErrRequiresUI {
if err = checkStatus(C.NCryptSignHash(*cngKeyHandle, padPtr, digestPtr, digestLen, sigPtr, sigLen, &sigLen, flags)); err == nil {
goto got_signature
}
}
return nil, fmt.Errorf("failed to sign digest: %w", err)
}
got_signature:
// CNG returns a raw ECDSA signature, but we want ASN.1 DER encoding
if _, isEC := privateKey.publicKey.(*ecdsa.PublicKey); isEC {
if len(sig)%2 != 0 {
return nil, errors.New("bad ecdsa signature from CNG")
}
return encodeEcdsaSigValue(sig)
}
return sig, nil
}
// Signs a digest using CryptoAPI
func (signer *WindowsCertStoreSigner) cryptoSignHash(digest []byte, hash crypto.Hash) ([]byte, error) {
if len(digest) != hash.Size() {
return nil, errors.New("bad digest for hash")
}
// Figure out which CryptoAPI hash algorithm we're using
var hash_alg C.ALG_ID
switch hash {
case crypto.SHA1:
hash_alg = C.CALG_SHA1
case crypto.SHA256:
hash_alg = C.CALG_SHA_256
case crypto.SHA384:
hash_alg = C.CALG_SHA_384
case crypto.SHA512:
hash_alg = C.CALG_SHA_512
default:
return nil, ErrUnsupportedHash
}
// Instantiate a CryptoAPI hash object
var cryptoHash C.HCRYPTHASH
privateKey, _ := signer.getPrivateKey()
cspHandle := (*C.HCRYPTPROV)(unsafe.Pointer(&privateKey.cspHandle))
if ok := C.CryptCreateHash(*cspHandle, hash_alg, 0, 0, &cryptoHash); ok == WIN_FALSE {
if err := lastError("failed to create hash"); errCause(err) == errCode(NTE_BAD_ALGID) {
return nil, ErrUnsupportedHash
} else {
return nil, err
}
}
defer C.CryptDestroyHash(cryptoHash)
// Make sure the hash size matches
var (
hashSize C.DWORD
hashSizePtr = (*C.BYTE)(unsafe.Pointer(&hashSize))
hashSizeLen = C.DWORD(unsafe.Sizeof(hashSize))
)
if ok := C.CryptGetHashParam(cryptoHash, C.HP_HASHSIZE, hashSizePtr, &hashSizeLen, 0); ok == WIN_FALSE {
return nil, lastError("failed to get hash size")
}
if hash.Size() != int(hashSize) {
return nil, errors.New("invalid CryptoAPI hash")
}
// Put our digest into the hash object
digestPtr := (*C.BYTE)(unsafe.Pointer(&digest[0]))
if ok := C.CryptSetHashParam(cryptoHash, C.HP_HASHVAL, digestPtr, 0); ok == WIN_FALSE {
return nil, lastError("failed to set hash digest")
}
// Get signature length
var sigLen C.DWORD
if ok := C.CryptSignHash(cryptoHash, C.ulong(privateKey.keySpec), nil, 0, nil, &sigLen); ok == WIN_FALSE {
return nil, lastError("failed to get signature length")
}
// Get signature
var (
sig = make([]byte, int(sigLen))
sigPtr = (*C.BYTE)(unsafe.Pointer(&sig[0]))
)
if ok := C.CryptSignHash(cryptoHash, C.ulong(privateKey.keySpec), nil, 0, sigPtr, &sigLen); ok == WIN_FALSE {
return nil, lastError("failed to sign digest")
}
// Reversing signature since it is little endian, but we want big endian
for i := len(sig)/2 - 1; i >= 0; i-- {
opp := len(sig) - 1 - i
sig[i], sig[opp] = sig[opp], sig[i]
}
return sig, nil
}
// Exports a windows.CertContext as an *x509.Certificate.
func exportCertContext(certCtx *windows.CertContext) (*x509.Certificate, error) {
// Technically, we should never throw here, since the exportCertContext function
// is only called when searching for certificates
if certCtx.EncodingType != windows.X509_ASN_ENCODING {
return nil, errors.New("unknown certificate encoding type")
}
der := unsafe.Slice(certCtx.EncodedCert, certCtx.Length)
return x509.ParseCertificate(der)
}
// Finds the error code for the given error
func errCause(err error) errCode {
msg := err.Error()
codeStr := msg[strings.LastIndex(msg, " ")+1:]
code, _ := strconv.ParseUint(codeStr, 16, 64)
return errCode(code)
}
// Gets the last error from the current thread. If there isn't one, it
// returns a new error
func lastError(msg string) error {
if err := checkError(msg); err != nil {
return err
}
return errors.New(msg)
}
// checkError tries to get the last error from the current thread. If there
// isn't one, it returns nil
func checkError(msg string) error {
if code := errCode(C.GetLastError()); code != 0 {
return fmt.Errorf("%s: %w", msg, code)
}
return nil
}
// Converts a SECURITY_STATUS into a securityStatus
func checkStatus(s C.SECURITY_STATUS) error {
secStatus := securityStatus(s)
if secStatus == ERROR_SUCCESS {
return nil
}
if secStatus == NTE_BAD_ALGID {
return ErrUnsupportedHash
}
if secStatus == NTE_SILENT_CONTEXT {
return ErrRequiresUI
}
return secStatus
}