azure-protected-vm-secrets/Windows/BcryptAesWrapper.cpp (365 lines of code) (raw):
#include "..\pch.h"
#ifndef PLATFORM_UNIX
#define UMDF_USING_NTSTATUS
#include <windows.h>
#include <bcrypt.h>
#ifndef _NTSTATUS_
#include <ntstatus.h>
#endif
#include "..\BcryptError.h"
#else
#endif
#include <memory>
#include <vector>
#include "..\AesWrapper.h"
#include "BcryptAesWrapper.h"
#include "..\ReturnCodes.h"
__inline long long __round_up(long long numToRound, long long multiple) {
return ((numToRound + multiple - 1) / multiple) * multiple;
}
GcmChainingInfo::GcmChainingInfo(BCRYPT_ALG_HANDLE algHandle)
{
DWORD bytesDone = 0;
this->authInfo = { 0 };
BCRYPT_INIT_AUTH_MODE_INFO(this->authInfo);
NTSTATUS bcryptResult = STATUS_SUCCESS;
BCRYPT_AUTH_TAG_LENGTHS_STRUCT authTagLengths;
DWORD blockLength = 0;
// Get properties and throw excpetion if it fails.
// No clean up needed since no resources are allocated until after this point.
bcryptResult = BCryptGetProperty(
algHandle,
BCRYPT_AUTH_TAG_LENGTH,
(BYTE*)&(authTagLengths),
sizeof(authTagLengths),
&bytesDone, 0
);
if (!BCRYPT_SUCCESS(bcryptResult)) {
// LibraryErrors class, Bcrypt subclass, property not found
throw BcryptError(bcryptResult, "BCryptGetProperty(BCRYPT_AUTH_TAG_LENGTH) failed",
ErrorCode::LibraryError_Bcrypt_propertyError);
}
bcryptResult = BCryptGetProperty(
algHandle,
BCRYPT_BLOCK_LENGTH,
(BYTE*)&(blockLength),
sizeof(blockLength),
&bytesDone, 0);
if (!BCRYPT_SUCCESS(bcryptResult)) {
// LibraryErrors class, Bcrypt subclass, property not found
throw BcryptError(bcryptResult, "BCryptGetProperty(BCRYPT_BLOCK_LENGTH) failed",
ErrorCode::LibraryError_Bcrypt_propertyError);
}
this->authTag = std::vector<unsigned char>(authTagLengths.dwMaxLength);
this->macContext = std::vector<unsigned char>(authTagLengths.dwMaxLength);
// init vector is the same length as the block length.
// In GCM the IV is the nonce, but still needs to be provided to BCryptEncrypt/Decrypt
this->initVector = std::vector<unsigned char>(blockLength);
this->authInfo.pbTag = this->authTag.data();
this->authInfo.cbTag = this->authTag.size();
this->authInfo.dwFlags = BCRYPT_AUTH_MODE_CHAIN_CALLS_FLAG;
this->authInfo.pbMacContext = this->macContext.data();
this->authInfo.cbMacContext = this->macContext.size();
this->authInfo.cbAAD = 0;
this->authInfo.cbData = 0;
this->authInfo.cbAuthData = 0;
this->authInfo.pbAuthData = nullptr;
}
GcmChainingInfo::~GcmChainingInfo()
{
}
void GcmChainingInfo::SetNonce(const std::vector<unsigned char> &nonce) noexcept
{
this->nonce = std::vector<unsigned char>(nonce.data(), nonce.data() + nonce.size());
this->authInfo.pbNonce = this->nonce.data();
this->authInfo.cbNonce = this->nonce.size();
}
std::vector<unsigned char> GcmChainingInfo::GetNonce() noexcept
{
return this->nonce;
}
void GcmChainingInfo::SetInitVector(const std::vector<unsigned char> &initVector) noexcept
{
this->initVector = std::vector<unsigned char>(initVector.data(), initVector.data() + initVector.size());
}
BCRYPT_AUTHENTICATED_CIPHER_MODE_INFO* GcmChainingInfo::GetAuthInfo() noexcept
{
return &(this->authInfo);
}
std::vector<unsigned char> GcmChainingInfo::GetInitVector() noexcept
{
return this->initVector;
}
GcmWrapper::GcmWrapper()
{
#ifndef PLATFORM_UNIX
this->hAesHandle = nullptr;
DWORD bytesDone = 0;
NTSTATUS bcryptResult = BCryptOpenAlgorithmProvider(
&(this->hAesHandle), BCRYPT_AES_ALGORITHM, 0, 0);
if (STATUS_SUCCESS != bcryptResult) {
// LibraryErrors class, Bcrypt subclass, provider/handler error
throw BcryptError(bcryptResult, "BCryptOpenAlgorithmProvider failed",
ErrorCode::LibraryError_Bcrypt_providerError);
}
this->hAesKey = nullptr;
this->authTagLengths = { 0 };
this->authInfo = { 0 };
bcryptResult = BCryptSetProperty(
this->hAesHandle,
BCRYPT_CHAINING_MODE,
(PUCHAR)BCRYPT_CHAIN_MODE_GCM,
sizeof(BCRYPT_CHAIN_MODE_GCM),
0);
if (!BCRYPT_SUCCESS(bcryptResult)) {
// LibraryErrors class, Bcrypt subclass, propertyError
throw BcryptError(bcryptResult, "BCryptSetProperty(BCRYPT_CHAINING_MODE) failed",
ErrorCode::LibraryError_Bcrypt_propertyError);
}
bcryptResult = BCryptGetProperty(this->hAesHandle, BCRYPT_AUTH_TAG_LENGTH, (BYTE*)&(this->authTagLengths), sizeof(authTagLengths), &bytesDone, 0);
if (!BCRYPT_SUCCESS(bcryptResult)) {
// LibraryErrors class, Bcrypt subclass, propertyError
throw BcryptError(bcryptResult, "BCryptGetProperty(BCRYPT_AUTH_TAG_LENGTH) failed",
ErrorCode::LibraryError_Bcrypt_propertyError);
}
bcryptResult = BCryptGetProperty(this->hAesHandle, BCRYPT_BLOCK_LENGTH, (BYTE*)&(this->blockLength), sizeof(blockLength), &bytesDone, 0);
if (!BCRYPT_SUCCESS(bcryptResult)) {
// LibraryErrors class, Bcrypt subclass, propertyError
throw BcryptError(bcryptResult, "BCryptGetProperty(BCRYPT_BLOCK_LENGTH) failed",
ErrorCode::LibraryError_Bcrypt_propertyError);
}
bcryptResult = BCryptGetProperty(
this->hAesHandle,
BCRYPT_OBJECT_LENGTH,
(PUCHAR) & (this->objectLength),
sizeof(this->objectLength),
&bytesDone,
0
);
if (!BCRYPT_SUCCESS(bcryptResult)) {
// LibraryErrors class, Bcrypt subclass, propertyError
throw BcryptError(bcryptResult, "BCryptGetProperty(BCRYPT_OBJECT_LENGTH) failed",
ErrorCode::LibraryError_Bcrypt_propertyError);
}
this->objectValue = std::vector<unsigned char>(this->objectLength);
#else
#endif // !PLATFORM_UNIX
}
GcmWrapper::~GcmWrapper()
{
#ifndef PLATFORM_UNIX
NTSTATUS bcryptResult = STATUS_SUCCESS;
if (this->hAesHandle != nullptr) {
bcryptResult = BCryptCloseAlgorithmProvider(this->hAesHandle, 0);
if (STATUS_SUCCESS != bcryptResult) {
// LibraryErrors class, Bcrypt subclass, handleError
throw BcryptError(bcryptResult, "BCryptCloseAlgorithmProvider failed",
ErrorCode::LibraryError_Bcrypt_handleError);
}
}
if (this->hAesKey != nullptr) {
bcryptResult = BCryptDestroyKey(this->hAesKey);
if (STATUS_SUCCESS != bcryptResult) {
// LibraryErrors class, Bcrypt subclass, handleError
throw BcryptError(bcryptResult, "BCryptDestroyKey failed",
ErrorCode::LibraryError_Bcrypt_handleError);
}
}
#else
#endif // !PLATFORM_UNIX
}
void GcmWrapper::SetKey(std::vector<unsigned char> &key)
{
#ifndef PLATFORM_UNIX
NTSTATUS bcryptResult = STATUS_SUCCESS;
bcryptResult = BCryptGenerateSymmetricKey(
this->hAesHandle,
&this->hAesKey,
this->objectValue.data(),
this->objectLength,
key.data(),
key.size(),
0
);
if (STATUS_SUCCESS != bcryptResult) {
// LibraryErrors class, Bcrypt subclass, keyError
throw BcryptError(bcryptResult, "BCryptGetProperty(BCRYPT_AUTH_TAG_LENGTH) failed",
ErrorCode::LibraryError_Bcrypt_keyError);
}
#else
#endif // !PLATFORM_UNIX
}
std::unique_ptr<AesChainingInfo> GcmWrapper::SetChainingInfo(const std::vector<unsigned char> &nonce)
{
std::unique_ptr<AesChainingInfo> chainingInfo;
try {
chainingInfo = std::make_unique<GcmChainingInfo>(this->hAesHandle);
chainingInfo->SetNonce(nonce);
}
catch (BcryptError e) {
// From a BcryptError inside the GcmChainingInfo constructor
throw e;
}
catch (std::exception& e) {
throw e;
}
return chainingInfo;
}
std::vector<unsigned char> GcmWrapper::Encrypt(const std::vector<unsigned char> &data, AesChainingInfo *chainingInfo) const
{
DWORD bytesDone = 0;
long long ciphertextSize = 0;
long long ptxOffset = 0;
if (chainingInfo == nullptr) {
throw std::exception("Chaining info must be set before calling Encrypt");
}
long long dataLength = data.size();
long long encryptedDataLength = __round_up(dataLength, this->blockLength);
long numBlocks = encryptedDataLength / this->blockLength;
GcmChainingInfo* gcmChainingInfo = dynamic_cast<GcmChainingInfo*>(chainingInfo);
BCRYPT_AUTHENTICATED_CIPHER_MODE_INFO* authInfo = gcmChainingInfo->GetAuthInfo();
encryptedDataLength += authInfo->cbTag;
std::vector<unsigned char> result(encryptedDataLength);
std::vector<unsigned char> initVector = gcmChainingInfo->GetInitVector();
#ifndef PLATFORM_UNIX
NTSTATUS bcryptResult = STATUS_SUCCESS;
// init aad
bcryptResult = BCryptEncrypt(
this->hAesKey,
nullptr,
0,
authInfo,
initVector.data(),
initVector.size(),
nullptr,
0,
&bytesDone,
0
);
if (STATUS_SUCCESS != bcryptResult) {
// CryptographyError class, AES subclass, encryptError
throw BcryptError(bcryptResult, "BCryptEncrypt failed",
ErrorCode::CryptographyError_AES_encryptError);
}
authInfo->cbAuthData = 0;
authInfo->pbAuthData = nullptr;
for (long i = 0; i < numBlocks - 1; i++) {
bytesDone = 0;
bcryptResult = BCryptEncrypt(
this->hAesKey,
(unsigned char *)data.data() + ptxOffset,
this->blockLength,
authInfo,
initVector.data(),
initVector.size(),
result.data() + ciphertextSize,
this->blockLength,
&bytesDone,
0
);
if (STATUS_SUCCESS != bcryptResult) {
// CryptographyError class, AES subclass, encryptError
throw BcryptError(bcryptResult, "BCryptEncrypt failed",
ErrorCode::CryptographyError_AES_encryptError);
}
ciphertextSize += this->blockLength;
ptxOffset += bytesDone;
}
bytesDone = 0;
authInfo->dwFlags &= ~BCRYPT_AUTH_MODE_CHAIN_CALLS_FLAG;
bcryptResult = BCryptEncrypt(
this->hAesKey,
(unsigned char*)data.data() + ptxOffset,
data.size() - ptxOffset,
authInfo,
initVector.data(),
initVector.size(),
result.data() + ciphertextSize,
this->blockLength,
&bytesDone,
0
);
if (STATUS_SUCCESS != bcryptResult) {
// CryptographyError class, AES subclass, encryptError
throw BcryptError(bcryptResult, "BCryptEncrypt failed",
ErrorCode::CryptographyError_AES_encryptError);
}
ciphertextSize += bytesDone;
result.resize(ciphertextSize + authInfo->cbTag);
std::copy(
authInfo->pbTag,
authInfo->pbTag + authInfo->cbTag,
result.data() + ciphertextSize
);
#else
#endif // !PLATFORM_UNIX
return result;
}
std::vector<unsigned char> GcmWrapper::Decrypt(const std::vector<unsigned char> &ciphertext, AesChainingInfo *chainingInfo) const
{
DWORD bytesDone = 0;
long long returnDataLength = 0;
long long ctxOffset = 0;
if (chainingInfo == nullptr) {
throw std::exception("Chaining info must be set before calling Encrypt");
}
long long encryptedDataLength = ciphertext.size() - this->authTagLengths.dwMaxLength;
std::vector<unsigned char> result(encryptedDataLength);
long numBlocks = __round_up(encryptedDataLength, this->blockLength) / this->blockLength;
GcmChainingInfo* gcmChainingInfo = dynamic_cast<GcmChainingInfo*>(chainingInfo);
BCRYPT_AUTHENTICATED_CIPHER_MODE_INFO* authInfo = gcmChainingInfo->GetAuthInfo();
std::vector<unsigned char> initVector = gcmChainingInfo->GetInitVector();
#ifndef PLATFORM_UNIX
NTSTATUS bcryptResult = STATUS_SUCCESS;
// Peel off the auth tag
std::vector<unsigned char> authTag(this->authTagLengths.dwMaxLength);
std::copy(
ciphertext.data() + encryptedDataLength,
ciphertext.data() + encryptedDataLength + this->authTagLengths.dwMaxLength,
authTag.data()
);
// init aad
bcryptResult = BCryptDecrypt(
this->hAesKey,
nullptr,
0,
authInfo,
initVector.data(),
initVector.size(),
nullptr,
0,
&bytesDone,
0
);
if (STATUS_SUCCESS != bcryptResult) {
// CryptographyError class, AES subclass, decryptError
throw BcryptError(bcryptResult, "BCryptDecrypt failed",
ErrorCode::CryptographyError_AES_decryptError);
}
authInfo->cbAuthData = 0;
authInfo->pbAuthData = nullptr;
for (long i = 0; i < numBlocks - 1; i++) {
bytesDone = 0;
ctxOffset = i * this->blockLength;
bcryptResult = BCryptDecrypt(
this->hAesKey,
(unsigned char*)ciphertext.data() + ctxOffset,
this->blockLength,
authInfo,
initVector.data(),
initVector.size(),
result.data() + returnDataLength,
this->blockLength,
&bytesDone,
0
);
if (STATUS_SUCCESS != bcryptResult) {
// CryptographyError class, AES subclass, decryptError
throw BcryptError(bcryptResult, "BCryptDecrypt failed",
ErrorCode::CryptographyError_AES_decryptError);
}
returnDataLength += bytesDone;
ctxOffset += this->blockLength;
}
authInfo->dwFlags &= ~BCRYPT_AUTH_MODE_CHAIN_CALLS_FLAG;
bytesDone = 0;
authInfo->pbTag = authTag.data();
authInfo->cbTag = authTag.size();
bcryptResult = BCryptDecrypt(
this->hAesKey,
(unsigned char*)ciphertext.data() + ctxOffset,
encryptedDataLength - ctxOffset,
authInfo,
initVector.data(),
initVector.size(),
result.data() + returnDataLength,
this->blockLength,
&bytesDone,
0
);
if (STATUS_SUCCESS != bcryptResult) {
// CryptographyError class, AES subclass, decryptError
throw BcryptError(bcryptResult, "BCryptDecrypt failed",
ErrorCode::CryptographyError_AES_decryptError);
}
returnDataLength += bytesDone;
return result;
#else
#endif // !PLATFORM_UNIX
}