azure-protected-vm-secrets/SecretsProvisioningLibrary.cpp (162 lines of code) (raw):

//SecretsProvisioningLibrary.cpp : Defines the functions for the static library. // #define WIN32_LEAN_AND_MEAN #include "pch.h" #include <memory> #include <vector> #include <iostream> #include "LibraryLogger.h" #include "TpmError.h" #ifndef PLATFORM_UNIX #include "BcryptError.h" #endif // !PLATFORM_UNIX #include "AesWrapper.h" #include "Tpm.h" #include "JsonWebToken.h" #include "System.h" #ifdef PLATFORM_UNIX #include "Linux/OsslAesWrapper.h" #include "Linux/OsslECDiffieHellman.h" #include "Linux/OsslHKDF.h" #include "Linux/OsslError.h" #else #include "Windows/BcryptAesWrapper.h" #include "Windows/BcryptECDiffieHellman.h" #include "Windows/BcryptHKDF.h" #endif // PLATFORM_UNIX #ifdef __cplusplus extern "C" { #endif // __cplusplus using namespace SecretsLogger; // See header file for function description #ifdef DYNAMICSECRETSPROVISIONINGLIBRARY_EXPORTS __declspec(dllexport) #endif // DYNAMICSECRETSPROVISIONINGLIBRARY_EXPORTS long unprotect_secret(char* jwt, unsigned int jwtlen, char** output_secret) { std::unique_ptr<char*> inOutputSecret; long result = 0; std::unique_ptr<AesWrapper> aesWrapper; std::unique_ptr<AesCreator> aesCreator; std::unique_ptr<AesChainingInfo> aesChainingInfo; std::unique_ptr<JsonWebToken> jwtObj; std::vector<unsigned char> salt; std::vector<unsigned char> dataNonce; std::vector<unsigned char> wrappingNonce; std::vector<unsigned char> wrappedAesKey; std::vector<unsigned char> encryptedSecret; std::vector<unsigned char> encryptedEcdhPrivate; std::vector<unsigned char> exportedPublicKeyData; std::string infoString = GetSystemUuid(); std::vector<unsigned char> infoData(infoString.begin(), infoString.end()); std::string jwtStr(jwt, jwt + jwtlen); LIBSECRETS_LOG(LogLevel::Debug, "Unprotect Secret\n", "JWT %s", jwtStr.c_str()); jwtObj = std::make_unique<JsonWebToken>(); try { // Parse the JWT jwtObj->ParseToken(jwtStr, true); json claims = jwtObj->getClaims(); LIBSECRETS_LOG(LogLevel::Debug, "Unprotect Secret\n", "JWT claims\n %s", claims.dump(4).c_str()); salt = encoders::base64_decode(claims["salt"]); dataNonce = encoders::base64_decode(claims["dataNonce"]); wrappingNonce = encoders::base64_decode(claims["keyNonce"]); wrappedAesKey = encoders::base64_decode(claims["wrappedAesTransportKey"]); encryptedSecret = encoders::base64_decode(claims["encryptedSecret"]); encryptedEcdhPrivate = encoders::base64_decode(claims["encryptedGuestEcdhPrivateKey"]); exportedPublicKeyData = encoders::base64_decode(claims["ephemeralEcdhPublicKey"]); Tpm tpm{}; std::vector<unsigned char> aesKey = tpm.RsaDecrypt(wrappedAesKey); if (aesKey.size() == 0) { printf("Failed to decrypt data\n"); result = LONG_MIN; return result; } #ifndef PLATFORM_UNIX aesCreator = std::make_unique<GcmCreator>(); #else aesCreator = std::make_unique<OsslGcmCreator>(); #endif // !PLATFORM_UNIX aesWrapper = aesCreator->CreateAesWrapper(); aesWrapper->SetKey(aesKey); aesChainingInfo = aesWrapper->SetChainingInfo(wrappingNonce); std::vector<unsigned char> encodedEcdhPrivate = aesWrapper->Decrypt(encryptedEcdhPrivate, aesChainingInfo.get()); if (encodedEcdhPrivate.size() == 0) { LIBSECRETS_LOG(LogLevel::Error, "TPM Decrypt\n", "ptext len %d", encodedEcdhPrivate.size()); result = LONG_MIN; return result; } // Import ECDH keys #ifdef PLATFORM_UNIX std::unique_ptr<OsslECDiffieHellman> ecdhPrivate = std::make_unique<OsslECDiffieHellman>(); std::unique_ptr<OsslECDiffieHellman> ecdhPublic = std::make_unique<OsslECDiffieHellman>(); #else std::unique_ptr<BcryptECDiffieHellman> ecdhPrivate = std::make_unique<BcryptECDiffieHellman>(); std::unique_ptr<BcryptECDiffieHellman> ecdhPublic = std::make_unique<BcryptECDiffieHellman>(); #endif ecdhPrivate->ImportPkcs8PrivateKey(encodedEcdhPrivate); ecdhPublic->ImportSubjectPublicKeyInfo(exportedPublicKeyData); // Derive shared secret #ifdef PLATFORM_UNIX std::unique_ptr <OsslHKDF> hkdf = std::make_unique<OsslHKDF>(ecdhPrivate->DeriveSecret(*ecdhPublic)); #else std::unique_ptr <BcryptHKDF> hkdf = std::make_unique<BcryptHKDF>(ecdhPrivate->DeriveSecret(*ecdhPublic)); #endif aesKey = hkdf->DeriveKey(salt, infoData, 32); // Decrypt the secret aesWrapper->SetKey(aesKey); aesChainingInfo = aesWrapper->SetChainingInfo(dataNonce); std::vector<unsigned char> plaintextData = aesWrapper->Decrypt(encryptedSecret, aesChainingInfo.get()); // Marshal to C string std::unique_ptr<char[]> inOutputSecret(new char[plaintextData.size()]); if (inOutputSecret == nullptr) { LIBSECRETS_LOG(LogLevel::Warning, "Unprotect Secret\n", "Pointer allocation failed"); return LONG_MIN; } std::copy(plaintextData.begin(), plaintextData.end(), inOutputSecret.get()); *output_secret = inOutputSecret.release(); result = static_cast<long>(plaintextData.size()); } catch (TpmError err) { LIBSECRETS_LOG(LogLevel::Error, "TPM Decrypt\n", "TPM error 0x%x occurred\n Description %s", err.getReturnCode(), err.getTPMError()); result = (long)err.GetLibRC(); } #ifndef PLATFORM_UNIX catch (BcryptError err) { LIBSECRETS_LOG(LogLevel::Error, "Bcrypt Decrypt\n", "Bcrypt status 0x%x occurred\n Message %s\t Bcrypt Info%s", err.getStatusCode(), err.what(), err.getErrorInfo()); result = (long)err.GetLibRC(); } catch (WinCryptError err) { LIBSECRETS_LOG(LogLevel::Error, "WinCrypt Decode\n", "Message %s\t Bcrypt Info%s", err.what(), err.GetErrorMessage()); result = (long)err.GetLibRC(); } #else catch (OsslError err) { LIBSECRETS_LOG(LogLevel::Error, "Openssl Decode\n", "Message %s\t Openssl Info%s", err.what(), err.getErrorInfo()); result = (long)ErrorCode::CryptographyError; } #endif // !PLATFORM_UNIX catch (JwtError err) { LIBSECRETS_LOG(LogLevel::Error, "JWT Validation\n", "JWT error occurred\n Message %s", err.what()); result = (long)err.GetLibRC(); } catch (std::runtime_error e) { LIBSECRETS_LOG(LogLevel::Error, "runtime Exception\n", "error info %s", e.what()); result = LONG_MIN; } catch (std::exception e) { LIBSECRETS_LOG(LogLevel::Error, "Standard Exception\n", "error info %s", e.what()); result = LONG_MIN; } return result; } #ifdef DYNAMICSECRETSPROVISIONINGLIBRARY_EXPORTS __declspec(dllexport) #endif // DYNAMICSECRETSPROVISIONINGLIBRARY_EXPORTS void free_secret(char* secret) { if (secret != nullptr) delete[] secret; } #ifdef __cplusplus } #endif // __cplusplus