azure-protected-vm-secrets/JsonWebToken.cpp (215 lines of code) (raw):

#include "pch.h" #include <nlohmann/json.hpp> #include <boost/archive/iterators/base64_from_binary.hpp> #include <boost/archive/iterators/binary_from_base64.hpp> #include <boost/archive/iterators/transform_width.hpp> #include <boost/algorithm/string.hpp> #include "JsonWebToken.h" #ifndef PLATFORM_UNIX #include "BcryptError.h" #include "Windows/WincryptX509.h" #else #include "Linux/OsslError.h" #include "Linux/OsslX509.h" #endif // !PLATFORM_UNIX #include "LibraryLogger.h" #include "BaseX509.h" using json = nlohmann::json; using namespace SecretsLogger; namespace encoders { std::string base64_encode(std::vector<unsigned char> value) { using namespace boost::archive::iterators; using It = base64_from_binary<transform_width<std::vector<unsigned char>::const_iterator, 6, 8>>; auto tmp = std::string(It(std::begin(value)), It(std::end(value))); return tmp.append((3 - value.size() % 3) % 3, '='); } std::string base64_url_encode(std::vector<unsigned char> value) { std::string base64 = base64_encode(value); boost::replace_all(base64, "+", "-"); boost::replace_all(base64, "/", "_"); return base64; } std::vector<unsigned char> base64_decode(std::string base64) { using namespace boost::archive::iterators; using It = transform_width<binary_from_base64<std::string::const_iterator>, 8, 6>; std::vector<unsigned char> base64DecodedVector(It(std::begin(base64)), It(std::end(base64))); std::size_t num_padding_chars = std::count(base64.rbegin(), base64.rend(), '='); if (num_padding_chars > 0) { base64DecodedVector.resize(base64DecodedVector.size() - num_padding_chars); } return base64DecodedVector; } std::vector<unsigned char> base64_url_decode(std::string base64Url) { boost::replace_all(base64Url, "-", "+"); boost::replace_all(base64Url, "_", "/"); return base64_decode(base64Url); } } JsonWebToken::JsonWebToken(const char *alg) { this->header = {}; // TODO: Validate alg is supported or raise exception this->header["alg"] = alg; this->header["typ"] = "JWT"; } JsonWebToken::~JsonWebToken() { } json JsonWebToken::getClaims() { return this->jwt; } json JsonWebToken::getHeader() { return this->header; } std::vector<unsigned char> JsonWebToken::getSignature() { return this->signature; } void JsonWebToken::SetHeader(json header) { this->header = header; } void JsonWebToken::SetPayload(json payload) { this->jwt = payload; } void JsonWebToken::SetSignature(std::vector<unsigned char> signature) { this->signature = signature; } std::string JsonWebToken::CreateToken() { std::vector<unsigned char> token; std::string header = this->header.dump(); this->jwt["iat"] = time(0); this->addClaim("exp", time(0) + 1800); // 30 minutes. This mirrors the service. std::string payload = this->jwt.dump(); std::string headerBase64 = encoders::base64_url_encode(std::vector<unsigned char>(header.begin(), header.end())); std::string payloadBase64 = encoders::base64_url_encode(std::vector<unsigned char>(payload.begin(), payload.end())); std::string signatureBase64 = encoders::base64_url_encode(this->signature); std::string tokenString = headerBase64 + "." + payloadBase64 + "." + signatureBase64; token = std::vector<unsigned char>(tokenString.begin(), tokenString.end()); return tokenString; } void JsonWebToken::ParseToken(std::string const&token, bool verify) { std::vector<unsigned char> tokenVector(token.begin(), token.end()); std::string headerBase64 = ""; std::string payloadBase64 = ""; std::string signatureBase64 = ""; std::string header = ""; std::string payload = ""; std::string signature = ""; size_t first = std::string::npos; size_t last = std::string::npos; if (tokenVector.size() < 2) { throw JwtError("Invalid JWT token."); } first = token.find_first_of('.'); last = token.find_last_of('.'); if (first == std::string::npos || last == std::string::npos || first == last) { throw JwtError("Invalid JWT token."); } headerBase64 = std::string(tokenVector.begin(), tokenVector.begin() + first); first++; payloadBase64 = std::string(tokenVector.begin() + first, tokenVector.begin() + last); last++; signatureBase64 = std::string(tokenVector.begin() + last, tokenVector.end()); if (!headerBase64.empty()) { try { std::vector<unsigned char> headerVector = encoders::base64_url_decode(headerBase64); header = std::string(headerVector.begin(), headerVector.end()); this->header = json::parse(header); } catch (json::parse_error& e) { LIBSECRETS_LOG(LogLevel::Error, "Json Parse Error\n", "Error message %s\n", e.what()); throw JwtError(e.what()); } catch (...) { LIBSECRETS_LOG(LogLevel::Error, "Json Parse Error\n", "Generic Parsing Error\n"); throw JwtError("Failed to parse header."); } } if (!payloadBase64.empty()) { std::vector<unsigned char> payloadVector = encoders::base64_url_decode(payloadBase64); payload = std::string(payloadVector.begin(), payloadVector.end()); try { this->jwt = json::parse(payload); if (this->jwt["exp"].is_number() && this->jwt["iat"].is_number()) { // Check if the token is expired or not yet valid // In Azure, the token expiration is set to 30 minutes after // the token is issued & signed. /*time_t exp = this->jwt["exp"]; time_t iat = this->jwt["iat"]; time_t now = time(0); if (now > exp) { // ParseError, Jwt subclass, timeError throw JwtError("Token has expired.", ErrorCode::ParsingError_Jwt_timeError); } if (now < iat) { // ParseError, Jwt subclass, timeError throw JwtError("Token is not yet valid.", ErrorCode::ParsingError_Jwt_timeError); }*/ } } catch (json::parse_error& e) { LIBSECRETS_LOG(LogLevel::Error, "Json Parse Error\n", "Error message %s\n", e.what()); throw JwtError(e.what()); } catch (...) { LIBSECRETS_LOG(LogLevel::Error, "Json Parse Error\n", "Generic Parsing Error\n"); throw JwtError("Failed to parse payload."); } } if (!signatureBase64.empty()) { this->signature = encoders::base64_url_decode(signatureBase64); if (verify) { std::string signed_portion = token.substr(0, token.find_last_of('.')); #ifndef PLATFORM_UNIX std::unique_ptr<WincryptX509> x509 = std::make_unique<WincryptX509>(); #else std::unique_ptr<OsslX509> x509 = std::make_unique<OsslX509>(); #endif try { std::for_each(std::begin(INTERMEDIATE_CERTS), std::end(INTERMEDIATE_CERTS), [&](const auto cert) { x509->LoadIntermediateCertificate(cert); } ); x509->LoadLeafCertificate(std::string(this->header["x5c"]).c_str()); if (!x509->VerifyCertChain()) { throw JwtError("Failed to verify certificate chain."); } else { LIBSECRETS_LOG( LogLevel::Debug, "Successfully Verified Certificate chain\n", ""); } std::vector<unsigned char> signed_data(signed_portion.begin(), signed_portion.end()); if (!x509->VerifySignature(signed_data, this->signature)) { throw JwtError("Failed to verify certificate chain."); } else { LIBSECRETS_LOG( LogLevel::Debug, "Successfully Verified Signature\n", ""); } } catch (json::out_of_range& e) { // No x5c header LIBSECRETS_LOG(LogLevel::Error, "Json Parse Error\n", "Error message %s\n", e.what()); throw JwtError(e.what()); } #ifndef PLATFORM_UNIX catch (WinCryptError& e) { // Certificate chain verification failed LIBSECRETS_LOG(LogLevel::Error, "WinCrypt Error\n", "Error message %s\n", e.what()); throw JwtError(e.what()); } catch (BcryptError &e) { // Signature verification failed LIBSECRETS_LOG(LogLevel::Error, "Bcrypt Verification\n", "Bcrypt status 0x%x occurred\n Message %s\t Bcrypt Info%s", e.getStatusCode(), e.what(), e.getErrorInfo()); throw JwtError(e.getErrorInfo()); } #else catch (OsslError& e) { // Certificate chain verification failed LIBSECRETS_LOG(LogLevel::Error, "Openssl Error\n", "Error message %s\n", e.what()); throw JwtError(e.what()); } #endif } } else { LIBSECRETS_LOG(LogLevel::Debug, "JWT Information\n", "No signature found in token\n"); } }