cvm-securekey-release-app/AttestationUtil.cpp (816 lines of code) (raw):

//------------------------------------------------------------------------------------------------- // <copyright file="AttestationUtil.cpp" company="Microsoft Corporation"> // Copyright (c) Microsoft Corporation. All rights reserved. // </copyright> //------------------------------------------------------------------------------------------------- // // TODO: Use OPENSSL_cleanse(buffer, sizeof(buffer)) to clear sensitive data from memory. #include <cstdlib> #include <ctime> #include <thread> #include <vector> #include <string> #include <sstream> #include <iomanip> #include <iostream> #include <boost/archive/iterators/base64_from_binary.hpp> #include <boost/archive/iterators/binary_from_base64.hpp> #include <boost/archive/iterators/transform_width.hpp> #include <boost/algorithm/string.hpp> #include <curl/curl.h> #include <nlohmann/json.hpp> #include <AttestationClient.h> #include "AttestationUtil.h" #include "Logger.h" #include "Constants.h" using namespace attest; using json = nlohmann::json; bool Util::isTraceOn = false; int Util::traceLevel = 1; /// \copydoc Util::base64_to_binary() std::vector<BYTE> Util::base64_to_binary(const std::string &base64_data) { using namespace boost::archive::iterators; using It = transform_width<binary_from_base64<std::string::const_iterator>, 8, 6>; return boost::algorithm::trim_right_copy_if(std::vector<BYTE>(It(std::begin(base64_data)), It(std::end(base64_data))), [](char c) { return c == '\0'; }); } /// \copydoc Util::binary_to_base64() std::string Util::binary_to_base64(const std::vector<BYTE> &binary_data) { using namespace boost::archive::iterators; using It = base64_from_binary<transform_width<std::vector<BYTE>::const_iterator, 6, 8>>; auto tmp = std::string(It(std::begin(binary_data)), It(std::end(binary_data))); return tmp.append((3 - binary_data.size() % 3) % 3, '='); } /// \copydoc Util::binary_to_hex() std::string Util::binary_to_hex(const std::vector<BYTE> &binary_data) { std::stringstream ss; ss << std::hex << std::setfill('0'); for (auto c : binary_data) { ss << std::setw(2) << static_cast<int>(c); } return ss.str(); } /// \copydoc Util::hex_to_binary() std::vector<BYTE> Util::hex_to_binary(const std::string &hex_data) { std::vector<BYTE> result; for (size_t i = 0; i < hex_data.length(); i += 2) { std::string byteString = hex_data.substr(i, 2); BYTE byte = (BYTE)strtol(byteString.c_str(), NULL, 16); result.push_back(byte); } return result; } /// \copydoc Util::binary_to_base64url() std::string Util::binary_to_base64url(const std::vector<BYTE> &binary_data) { using namespace boost::archive::iterators; using It = base64_from_binary<transform_width<std::vector<BYTE>::const_iterator, 6, 8>>; auto tmp = std::string(It(std::begin(binary_data)), It(std::end(binary_data))); // For encoding to base64url, replace "+" with "-" and "/" with "_" boost::replace_all(tmp, "+", "-"); boost::replace_all(tmp, "/", "_"); // We do not need to add padding characters while url encoding. return tmp; } /// \copydoc Util::base64url_to_binary() std::vector<BYTE> Util::base64url_to_binary(const std::string &base64_data) { std::string stringData = base64_data; // While decoding base64 url, replace - with + and _ with + and // use stanard base64 decode. we dont need to add padding characters. underlying library handles it. boost::replace_all(stringData, "-", "+"); boost::replace_all(stringData, "_", "/"); return base64_to_binary(stringData); } /// \copydoc Util::base64_decode() std::string Util::base64_decode(const std::string &data) { using namespace boost::archive::iterators; using It = transform_width<binary_from_base64<std::string::const_iterator>, 8, 6>; return boost::algorithm::trim_right_copy_if(std::string(It(std::begin(data)), It(std::end(data))), [](char c) { return c == '\0'; }); } /// \copydoc Util::url_encode() std::string Util::url_encode(const std::string &data) { std::string encoded_str{data}; CURL *curl = curl_easy_init(); if (!curl) { TRACE_ERROR_EXIT("curl_easy_init() failed") } char *output = curl_easy_escape(curl, data.c_str(), data.length()); if (output) { encoded_str = data; curl_free(output); } curl_easy_cleanup(curl); return encoded_str; } /// <summary> /// Callback for curl perform operation. /// </summary> size_t Util::CurlWriteCallback(char *data, size_t size, size_t nmemb, std::string *buffer) { size_t result = 0; if (buffer != NULL) { buffer->append(data, size * nmemb); result = size * nmemb; } return result; } /// Retrieve IMDS token retrieval URL for a resource url. /// eg, "http://169.254.169.254/metadata/identity/oauth2/token?api-version=2018-02-01&resource=https%3A%2F%2Fvault.azure.net"}; static inline std::string GetImdsTokenUrl(std::string url) { std::ostringstream oss; oss << Constants::IMDS_TOKEN_URL; oss << "?api-version=" << Constants::IMDS_API_VERSION; oss << "&resource=" << Util::url_encode(url); // Managed id is optional if there is only 1 client id registered for the VM. auto client_id = std::getenv("IMDS_CLIENT_ID"); if (client_id != nullptr && strlen(client_id) > 0) { oss << "&client_id=" << client_id; } else { auto object_id = std::getenv("IMDS_OBJECT_ID"); if (object_id != nullptr && strlen(object_id) > 0) { oss << "&object_id=" << object_id; } else { // If client id is not provided, msi_res_id (ARM resource id) could be provided. auto msi_res_id = std::getenv("IMDS_MSI_RES_ID"); if (msi_res_id != nullptr && strlen(msi_res_id) > 0) { oss << "&msi_res_id=" << Util::url_encode(msi_res_id); } } } TRACE_OUT("IMDS token URL: %s", oss.str().c_str()); return oss.str(); } // Define a utility method to determine the resource URL based on KEKUrl std::string getResourceUrl(const std::string &KEKUrl, bool isIMDS = true) { // Constants for suffixes and corresponding resource URLs const std::string AKV_URL_SUFFIX = Constants::AKV_URL_SUFFIX; const std::string MHSM_URL_SUFFIX = Constants::MHSM_URL_SUFFIX; const std::string AKV_RESOURCE_URL = Constants::AKV_RESOURCE_URL; const std::string MHSM_RESOURCE_URL = Constants::MHSM_RESOURCE_URL; // Check if AKV suffix is present in KEKUrl if (KEKUrl.find(AKV_URL_SUFFIX) != std::string::npos) { TRACE_OUT("AKV resource suffix found in KEKUrl"); return isIMDS ? AKV_RESOURCE_URL : AKV_RESOURCE_URL + "/.default"; } // If AKV suffix is not found, check if MHSM suffix is present else if (KEKUrl.find(MHSM_URL_SUFFIX) != std::string::npos) { TRACE_OUT("MHSM resource suffix found in KEKUrl"); return isIMDS ? MHSM_RESOURCE_URL : MHSM_RESOURCE_URL + "/.default"; } // If neither AKV nor MHSM suffix is found, throw an error else { TRACE_ERROR_EXIT("Invalid resource suffix found in KEKUrl: " + KEKUrl) } } /// \copydoc Util::GetIMDSToken() std::string Util::GetIMDSToken(const std::string &KEKUrl) { TRACE_OUT("Entering Util::GetIMDSToken()"); CURL *curl = curl_easy_init(); if (!curl) { TRACE_ERROR_EXIT("curl_easy_init() failed") } // AKV and mHSM has different audience need to be passed to IMDS. std::string resourceUrl = getResourceUrl(KEKUrl); CURLcode curlRet = curl_easy_setopt(curl, CURLOPT_URL, GetImdsTokenUrl(resourceUrl).c_str()); if (curlRet != CURLE_OK) { TRACE_ERROR_EXIT("curl_easy_setopt() failed") } // ByPassing proxy for IMDS. // ref: https://learn.microsoft.com/en-us/azure/virtual-machines/instance-metadata-service?tabs=windows curlRet = curl_easy_setopt(curl, CURLOPT_PROXY, ""); if (curlRet != CURLE_OK) { std::ostringstream oss; oss << "curl_easy_setopt() failed: " << curl_easy_strerror(curlRet); TRACE_ERROR_EXIT(oss.str().c_str()) } struct curl_slist *headers = NULL; headers = curl_slist_append(headers, "Metadata: true"); curlRet = curl_easy_setopt(curl, CURLOPT_HTTPHEADER, headers); if (curlRet != CURLE_OK) { TRACE_ERROR_EXIT("curl_easy_setopt() failed\n") } curlRet = curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, CurlWriteCallback); if (curlRet != CURLE_OK) { TRACE_ERROR_EXIT("curl_easy_setopt() failed") } std::string responseStr; curlRet = curl_easy_setopt(curl, CURLOPT_WRITEDATA, &responseStr); if (curlRet != CURLE_OK) { std::ostringstream oss; oss << "curl_easy_setopt() failed: " << curl_easy_strerror(curlRet); TRACE_ERROR_EXIT(oss.str().c_str()) } curlRet = curl_easy_perform(curl); if (curlRet != CURLE_OK) { std::ostringstream oss; oss << "curl_easy_perform() failed: " << curl_easy_strerror(curlRet); TRACE_ERROR_EXIT(oss.str().c_str()) } curl_easy_cleanup(curl); TRACE_OUT("Response: %s\n", Util::reduct_log(responseStr).c_str()); json json_object = json::parse(responseStr.c_str()); std::string access_token = json_object["access_token"].get<std::string>(); TRACE_OUT("Access Token: %s\n", Util::reduct_log(access_token).c_str()); TRACE_OUT("Exiting Util::GetIMDSToken()"); return access_token; } /// \copydoc Util::GetAADToken() std::string Util::GetAADToken(const std::string &KEKUrl) { TRACE_OUT("Entering Util::GetAADToken()"); auto clientId = std::getenv("AKV_SKR_CLIENT_ID"); auto clientSecret = std::getenv("AKV_SKR_CLIENT_SECRET"); auto tenantId = std::getenv("AKV_SKR_TENANT_ID"); std::string resourceUrl = getResourceUrl(KEKUrl, false); std::string tokenUrl = "https://login.microsoftonline.com/" + std::string(tenantId) + "/oauth2/v2.0/token"; std::string postData = "client_id=" + std::string(clientId) + "&client_secret=" + std::string(clientSecret) + "&grant_type=client_credentials&scope= " + resourceUrl; CURL *curl = curl_easy_init(); if (curl) { curl_easy_setopt(curl, CURLOPT_URL, tokenUrl.c_str()); curl_easy_setopt(curl, CURLOPT_POSTFIELDS, postData.c_str()); curl_easy_setopt(curl, CURLOPT_POSTFIELDSIZE, postData.length()); curl_slist *headers = nullptr; headers = curl_slist_append(headers, "Content-Type: application/x-www-form-urlencoded"); curl_easy_setopt(curl, CURLOPT_HTTPHEADER, headers); std::string response; curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, CurlWriteCallback); curl_easy_setopt(curl, CURLOPT_WRITEDATA, &response); CURLcode result = curl_easy_perform(curl); curl_slist_free_all(headers); curl_easy_cleanup(curl); if (result == CURLE_OK) { std::string token; json jsonResponse = json::parse(response); if (jsonResponse.contains("access_token")) { token = jsonResponse["access_token"].get<std::string>(); } else { TRACE_ERROR_EXIT("access_token not found in AAD auth response") } TRACE_OUT("Response: %s\n", token.c_str()); TRACE_OUT("Exiting Util::GetAADToken()"); return token; } else { TRACE_ERROR_EXIT("curl_easy_perform() failed for URL") } } else { TRACE_ERROR_EXIT("curl_easy_init() failed") } std::cerr << "Failed to obtain AKV AAD token" << std::endl; exit(-1); } /// \copydoc Util::GetMAAToken() // TODO: attestation server URL can be constructed from VM region if necessary. std::string Util::GetMAAToken(const std::string &attestation_url, const std::string &nonce) { TRACE_OUT("Entering Util::GetMAAToken()"); std::string attest_server_url; attest_server_url.assign(attestation_url); if (attest_server_url.empty()) { // use the default attestation url attest_server_url.assign(Constants::DEFAULT_ATTESTATION_URL); } std::string nonce_token; nonce_token.assign(nonce); if (nonce_token.empty()) { // use some random nonce nonce_token.assign(Constants::NONCE); } AttestationClient *attestation_client = nullptr; AttestationLogger *log_handle = new Logger(Util::get_trace()); // Initialize attestation client if (!Initialize(log_handle, &attestation_client)) { std::cerr << "Failed to create attestation client object" << std::endl; Uninitialize(); exit(-1); } // parameters for the Attest call attest::ClientParameters params = {}; params.attestation_endpoint_url = (PBYTE)attest_server_url.c_str(); std::string client_payload_str = "{\"nonce\": \"" + nonce_token + "\"}"; // nonce is optional params.client_payload = (PBYTE)client_payload_str.c_str(); params.version = CLIENT_PARAMS_VERSION; PBYTE jwt = nullptr; attest::AttestationResult result; bool is_cvm = false; bool attestation_success = true; std::string jwt_str; if ((result = attestation_client->Attest(params, &jwt)).code_ != attest::AttestationResult::ErrorCode::SUCCESS) { attestation_success = false; } if (attestation_success) { jwt_str = std::string(reinterpret_cast<char *>(jwt)); std::vector<std::string> tokens; boost::split(tokens, jwt_str, [](char c) { return c == '.'; }); if (tokens.size() < 3) { std::cerr << "Invalid JWT token" << std::endl; exit(-1); } json attestation_claims = json::parse(base64_decode(tokens[1])); try { std::string attestation_type = attestation_claims["x-ms-isolation-tee"]["x-ms-attestation-type"].get<std::string>(); std::string compliance_status = attestation_claims["x-ms-isolation-tee"]["x-ms-compliance-status"].get<std::string>(); if (boost::iequals(attestation_type, "sevsnpvm") && boost::iequals(compliance_status, "azure-compliant-cvm")) { is_cvm = true; } } catch (...) { } // sevsnp claim does not exist in the token attestation_client->Free(jwt); Uninitialize(); } TRACE_OUT("Exiting Util::GetMAAToken()"); return jwt_str; } /// \copydoc Util::SplitString() std::vector<std::string> Util::SplitString(const std::string &str, char delim) { TRACE_OUT("Entering Util::SplitString()"); std::vector<std::string> result; std::stringstream ss(str); std::string item; while (std::getline(ss, item, delim)) { result.push_back(item); } TRACE_OUT("Exiting Util::SplitString()"); return result; } /// Get the modulus size in bytes of RSA key. int RSA_get_size(EVP_PKEY *pkey) { int rsaModulusSize = 0; #if defined(OPENSSL_VERSION_MAJOR) && OPENSSL_VERSION_MAJOR >= 3 // It is OSSL >= 3.0 // TODO: investigate why EVP_PKEY_get_size causes SIGSEGV in OSSL 3.0 // rsaModulusSize = EVP_PKEY_get_size(pkey); // fallback to deprecated API until above issue is resolved. RSA *rsa = EVP_PKEY_get1_RSA(pkey); rsaModulusSize = RSA_size(rsa); #else RSA *rsa = EVP_PKEY_get1_RSA(pkey); rsaModulusSize = RSA_size(rsa); #endif return rsaModulusSize; } /// handle openssl errors static void handle_openssl_errors(void) { TRACE_OUT("Entering handle_openssl_errors()"); std::cerr << "Error in OpenSSL" << std::endl; ERR_print_errors_fp(stderr); unsigned long error; while ((error = ERR_get_error())) { char error_str[120]{}; ERR_error_string_n(error, error_str, sizeof(error_str)); std::cerr << "Error: " << error_str << std::endl; } TRACE_OUT("Exiting handle_openssl_errors()"); exit(-1); } /// Decrypt ciphertext using the key static int decrypt_aes_key_unwrap(PBYTE key, PBYTE ciphertext, int ciphertext_len, PBYTE plaintext) { TRACE_OUT("Entering decrypt_aes_key_unwrap()"); EVP_CIPHER_CTX *ctx; int len; int plaintext_len; /* Create and initialise the context */ if (!(ctx = EVP_CIPHER_CTX_new())) handle_openssl_errors(); EVP_CIPHER_CTX_set_flags(ctx, EVP_CIPHER_CTX_FLAG_WRAP_ALLOW); /* Initialise the decryption operation. */ if (1 != EVP_DecryptInit_ex(ctx, EVP_aes_256_wrap_pad(), NULL, NULL, NULL)) handle_openssl_errors(); if (1 != EVP_DecryptInit_ex(ctx, NULL, NULL, key, NULL)) handle_openssl_errors(); // Set padding to PKCS#8 /*if (1 != EVP_CIPHER_CTX_set_padding(ctx, 1)) { handle_openssl_errors(); }*/ if (1 != EVP_DecryptUpdate(ctx, plaintext, &len, ciphertext, ciphertext_len)) handle_openssl_errors(); plaintext_len = len; if (1 != EVP_DecryptFinal_ex(ctx, plaintext + len, &len)) handle_openssl_errors(); plaintext_len += len; EVP_CIPHER_CTX_free(ctx); TRACE_OUT("Exiting decrypt_aes_key_unwrap()"); return plaintext_len; } // Construct URL for secure key release. // Format: https://<keyvaultname>.vault.azure.net/keys/<keyname>/<keyversion>/release?api-version=7.3 std::string Util::GetKeyVaultSKRurl(const std::string &KEKUrl) { TRACE_OUT("Entering Util::GetKeyVaultSKRurl()"); std::ostringstream requestUri; requestUri << KEKUrl; requestUri << "/" << "release"; requestUri << "?" << "api-version"; requestUri << "=" << "7.3"; TRACE_OUT("Request URI: %s\n", requestUri.str().c_str()); TRACE_OUT("Exiting Util::GetKeyVaultSKRurl()"); return requestUri.str(); } std::string Util::GetKeyVaultResponse(const std::string &requestUri, const std::string &access_token, const std::string &attestation_token, const std::string &nonce) { TRACE_OUT("Entering Util::GetKeyVaultResponse()"); CURL *curl = curl_easy_init(); if (!curl) { TRACE_ERROR_EXIT("curl_easy_init() failed") } CURLcode curlRet = curl_easy_setopt(curl, CURLOPT_URL, requestUri.c_str()); if (curlRet != CURLE_OK) { TRACE_ERROR_EXIT("curl_easy_setopt() failed for URL") } curlRet = curl_easy_setopt(curl, CURLOPT_POST, 1L); if (curlRet != CURLE_OK) { TRACE_ERROR_EXIT("curl_easy_setopt() failed for POST") } curlRet = curl_easy_setopt(curl, CURLOPT_HTTP_VERSION, CURL_HTTP_VERSION_1_1); if (curlRet != CURLE_OK) { TRACE_ERROR_EXIT("curl_easy_setopt() failed for HTTP_VERSION") } struct curl_slist *headers = NULL; std::ostringstream bearerToken; bearerToken << "Authorization: Bearer " << access_token; headers = curl_slist_append(headers, bearerToken.str().c_str()); TRACE_OUT("Bearer token: %s", Util::reduct_log(bearerToken.str()).c_str()); headers = curl_slist_append(headers, "Content-Type: application/json"); headers = curl_slist_append(headers, "Accept: application/json"); headers = curl_slist_append(headers, "User-Agent: AzureDiskEncryption"); curlRet = curl_easy_setopt(curl, CURLOPT_HTTPHEADER, headers); if (curlRet != CURLE_OK) { TRACE_ERROR_EXIT("curl_easy_setopt() failed\n") } std::ostringstream requestBody; std::string nonce_token; nonce_token.assign(nonce); if (nonce_token.empty()) { // use some random nonce nonce_token.assign(Constants::NONCE); } requestBody << "{"; requestBody << "\"nonce\": \"" + nonce_token + "\","; requestBody << "\"target\": \"" << attestation_token << "\","; requestBody << "\"enc\": \"CKM_RSA_AES_KEY_WRAP\""; requestBody << "}"; std::string requestBodyStr(requestBody.str()); // TRACE_OUT("requestBody: size=%d, '%s'", requestBodyStr.size(), requestBodyStr.c_str()); curlRet = curl_easy_setopt(curl, CURLOPT_POSTFIELDS, requestBodyStr.c_str()); if (curlRet != CURLE_OK) { TRACE_ERROR_EXIT("curl_easy_setopt() failed for CURLOPT_POSTFIELDS\n") } curlRet = curl_easy_setopt(curl, CURLOPT_POSTFIELDSIZE, (long)requestBodyStr.size()); if (curlRet != CURLE_OK) { TRACE_ERROR_EXIT("curl_easy_setopt() failed for CURLOPT_POSTFIELDSIZE\n") } // Enable verbose output from curl for debugging. /* curlRet = curl_easy_setopt(curl, CURLOPT_VERBOSE, 1L); if (curlRet != CURLE_OK) { TRACE_ERROR_EXIT("curl_easy_setopt() failed for CURLOPT_VERBOSE\n") } */ char errbuf[CURL_ERROR_SIZE] = { 0, }; curlRet = curl_easy_setopt(curl, CURLOPT_ERRORBUFFER, errbuf); if (curlRet != CURLE_OK) { size_t len = strlen(errbuf); std::cerr << "libcurl: " << curlRet << std::endl; if (len) std::cerr << errbuf << (errbuf[len - 1] != '\n') ? "\n" : ""; std::cerr << curl_easy_strerror(curlRet) << std::endl; TRACE_ERROR_EXIT("curl_easy_setopt() failed for CURLOPT_ERRORBUFFER\n") } // DEBUG only, when a proxy is needed such as Fiddler. curl_easy_setopt(curl, CURLOPT_SSL_VERIFYPEER, 0L); curl_easy_setopt(curl, CURLOPT_SSL_VERIFYHOST, 0L); curlRet = curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, CurlWriteCallback); if (curlRet != CURLE_OK) { TRACE_ERROR_EXIT("curl_easy_setopt() failed") } #ifndef PLATFORM_UNIX curl_easy_setopt(curl, CURLOPT_CAINFO, "curl-ca-bundle.crt"); #endif std::string responseStr; curlRet = curl_easy_setopt(curl, CURLOPT_WRITEDATA, &responseStr); if (curlRet != CURLE_OK) { std::ostringstream oss; oss << "curl_easy_setopt() failed: " << curl_easy_strerror(curlRet); TRACE_ERROR_EXIT(oss.str().c_str()) } // Perform the request, check the return code curlRet = curl_easy_perform(curl); // Check for errors if (curlRet != CURLE_OK) { std::ostringstream oss; oss << "curl_easy_perform() failed: " << curl_easy_strerror(curlRet); TRACE_ERROR_EXIT(oss.str().c_str()) } /* switch (code) { case CURLE_COULDNT_RESOLVE_HOST: case CURLE_COULDNT_RESOLVE_PROXY: case CURLE_COULDNT_CONNECT: case CURLE_WRITE_ERROR: STATSCOUNTER_INC(indexConFail, mutIndexConFail); return RS_RET_SUSPENDED; default: STATSCOUNTER_INC(indexSubmit, mutIndexSubmit); return RS_RET_OK; } */ // Cleanup curl curl_slist_free_all(headers); curl_easy_cleanup(curl); TRACE_OUT("SKR response: %s", Util::reduct_log(responseStr).c_str()); TRACE_OUT("Exiting Util::GetKeyVaultResponse()"); return responseStr; } bool Util::doSKR(const std::string &attestation_url, const std::string &nonce, std::string KEKUrl, EVP_PKEY **pkey, const Util::AkvCredentialSource &akv_credential_source) { TRACE_OUT("Entering Util::doSKR()"); try { std::string attest_token(Util::GetMAAToken(attestation_url, nonce)); TRACE_OUT("MAA Token: %s", Util::reduct_log(attest_token).c_str()); // Get Akv access token either using IMDS or Service Principal std::string access_token; if (akv_credential_source == Util::AkvCredentialSource::EnvServicePrincipal) { access_token = std::move(Util::GetAADToken(KEKUrl)); } else { access_token = std::move(Util::GetIMDSToken(KEKUrl)); } TRACE_OUT("AkvMsiAccessToken: %s", Util::reduct_log(access_token).c_str()); std::string requestUri = Util::GetKeyVaultSKRurl(KEKUrl); std::string responseStr = Util::GetKeyVaultResponse(requestUri, access_token, attest_token, nonce); // Parse the response: json skrJson = json::parse(responseStr.c_str()); std::string skrToken = skrJson["value"]; TRACE_OUT("SKR token: %s", Util::reduct_log(skrToken).c_str()); std::vector<std::string> tokenParts = Util::SplitString(skrToken, '.'); if (tokenParts.size() != 3) { TRACE_ERROR_EXIT("Invalid SKR token") } std::vector<BYTE> tokenPayload(Util::base64url_to_binary(tokenParts[1])); std::string tokenPayloadStr(tokenPayload.begin(), tokenPayload.end()); TRACE_OUT("SKR token payload: %s", Util::reduct_log(tokenPayloadStr).c_str()); json skrPayloadJson = json::parse(tokenPayloadStr.c_str()); std::vector<BYTE> key_hsm = Util::base64url_to_binary(skrPayloadJson["response"]["key"]["key"]["key_hsm"]); TRACE_OUT("SKR key_hsm: %s", Util::reduct_log(Util::binary_to_base64url(key_hsm)).c_str()); json cipherTextJson = json::parse(key_hsm); std::vector<BYTE> cipherText = Util::base64url_to_binary(cipherTextJson["ciphertext"]); TRACE_OUT("Encrypted bytes length: %ld", cipherText.size()); std::string cipherTextStr(cipherText.begin(), cipherText.end()); TRACE_OUT("Encrypted bytes: %s", Util::reduct_log(Util::binary_to_base64url(cipherText)).c_str()); AttestationClient *attestation_client = nullptr; AttestationLogger *log_handle = new Logger(Util::get_trace()); // Initialize attestation client if (!Initialize(log_handle, &attestation_client)) { printf("Failed to create attestation client object\n"); Uninitialize(); exit(1); } // gsl::span<const BYTE> payload = { cipherText + headerSize, cipherText - headerSize }; attest::AttestationResult result; int RSASize = 2048; int ModulusSize = RSASize / 8; uint8_t *decryptedAESBytes = nullptr; uint32_t decryptedBytesSize = 0; result = attestation_client->Decrypt(attest::EncryptionType::NONE, cipherText.data(), ModulusSize, NULL, 0, &decryptedAESBytes, &decryptedBytesSize, attest::RsaScheme::RsaOaep, // mHSM uses RSA-OAEP wrapping attest::RsaHashAlg::RsaSha1 // mHSM uses SHA1 hashing ); if (result.code_ != attest::AttestationResult::ErrorCode::SUCCESS) { printf("Failed to decrypt the AES key. Error code: %d, TPM error code=%d, Desc=%s\n", static_cast<int>(result.code_), result.tpm_error_code_, result.description_.c_str()); exit(1); } else { std::vector<BYTE> decryptedAESBytesVec(decryptedAESBytes, decryptedAESBytes + decryptedBytesSize); TRACE_OUT("Decrypted Transfer key: %s\n", Util::reduct_log(Util::binary_to_base64url(decryptedAESBytesVec)).c_str()); } // The remaining bytes are the encrypted CMK bytes with the decrypted AES key. // use openssl AES to decrypt the CMK bytes. BYTE private_key[8192]; int private_key_len = 0; private_key_len = decrypt_aes_key_unwrap(decryptedAESBytes, cipherText.data() + ModulusSize, (int)(cipherText.size() - ModulusSize), private_key); if (private_key_len == 0) { printf("Failed to decrypt the CMK\n"); exit(1); } else { TRACE_OUT("CMK private key has length=%d", private_key_len); std::vector<BYTE> privateKeyVec(private_key, private_key + private_key_len); TRACE_OUT("Decrypted CMK in base64url: %s", Util::reduct_log(Util::binary_to_base64url(privateKeyVec)).c_str()); TRACE_OUT("Decrypted CMK in hex: %s", Util::reduct_log(Util::binary_to_hex(privateKeyVec)).c_str()); // PKCS#8 BIO *bio_key = BIO_new_mem_buf(privateKeyVec.data(), (int)privateKeyVec.size()); if (!bio_key) { std::cerr << "Error creating memory BIO" << std::endl; exit(-1); } *pkey = d2i_PrivateKey_bio(bio_key, NULL); if (!*pkey) { // error handling std::cout << "Failed to load the priv key" << std::endl; ERR_print_errors_fp(stderr); // input data is not in correct format char buf[120]; ERR_error_string(ERR_get_error(), buf); printf("PKCS8 format check failed: %s\n", buf); exit(-1); } BIO_free(bio_key); return true; } // Cleanup Uninitialize(); delete log_handle; log_handle = nullptr; } catch (std::exception &e) { printf("Exception occured. Details - %s", e.what()); exit(1); } TRACE_OUT("Exiting Util::doSKR()"); return true; } // A helper function to handle errors void handleErrors() { ERR_print_errors_fp(stderr); abort(); } // A function that encrypts a message with a public key using EVP_PKEY_encrypt int rsa_encrypt(EVP_PKEY *pkey, const PBYTE msg, size_t msglen, PBYTE *enc, size_t *enclen) { TRACE_OUT("Entering rsa_encrypt()"); int ret = -1; EVP_PKEY_CTX *ctx = NULL; size_t outlen; // Create the context for the encryption operation ctx = EVP_PKEY_CTX_new(pkey, NULL); if (!ctx) handleErrors(); // Initialize the encryption operation if (EVP_PKEY_encrypt_init(ctx) <= 0) handleErrors(); #if defined(OPENSSL_VERSION_MAJOR) && OPENSSL_VERSION_MAJOR >= 3 // TODO: investiagate why setting padding and md algorithms causing SIGSEGV in OSSL 3.x #else // Set the RSA padding mode to either PKCS #1 OAEP if (EVP_PKEY_CTX_set_rsa_padding(ctx, RSA_PKCS1_OAEP_PADDING) <= 0) handleErrors(); // Set RSA signature scheme to SHA256 if (EVP_PKEY_CTX_set_rsa_oaep_md(ctx, EVP_sha256()) <= 0) handleErrors(); #endif // Determine the buffer length for the encrypted data if (EVP_PKEY_encrypt(ctx, NULL, &outlen, msg, msglen) <= 0) handleErrors(); // Allocate memory for the encrypted data *enc = (PBYTE)OPENSSL_malloc(outlen); if (!*enc) handleErrors(); // Perform the encryption operation if (EVP_PKEY_encrypt(ctx, *enc, &outlen, msg, msglen) <= 0) handleErrors(); // Set the encrypted data length *enclen = outlen; // Clean up and return success ret = 0; EVP_PKEY_CTX_free(ctx); TRACE_OUT("Exiting rsa_encrypt()"); return ret; } // A function that encrypts a message with a public key using EVP_PKEY_encrypt int rsa_decrypt(EVP_PKEY *pkey, const PBYTE msg, size_t msglen, PBYTE *dec, size_t *declen) { TRACE_OUT("Entering rsa_decrypt()"); int ret = -1; EVP_PKEY_CTX *ctx = NULL; size_t outlen; // Create the context for the encryption operation ctx = EVP_PKEY_CTX_new(pkey, NULL); if (!ctx) handleErrors(); // Initialize the encryption operation if (EVP_PKEY_decrypt_init(ctx) <= 0) handleErrors(); #if defined(OPENSSL_VERSION_MAJOR) && OPENSSL_VERSION_MAJOR >= 3 // TODO: investiagate why setting padding and md algorithms causing SIGSEGV in OSSL 3.x #else // Set the RSA padding mode to PKCS #1 OAEP if (EVP_PKEY_CTX_set_rsa_padding(ctx, RSA_PKCS1_OAEP_PADDING) <= 0) handleErrors(); // Set RSA signature scheme to SHA256 if (EVP_PKEY_CTX_set_rsa_oaep_md(ctx, EVP_sha256()) <= 0) // TODO: can be a parameter handleErrors(); #endif // Determine the buffer length for the encrypted data if (EVP_PKEY_decrypt(ctx, NULL, &outlen, msg, msglen) <= 0) handleErrors(); // Allocate memory for the encrypted data *dec = (PBYTE)OPENSSL_malloc(outlen); if (!*dec) handleErrors(); // Perform the encryption operation if (EVP_PKEY_decrypt(ctx, *dec, &outlen, msg, msglen) <= 0) handleErrors(); // Set the encrypted data length *declen = outlen; // Clean up and return success ret = 0; EVP_PKEY_CTX_free(ctx); TRACE_OUT("Exiting rsa_decrypt()"); return ret; } std::string Util::WrapKey(const std::string &attestation_url, const std::string &nonce, const std::string &sym_key, const std::string &key_enc_key_url, const Util::AkvCredentialSource &akv_credential_source) { TRACE_OUT("Entering Util::WrapKey()"); EVP_PKEY *pkey = nullptr; if (!Util::doSKR(attestation_url, nonce, key_enc_key_url, &pkey, akv_credential_source)) { std::cerr << "Failed to release the private key" << std::endl; exit(-1); } int pkeyBaseId = EVP_PKEY_base_id(pkey); TRACE_OUT("Key release completed successfully. EVP_PKEY_base_id=%d", pkeyBaseId); // Check if the key is of type RSA. If not, exit because EC keys do not support wrapKey/unwrapKey^M if (pkeyBaseId != EVP_PKEY_RSA /* PKCS1 */ && pkeyBaseId != EVP_PKEY_RSA2 /* X500 */) { std::cerr << "The key is not of type RSA. Only RSA keys are supported for wrapKey/unwrapKey" << std::endl; exit(-1); } int rsaSize = RSA_get_size(pkey); TRACE_OUT("Wrapping: %s", Util::reduct_log(sym_key).c_str()); size_t encrypted_length = 0; PBYTE encryptedKey; if (rsa_encrypt(pkey, (const PBYTE)sym_key.c_str(), sym_key.size(), &encryptedKey, &encrypted_length) == -1) { std::cerr << "Failed to wrap the symmetric key: " << std::endl; handle_openssl_errors(); exit(-1); } TRACE_OUT("Wrapping the symmetric key succeeded: encrypted_length=%ld\n", encrypted_length); std::vector<BYTE> encryptedKeyVector(encryptedKey, encryptedKey + encrypted_length); std::string cipherText = Util::binary_to_base64(encryptedKeyVector); TRACE_OUT("Wrapped symmetric key in base64: %s\n", Util::reduct_log(cipherText).c_str()); // Cleanup OPENSSL_free(encryptedKey); EVP_PKEY_free(pkey); TRACE_OUT("Exiting Util::WrapKey()"); return cipherText; } std::string Util::UnwrapKey(const std::string &attestation_url, const std::string &nonce, const std::string &wrapped_key_base64, const std::string &key_enc_key_url, const Util::AkvCredentialSource &akv_credential_source) { TRACE_OUT("Entering Util::UnwrapKey()"); EVP_PKEY *pkey = nullptr; if (!Util::doSKR(attestation_url, nonce, key_enc_key_url, &pkey, akv_credential_source)) { std::cerr << "Failed to release the private key" << std::endl; exit(-1); } int pkeyBaseId = EVP_PKEY_base_id(pkey); TRACE_OUT("Key release completed successfully. EVP_PKEY_base_id=%d", pkeyBaseId); // Check if the key is of type RSA. If not, exit because EC keys do not support wrapKey/unwrapKey^M if (pkeyBaseId != EVP_PKEY_RSA /* PKCS1 */ && pkeyBaseId != EVP_PKEY_RSA2 /* X500 */) { std::cerr << "The key is not of type RSA. Only RSA keys are supported for wrapKey/unwrapKey" << std::endl; exit(-1); } int rsaSize = RSA_get_size(pkey); TRACE_OUT("Unwrapping: %s\n", wrapped_key_base64.c_str()); std::vector<BYTE> wrapped_key = Util::base64_to_binary(wrapped_key_base64); size_t decrypted_length = 0; PBYTE decryptedKey; if (rsa_decrypt(pkey, wrapped_key.data(), wrapped_key.size(), &decryptedKey, &decrypted_length) == -1) { std::cerr << "Failed to unwrap the symmetric key: " << std::endl; handle_openssl_errors(); exit(-1); } TRACE_OUT("Unwrapping the symmetric key succeeded: decrypted_length=%lud", decrypted_length); std::vector<BYTE> decryptedKeyVector(decryptedKey, decryptedKey + decrypted_length); std::string plainText = Util::binary_to_base64(decryptedKeyVector); TRACE_OUT("Unwrapped symmetric key in base64: %s", Util::reduct_log(plainText).c_str()); TRACE_OUT("Exiting Util::UnwrapKey()"); // Cleanup OPENSSL_free(decryptedKey); EVP_PKEY_free(pkey); return Util::base64_decode(plainText); } bool Util::ReleaseKey(const std::string &attestation_url, const std::string &nonce, const std::string &key_enc_key_url, const Util::AkvCredentialSource &akv_credential_source) { TRACE_OUT("Entering Util::ReleaseKey()"); EVP_PKEY *pkey = nullptr; if (!Util::doSKR(attestation_url, nonce, key_enc_key_url, &pkey, akv_credential_source)) { std::cerr << "Failed to release the private key" << std::endl; return false; } TRACE_OUT("Key release completed successfully."); // Check if the key is of type RSA. If not, exit because EC keys do not support wrapKey/unwrapKey switch (EVP_PKEY_base_id(pkey)) { case EVP_PKEY_RSA: case EVP_PKEY_RSA2: std::cout << "The released key is of type RSA. It can be used for wrapKey/unwrapKey operations." << std::endl; return true; case EVP_PKEY_EC: std::cout << "The released key is of type EC. It can be used for sign/verify operations." << std::endl; return true; default: std::cout << "The released key is of type " << EVP_PKEY_base_id(pkey) << ". Not sure what operations are supported." << std::endl; return false; } }