azure-protected-vm-secrets/Linux/OsslAesWrapper.cpp (112 lines of code) (raw):
#include <openssl/evp.h>
#include <openssl/rand.h>
#include <iostream>
#include <vector>
#include <string>
#include <stdexcept>
#include <memory>
#include "OsslAesWrapper.h"
#include "OsslError.h"
static std::vector<unsigned char> generate_random_bytes(size_t num) {
std::vector<unsigned char> bytes(num);
if (RAND_bytes(bytes.data(), num) != 1) {
throw std::runtime_error("Failed to generate random bytes");
}
return bytes;
}
OsslGcmChainingInfo::OsslGcmChainingInfo() {
}
OsslGcmChainingInfo::~OsslGcmChainingInfo() {
}
void OsslGcmChainingInfo::SetNonce(const std::vector<unsigned char> &nonce) noexcept {
this->nonce = nonce;
}
std::vector<unsigned char> OsslGcmChainingInfo::GetNonce() noexcept {
return nonce;
}
void OsslGcmChainingInfo::SetInitVector(const std::vector<unsigned char> &initVector) noexcept {
this->initVector = initVector;
}
std::vector<unsigned char> OsslGcmChainingInfo::GetInitVector() noexcept {
return initVector;
}
OsslGcmWrapper::OsslGcmWrapper() {
ctx = EVP_CIPHER_CTX_new();
if (!ctx) {
throw std::runtime_error("Failed to create cipher context");
}
}
OsslGcmWrapper::~OsslGcmWrapper() {
if (ctx) {
EVP_CIPHER_CTX_free(ctx);
ctx = nullptr;
}
}
void OsslGcmWrapper::SetKey(std::vector<unsigned char> &key) {
this->key = key;
}
std::unique_ptr<AesChainingInfo> OsslGcmWrapper::SetChainingInfo(const std::vector<unsigned char> &nonce) {
std::unique_ptr<AesChainingInfo> chainingInfo = std::make_unique<OsslGcmChainingInfo>();
chainingInfo->SetInitVector(nonce);
return chainingInfo;
}
std::vector<unsigned char> OsslGcmWrapper::Encrypt(const std::vector<unsigned char> &data, AesChainingInfo* chainingInfo) const {
if (!chainingInfo) {
throw std::runtime_error("Chaining info must be set before calling Encrypt");
}
if (EVP_EncryptInit_ex(ctx, EVP_aes_256_gcm(), nullptr, nullptr, nullptr) != 1) {
throw OsslError(ERR_get_error(), "Failed to initialize encryption");
}
if (EVP_CIPHER_CTX_ctrl(ctx, EVP_CTRL_GCM_SET_IVLEN, chainingInfo->GetInitVector().size(), nullptr) != 1) {
throw OsslError(ERR_get_error(), "Failed to set IV length");
}
if (EVP_EncryptInit_ex(ctx, nullptr, nullptr, key.data(), chainingInfo->GetInitVector().data()) != 1) {
throw OsslError(ERR_get_error(), "Failed to set key and IV");
}
std::vector<unsigned char> ciphertext(data.size() + EVP_MAX_BLOCK_LENGTH);
int len = 0, ciphertext_len = 0;
if (EVP_EncryptUpdate(ctx, ciphertext.data(), &len, data.data(), data.size()) != 1) {
throw OsslError(ERR_get_error(), "Failed to encrypt");
}
ciphertext_len = len;
if (EVP_EncryptFinal_ex(ctx, ciphertext.data() + len, &len) != 1) {
throw OsslError(ERR_get_error(), "Failed to finalize encryption");
}
ciphertext_len += len;
std::vector<unsigned char> tag(16);
if (EVP_CIPHER_CTX_ctrl(ctx, EVP_CTRL_GCM_GET_TAG, tag.size(), tag.data()) != 1) {
throw OsslError(ERR_get_error(), "Failed to get tag");
}
ciphertext.resize(ciphertext_len);
ciphertext.insert(ciphertext.end(), tag.begin(), tag.end());
return ciphertext;
}
std::vector<unsigned char> OsslGcmWrapper::Decrypt(const std::vector<unsigned char> &ciphertext, AesChainingInfo* chainingInfo) const {
if (!chainingInfo) {
throw std::runtime_error("Chaining info must be set before calling Decrypt");
}
if (EVP_DecryptInit_ex(ctx, EVP_aes_256_gcm(), nullptr, nullptr, nullptr) != 1) {
throw OsslError(ERR_get_error(), "Failed to initialize decryption");
}
if (EVP_CIPHER_CTX_ctrl(ctx, EVP_CTRL_GCM_SET_IVLEN, chainingInfo->GetInitVector().size(), nullptr) != 1) {
throw OsslError(ERR_get_error(), "Failed to set IV length");
}
if (EVP_DecryptInit_ex(ctx, nullptr, nullptr, key.data(), chainingInfo->GetInitVector().data()) != 1) {
throw OsslError(ERR_get_error(), "Failed to set key and IV");
}
std::vector<unsigned char> plaintext(ciphertext.size() - 16);
int len = 0, plaintext_len = 0;
if (EVP_DecryptUpdate(ctx, plaintext.data(), &len, ciphertext.data(), ciphertext.size() - 16) != 1) {
throw OsslError(ERR_get_error(), "Failed to decrypt");
}
plaintext_len = len;
if (EVP_CIPHER_CTX_ctrl(ctx, EVP_CTRL_GCM_SET_TAG, 16, const_cast<unsigned char*>(ciphertext.data() + ciphertext.size() - 16)) != 1) {
throw OsslError(ERR_get_error(), "Failed to set tag");
}
if (EVP_DecryptFinal_ex(ctx, plaintext.data() + len, &len) != 1) {
throw OsslError(ERR_get_error(), "Failed to finalize decryption (tag mismatch?)");
}
plaintext_len += len;
plaintext.resize(plaintext_len);
return plaintext;
}