plugins/wasm-cpp/extensions/jwt_auth/plugin.cc (387 lines of code) (raw):

// Copyright (c) 2022 Alibaba Group Holding Ltd. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "extensions/jwt_auth/plugin.h" #include <algorithm> #include <array> #include <cstdint> #include <string> #include <unordered_set> #include <utility> #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" #include "absl/strings/str_split.h" #include "common/common_util.h" #include "common/http_util.h" #include "common/json_util.h" using ::nlohmann::json; using ::Wasm::Common::JsonArrayIterate; using ::Wasm::Common::JsonGetField; using ::Wasm::Common::JsonObjectIterate; using ::Wasm::Common::JsonValueAs; #ifdef NULL_PLUGIN namespace proxy_wasm { namespace null_plugin { namespace jwt_auth { PROXY_WASM_NULL_PLUGIN_REGISTRY #endif namespace { constexpr absl::string_view InvalidTokenErrorString = ", error=\"invalid_token\""; constexpr uint32_t MaximumUriLength = 256; constexpr std::string_view kRcDetailJwtAuthnPrefix = "jwt_authn_access_denied"; std::string generateRcDetails(std::string_view error_msg) { // Replace space with underscore since RCDetails may be written to access log. // Some log processors assume each log segment is separated by whitespace. return absl::StrCat(kRcDetailJwtAuthnPrefix, "{", absl::StrJoin(absl::StrSplit(error_msg, ' '), "_"), "}"); } } // namespace static RegisterContextFactory register_JwtAuth(CONTEXT_FACTORY(PluginContext), ROOT_FACTORY(PluginRootContext)); #define JSON_FIND_FIELD(dict, field) \ auto dict##_##field##_json = dict.find(#field); \ if (dict##_##field##_json == dict.end()) { \ LOG_WARN("can't find '" #field "' in " #dict); \ return false; \ } #define JSON_VALUE_AS(type, src, dst, err_msg) \ auto dst##_v = JsonValueAs<type>(src); \ if (dst##_v.second != Wasm::Common::JsonParserResultDetail::OK || \ !dst##_v.first) { \ LOG_WARN(#err_msg); \ return false; \ } \ auto& dst = dst##_v.first.value(); #define JSON_FIELD_VALUE_AS(type, dict, field) \ JSON_VALUE_AS(type, dict##_##field##_json.value(), dict##_##field, \ "'" #field "' field in " #dict "convert to " #type " failed") bool PluginRootContext::parsePluginConfig(const json& configuration, JwtAuthConfigRule& rule) { std::unordered_set<std::string> name_set; if (!JsonArrayIterate( configuration, "consumers", [&](const json& consumer) -> bool { Consumer c; JSON_FIND_FIELD(consumer, name); JSON_FIELD_VALUE_AS(std::string, consumer, name); if (name_set.count(consumer_name) != 0) { LOG_WARN("consumer already exists: " + consumer_name); return false; } c.name = consumer_name; JSON_FIND_FIELD(consumer, jwks); JSON_FIELD_VALUE_AS(std::string, consumer, jwks); c.jwks = google::jwt_verify::Jwks::createFrom( consumer_jwks, google::jwt_verify::Jwks::JWKS); if (c.jwks->getStatus() != Status::Ok) { LOG_WARN(absl::StrFormat( "jwks is invalid, consumer:%s, status:%s, jwks:%s", consumer_name, google::jwt_verify::getStatusString(c.jwks->getStatus()), consumer_jwks)); return false; } std::unordered_map<std::string, std::string> claims; auto consumer_claims_json = consumer.find("claims"); if (consumer_claims_json != consumer.end()) { JSON_FIELD_VALUE_AS(Wasm::Common::JsonObject, consumer, claims); if (!JsonObjectIterate( consumer_claims, [&](std::string key) -> bool { auto claims_claim_json = consumer_claims.find(key); JSON_FIELD_VALUE_AS(std::string, claims, claim); claims.emplace(std::make_pair( key, Wasm::Common::trim(claims_claim))); return true; })) { LOG_WARN("failed to parse 'claims' in consumer: " + consumer_name); return false; } } auto consumer_issuer_json = consumer.find("issuer"); if (consumer_issuer_json != consumer.end()) { JSON_FIELD_VALUE_AS(std::string, consumer, issuer); claims.emplace( std::make_pair("iss", Wasm::Common::trim(consumer_issuer))); } c.allowd_claims = std::move(claims); std::vector<FromHeader> from_headers; if (!JsonArrayIterate( consumer, "from_headers", [&](const json& from_header) -> bool { JSON_FIND_FIELD(from_header, name); JSON_FIELD_VALUE_AS(std::string, from_header, name); std::string header_value_prefix; auto from_header_value_prefix_json = from_header.find("value_prefix"); if (from_header_value_prefix_json != from_header.end()) { JSON_FIELD_VALUE_AS(std::string, from_header, value_prefix); header_value_prefix = from_header_value_prefix; } from_headers.push_back( FromHeader{from_header_name, header_value_prefix}); return true; })) { LOG_WARN("failed to parse 'from_headers' in consumer: " + consumer_name); return false; } std::vector<std::string> from_params; if (!JsonArrayIterate(consumer, "from_params", [&](const json& from_param_json) -> bool { JSON_VALUE_AS(std::string, from_param_json, from_param, "invalid item"); from_params.push_back(from_param); return true; })) { LOG_WARN("failed to parse 'from_params' in consumer: " + consumer_name); return false; } std::vector<std::string> from_cookies; if (!JsonArrayIterate(consumer, "from_cookies", [&](const json& from_cookie_json) -> bool { JSON_VALUE_AS(std::string, from_cookie_json, from_cookie, "invalid item"); from_cookies.push_back(from_cookie); return true; })) { LOG_WARN("failed to parse 'from_cookies' in consumer: " + consumer_name); return false; } if (!from_headers.empty() || !from_params.empty() || !from_cookies.empty()) { c.from_headers = std::move(from_headers); c.from_params = std::move(from_params); c.from_cookies = std::move(from_cookies); } std::unordered_map<std::string, ClaimToHeader> claims_to_headers; if (!JsonArrayIterate( consumer, "claims_to_headers", [&](const json& item_json) -> bool { JSON_VALUE_AS(Wasm::Common::JsonObject, item_json, item, "invalid item"); JSON_FIND_FIELD(item, claim); JSON_FIELD_VALUE_AS(std::string, item, claim); auto c2h_it = claims_to_headers.find(item_claim); if (c2h_it != claims_to_headers.end()) { LOG_WARN("claim to header already exists: " + item_claim); return false; } auto& c2h = claims_to_headers[item_claim]; JSON_FIND_FIELD(item, header); JSON_FIELD_VALUE_AS(std::string, item, header); c2h.header = std::move(item_header); auto item_override_json = item.find("override"); if (item_override_json != item.end()) { JSON_FIELD_VALUE_AS(bool, item, override); c2h.override = item_override; } return true; })) { LOG_WARN("failed to parse 'claims_to_headers' in consumer: " + consumer_name); return false; } c.claims_to_headers = std::move(claims_to_headers); auto consumer_clock_skew_seconds_json = consumer.find("clock_skew_seconds"); if (consumer_clock_skew_seconds_json != consumer.end()) { JSON_FIELD_VALUE_AS(uint64_t, consumer, clock_skew_seconds); c.clock_skew = consumer_clock_skew_seconds; } auto consumer_keep_token_json = consumer.find("keep_token"); if (consumer_keep_token_json != consumer.end()) { JSON_FIELD_VALUE_AS(bool, consumer, keep_token); c.keep_token = consumer_keep_token; } c.extractor = Extractor::create(c); rule.consumers.push_back(std::move(c)); name_set.insert(consumer_name); return true; })) { LOG_WARN("failed to parse configuration for consumers."); return false; } if (rule.consumers.empty()) { LOG_INFO("at least one consumer has to be configured for a rule."); return false; } std::vector<std::string> enable_headers; if (!JsonArrayIterate(configuration, "enable_headers", [&](const json& enable_header_json) -> bool { JSON_VALUE_AS(std::string, enable_header_json, enable_header, "invalid item"); enable_headers.push_back(enable_header); return true; })) { LOG_WARN("failed to parse 'enable_headers'"); return false; } rule.enable_headers = std::move(enable_headers); return true; } Status PluginRootContext::consumerVerify( const Consumer& consumer, uint64_t now, std::vector<JwtLocationConstPtr>& jwt_tokens) { auto tokens = consumer.extractor->extract(); if (tokens.empty()) { return Status::JwtMissed; } for (auto& token : tokens) { google::jwt_verify::Jwt jwt; Status status = jwt.parseFromString(token->token()); if (status != Status::Ok) { LOG_INFO(absl::StrFormat( "jwt parse failed, consumer:%s, token:%s, status:%s", consumer.name, token->token(), google::jwt_verify::getStatusString(status))); return status; } StructUtils payload_getter(jwt.payload_pb_); if (!consumer.allowd_claims.empty()) { for (const auto& claim : consumer.allowd_claims) { std::string value; if (payload_getter.GetString(claim.first, &value) == StructUtils::WRONG_TYPE) { LOG_INFO(absl::StrFormat( "jwt payload invalid, consumer:%s, token:%s, claim:%s", consumer.name, jwt.payload_str_, claim.first)); return Status::JwtVerificationFail; } if (value != claim.second) { LOG_INFO(absl::StrFormat( "jwt payload invalid, consumer:%s, claim:%s, value:%s, expect:%s", consumer.name, claim.first, value, claim.second)); return Status::JwtVerificationFail; } } } status = jwt.verifyTimeConstraint(now, consumer.clock_skew); if (status != Status::Ok) { LOG_DEBUG(absl::StrFormat( "jwt verify time failed, consumer:%s, token:%s, status:%s", consumer.name, token->token(), google::jwt_verify::getStatusString(status))); return status; } status = google::jwt_verify::verifyJwtWithoutTimeChecking(jwt, *consumer.jwks); if (status != Status::Ok) { LOG_DEBUG(absl::StrFormat( "jwt verify failed, consumer:%s, token:%s, status:%s", consumer.name, token->token(), google::jwt_verify::getStatusString(status))); return status; } for (const auto& claim_to_header : consumer.claims_to_headers) { std::string value; if (payload_getter.GetString(claim_to_header.first, &value) != StructUtils::WRONG_TYPE) { token->addClaimToHeader(claim_to_header.second.header, value, claim_to_header.second.override); } else { uint64_t num_value; if (payload_getter.GetUInt64(claim_to_header.first, &num_value) != StructUtils::WRONG_TYPE) { token->addClaimToHeader(claim_to_header.second.header, std::to_string((unsigned long long)num_value), claim_to_header.second.override); } } } } jwt_tokens = std::move(tokens); return Status::Ok; } bool PluginRootContext::checkPlugin( const JwtAuthConfigRule& rule, const std::optional<std::unordered_set<std::string>>& allow_set) { if (!rule.enable_headers.empty()) { bool skip_auth = true; for (const auto& enable_header : rule.enable_headers) { auto header_ptr = getRequestHeader(enable_header); if (header_ptr->size() > 0) { LOG_DEBUG("enable by header: " + header_ptr->toString()); skip_auth = false; break; } } if (skip_auth) { return true; } } std::optional<Status> err_status; bool verified = false; uint64_t now = getCurrentTimeNanoseconds() / 1e9; for (const auto& consumer : rule.consumers) { std::vector<JwtLocationConstPtr> tokens; auto status = consumerVerify(consumer, now, tokens); if (status == Status::Ok) { verified = true; // global config without allow_set field allows any consumers if (!allow_set || allow_set.value().find(consumer.name) != allow_set.value().end()) { addRequestHeader("X-Mse-Consumer", consumer.name); for (auto& token : tokens) { if (!consumer.keep_token) { token->removeJwt(); } token->claimsToHeaders(); } return true; } } // use the first status if (!err_status) { err_status = status; } } if (!verified) { auto status = err_status ? err_status.value() : Status::JwtMissed; auto err_str = google::jwt_verify::getStatusString(status); auto authn_value = absl::StrCat( "Bearer realm=\"", Wasm::Common::Http::buildOriginalUri(MaximumUriLength), "\""); if (status != Status::JwtMissed) { absl::StrAppend(&authn_value, InvalidTokenErrorString); } sendLocalResponse(401, generateRcDetails(err_str), err_str, {{"WWW-Authenticate", authn_value}}); } else { sendLocalResponse(403, kRcDetailJwtAuthnPrefix, "Access Denied", {}); } return false; } bool PluginRootContext::onConfigure(size_t size) { // Parse configuration JSON string. if (size > 0 && !configure(size)) { LOG_WARN("configuration has errors initialization will not continue."); return false; } return true; } bool PluginRootContext::configure(size_t configuration_size) { auto configuration_data = getBufferBytes(WasmBufferType::PluginConfiguration, 0, configuration_size); // Parse configuration JSON string. auto result = ::Wasm::Common::JsonParse(configuration_data->view()); if (!result) { LOG_WARN(absl::StrCat("cannot parse plugin configuration JSON string: ", configuration_data->view())); return false; } if (!parseAuthRuleConfig(result.value())) { LOG_WARN(absl::StrCat("cannot parse plugin configuration JSON string: ", configuration_data->view())); return false; } return true; } FilterHeadersStatus PluginContext::onRequestHeaders(uint32_t, bool) { auto* rootCtx = rootContext(); return rootCtx->checkAuthRule( [rootCtx](const auto& config, const auto& allow_set) { return rootCtx->checkPlugin(config, allow_set); }) ? FilterHeadersStatus::Continue : FilterHeadersStatus::StopIteration; } #ifdef NULL_PLUGIN } // namespace jwt_auth } // namespace null_plugin } // namespace proxy_wasm #endif