azure-protected-vm-secrets/SecretsProvisioningSample/SecretsProvisioningSample.cpp (193 lines of code) (raw):
// SecretsProvisioningSample.cpp : This file contains the 'main' function. Program execution begins and ends there.
//
#define UMDF_USING_NTSTATUS
#include <iostream>
#include "SecretsProvisioningLibrary.h"
#ifndef DYNAMIC_SAMPLE
#include "Tss2Wrapper.h"
#include "TpmError.h"
#include "AesWrapper.h"
#ifdef PLATFORM_UNIX
#include "Linux/OsslAesWrapper.h"
#include "Linux/OsslECDiffieHellman.h"
#include "Linux/OsslHKDF.h"
#include "Linux/OsslError.h"
#else
#include "Windows/BcryptAesWrapper.h"
#include "Windows/BcryptECDiffieHellman.h"
#include "Windows/BcryptHKDF.h"
#include "BcryptError.h"
#endif // PLATFORM_UNIX
// #include "ECDiffieHellman.h"
// #include "HKDF.h"
#include "JsonWebToken.h"
#include "System.h"
#include <nlohmann/json.hpp>
std::vector<BYTE> MakeRandomBytes(size_t a_Length)
{
std::vector<BYTE> result(a_Length);
for (size_t i = 0; i < result.size(); i++)
{
result[i] = (BYTE)rand();
}
return result;
}
void GenerateKey() {
std::unique_ptr<Tss2Wrapper> tss2Wrapper;
try {
tss2Wrapper = std::make_unique <Tss2Wrapper>();
tss2Wrapper->GenerateGuestKey();
}
catch (TpmError e) {
std::cout << "Error in TPM " << e.getTPMError() << std::endl;
}
}
void RemoveKey() {
std::unique_ptr<Tss2Wrapper> tss2Wrapper;
try {
tss2Wrapper = std::make_unique <Tss2Wrapper>();
tss2Wrapper->RemoveKey();
}
catch (TpmError e) {
std::cout << "Error in TPM " << e.getTPMError() << std::endl;
}
}
bool IsKeyPresent() {
std::unique_ptr<Tss2Wrapper> tss2Wrapper;
bool isKeyPresent = false;
try {
tss2Wrapper = std::make_unique <Tss2Wrapper>();
isKeyPresent = tss2Wrapper->IsKeyPresent();
}
catch (TpmError e) {
std::cout << "Error in TPM " << e.getTPMError() << std::endl;
}
return isKeyPresent;
}
void GetVmidFromSmbios() {
std::string uuid = GetSystemUuid();
std::cout << "UUID: " << uuid << std::endl;
}
std::string Encrypt(const char* data) {
std::vector<unsigned char> secretData(data, data + strlen(data) + 1);
std::vector<unsigned char> ciphertextData;
std::vector<unsigned char> wrappedAesKey;
std::vector<unsigned char> encryptedSecret, encryptedEcdhPrivate;
std::vector<unsigned char> dataNonce, wrappingNonce;
std::unique_ptr<Tss2Wrapper> tss2Wrapper;
std::unique_ptr<AesWrapper> aesWrapper;
std::unique_ptr<AesCreator> aesCreator;
std::unique_ptr<AesChainingInfo> aesChainingInfo;
std::vector<unsigned char> wrappingKey, aesKey;
std::vector<unsigned char> exportedPublicKeyData, exportedPrivateKeyData;
std::vector<unsigned char> saltData = MakeRandomBytes(32);
std::unique_ptr<JsonWebToken> jwt;
std::string token;
try {
std::string infoString = GetSystemUuid();
std::vector<unsigned char> infoData(infoString.begin(), infoString.end());
// Generate ECDH key pair
#ifdef PLATFORM_UNIX
std::unique_ptr<OsslECDiffieHellman> ecdhPrivate = std::make_unique<OsslECDiffieHellman>();
std::unique_ptr<OsslECDiffieHellman> ecdhPublic = std::make_unique <OsslECDiffieHellman>();
#else
std::unique_ptr<BcryptECDiffieHellman> ecdhPrivate = std::make_unique<BcryptECDiffieHellman>();
std::unique_ptr<BcryptECDiffieHellman> ecdhPublic = std::make_unique<BcryptECDiffieHellman>();
#endif
ecdhPrivate->GenerateKeyPair();
ecdhPublic->GenerateKeyPair();
exportedPublicKeyData = ecdhPublic->ExportSubjectPublicKeyInfo();
exportedPrivateKeyData = ecdhPrivate->ExportPkcs8PrivateKey();
// Derive shared secret
#ifdef PLATFORM_UNIX
OsslHKDF hkdf = OsslHKDF(ecdhPrivate->DeriveSecret(*ecdhPublic));
#else
BcryptHKDF hkdf = BcryptHKDF(ecdhPrivate->DeriveSecret(*ecdhPublic));
#endif
aesKey = hkdf.DeriveKey(saltData, infoData, 32);
if (aesKey.size() == 0) {
std::cout << "Failed to derive key" << std::endl;
return token;
}
std::cout << "Derived key size" << aesKey.size() << std::endl;
// Prep AesWrapper
#ifdef PLATFORM_UNIX
aesCreator = std::make_unique<OsslGcmCreator>();
#else
aesCreator = std::make_unique<GcmCreator>();
#endif // PLATFORM_UNIX
aesWrapper = aesCreator->CreateAesWrapper();
// Encrypt the secret data
aesWrapper->SetKey(aesKey);
dataNonce = MakeRandomBytes(12);
aesChainingInfo = aesWrapper->SetChainingInfo(dataNonce);
encryptedSecret = aesWrapper->Encrypt(secretData, aesChainingInfo.get());
if (encryptedSecret.size() == 0) {
std::cout << "Failed to encrypt data" << std::endl;
return token;
}
std::cout << "Encrypted data size" << encryptedSecret.size() << std::endl;
// Encrypt the private key
std::vector<unsigned char> wrappingKey = MakeRandomBytes(32);
aesWrapper->SetKey(wrappingKey);
wrappingNonce = MakeRandomBytes(12);
aesChainingInfo = aesWrapper->SetChainingInfo(wrappingNonce);
encryptedEcdhPrivate = aesWrapper->Encrypt(exportedPrivateKeyData, aesChainingInfo.get());
if (encryptedSecret.size() == 0) {
std::cout << "Failed to encrypt data" << std::endl;
return token;
}
// Encrypt AES key with EK
tss2Wrapper = std::make_unique <Tss2Wrapper>();
printf("Generated EK\nPreparing to encrypt\n");
wrappedAesKey = tss2Wrapper->Tss2RsaEncrypt(wrappingKey);
if (wrappedAesKey.size() == 0) {
std::cout << "Failed to wrap key" << std::endl;
return token;
}
}
catch (TpmError e) {
std::cout << "Error in TPM " << e.getTPMError() << std::endl;
return token;
}
#ifndef PLATFORM_UNIX
catch (BcryptError e) {
std::cout << "Error in Bcrypt " << e.getErrorInfo() << std::endl;
return token;
}
#else
catch (OsslError e) {
std::cout << "Error in Ossl " << e.getErrorInfo() << std::endl;
return token;
}
#endif // !PLATFORM_UNIX
catch (std::exception e) {
std::cout << "Failed to encrypt data" << e.what() << std::endl;
return token;
}
// Prepare jwt
jwt = std::make_unique<JsonWebToken>();
json header = {
{"alg", "RS256"},
{"typ", "JWT"}
};
jwt->SetHeader(header);
json payload = {
{"salt", encoders::base64_encode(saltData)},
{"dataNonce", encoders::base64_encode(dataNonce)},
{"keyNonce", encoders::base64_encode(wrappingNonce)},
{"wrappedAesTransportKey", encoders::base64_encode(wrappedAesKey)},
{"encryptedSecret", encoders::base64_encode(encryptedSecret)},
{"encryptedGuestEcdhPrivateKey", encoders::base64_encode(encryptedEcdhPrivate)},
{"ephemeralEcdhPublicKey", encoders::base64_encode(exportedPublicKeyData)}
};
jwt->SetPayload(payload);
token = jwt->CreateToken();
return token;
}
#endif
std::string Decrypt(const char* jwt) {
std::string secret;
char* output_secret = nullptr;
int jwtlen = strlen(jwt); // hacky way to get the length of the jwt
long result = unprotect_secret((char*)(jwt), jwtlen, &output_secret);
if (result <= 0) {
std::cout << "Failed to unprotect secret" << std::hex << result << std::endl;
return secret;
}
if (output_secret != nullptr) {
secret = std::string(output_secret, result);
std::cout << "\n\nSecret: " << secret.c_str() << std::endl;
free_secret(output_secret);
}
return secret;
}