bool Util::doSKR()

in cvm-securekey-release-app/AttestationUtil.cpp [707:854]


bool Util::doSKR(const std::string &attestation_url,
                 const std::string &nonce,
                 std::string KEKUrl,
                 EVP_PKEY **pkey,
                 const Util::AkvCredentialSource &akv_credential_source)
{
    TRACE_OUT("Entering Util::doSKR()");

    try
    {
        std::string attest_token(Util::GetMAAToken(attestation_url, nonce));
        TRACE_OUT("MAA Token: %s", Util::reduct_log(attest_token).c_str());

        // Get Akv access token either using IMDS or Service Principal
        std::string access_token;
        if (akv_credential_source == Util::AkvCredentialSource::EnvServicePrincipal)
        {
            access_token = std::move(Util::GetAADToken(KEKUrl));
        }
        else
        {
            access_token = std::move(Util::GetIMDSToken(KEKUrl));
        }

        TRACE_OUT("AkvMsiAccessToken: %s", Util::reduct_log(access_token).c_str());

        std::string requestUri = Util::GetKeyVaultSKRurl(KEKUrl);
        std::string responseStr = Util::GetKeyVaultResponse(requestUri, access_token, attest_token, nonce);

        // Parse the response:
        json skrJson = json::parse(responseStr.c_str());
        std::string skrToken = skrJson["value"];
        TRACE_OUT("SKR token: %s", Util::reduct_log(skrToken).c_str());
        std::vector<std::string> tokenParts = Util::SplitString(skrToken, '.');
        if (tokenParts.size() != 3)
        {
            TRACE_ERROR_EXIT("Invalid SKR token")
        }

        std::vector<BYTE> tokenPayload(Util::base64url_to_binary(tokenParts[1]));
        std::string tokenPayloadStr(tokenPayload.begin(), tokenPayload.end());
        TRACE_OUT("SKR token payload: %s", Util::reduct_log(tokenPayloadStr).c_str());
        json skrPayloadJson = json::parse(tokenPayloadStr.c_str());
        std::vector<BYTE> key_hsm = Util::base64url_to_binary(skrPayloadJson["response"]["key"]["key"]["key_hsm"]);
        TRACE_OUT("SKR key_hsm: %s", Util::reduct_log(Util::binary_to_base64url(key_hsm)).c_str());
        json cipherTextJson = json::parse(key_hsm);
        std::vector<BYTE> cipherText = Util::base64url_to_binary(cipherTextJson["ciphertext"]);
        TRACE_OUT("Encrypted bytes length: %ld", cipherText.size());
        std::string cipherTextStr(cipherText.begin(), cipherText.end());
        TRACE_OUT("Encrypted bytes: %s", Util::reduct_log(Util::binary_to_base64url(cipherText)).c_str());

        AttestationClient *attestation_client = nullptr;
        AttestationLogger *log_handle = new Logger(Util::get_trace());

        // Initialize attestation client
        if (!Initialize(log_handle, &attestation_client))
        {
            printf("Failed to create attestation client object\n");
            Uninitialize();
            exit(1);
        }
        // gsl::span<const BYTE> payload = { cipherText + headerSize, cipherText - headerSize };

        attest::AttestationResult result;
        int RSASize = 2048;
        int ModulusSize = RSASize / 8;
        uint8_t *decryptedAESBytes = nullptr;
        uint32_t decryptedBytesSize = 0;
        result = attestation_client->Decrypt(attest::EncryptionType::NONE,
                                             cipherText.data(),
                                             ModulusSize,
                                             NULL,
                                             0,
                                             &decryptedAESBytes,
                                             &decryptedBytesSize,
                                             attest::RsaScheme::RsaOaep, // mHSM uses RSA-OAEP wrapping
                                             attest::RsaHashAlg::RsaSha1 // mHSM uses SHA1 hashing
        );
        if (result.code_ != attest::AttestationResult::ErrorCode::SUCCESS)
        {
            printf("Failed to decrypt the AES key. Error code: %d, TPM error code=%d, Desc=%s\n", static_cast<int>(result.code_), result.tpm_error_code_, result.description_.c_str());
            exit(1);
        }
        else
        {
            std::vector<BYTE> decryptedAESBytesVec(decryptedAESBytes, decryptedAESBytes + decryptedBytesSize);
            TRACE_OUT("Decrypted Transfer key: %s\n", Util::reduct_log(Util::binary_to_base64url(decryptedAESBytesVec)).c_str());
        }

        // The remaining bytes are the encrypted CMK bytes with the decrypted AES key.
        // use openssl AES to decrypt the CMK bytes.
        BYTE private_key[8192];
        int private_key_len = 0;
        private_key_len = decrypt_aes_key_unwrap(decryptedAESBytes,
                                                 cipherText.data() + ModulusSize,
                                                 (int)(cipherText.size() - ModulusSize),
                                                 private_key);
        if (private_key_len == 0)
        {
            printf("Failed to decrypt the CMK\n");
            exit(1);
        }
        else
        {
            TRACE_OUT("CMK private key has length=%d", private_key_len);
            std::vector<BYTE> privateKeyVec(private_key, private_key + private_key_len);
            TRACE_OUT("Decrypted CMK in base64url: %s", Util::reduct_log(Util::binary_to_base64url(privateKeyVec)).c_str());
            TRACE_OUT("Decrypted CMK in hex: %s", Util::reduct_log(Util::binary_to_hex(privateKeyVec)).c_str());

            // PKCS#8
            BIO *bio_key = BIO_new_mem_buf(privateKeyVec.data(), (int)privateKeyVec.size());
            if (!bio_key)
            {
                std::cerr << "Error creating memory BIO" << std::endl;
                exit(-1);
            }
            *pkey = d2i_PrivateKey_bio(bio_key, NULL);
            if (!*pkey)
            {
                // error handling
                std::cout << "Failed to load the priv key" << std::endl;
                ERR_print_errors_fp(stderr);

                // input data is not in correct format
                char buf[120];
                ERR_error_string(ERR_get_error(), buf);
                printf("PKCS8 format check failed: %s\n", buf);
                exit(-1);
            }
            BIO_free(bio_key);

            return true;
        }

        // Cleanup
        Uninitialize();
        delete log_handle;
        log_handle = nullptr;
    }
    catch (std::exception &e)
    {
        printf("Exception occured. Details - %s", e.what());
        exit(1);
    }

    TRACE_OUT("Exiting Util::doSKR()");
    return true;
}