plugins/wasm-cpp/extensions/jwt_auth/extractor.cc (210 lines of code) (raw):

// Copyright (c) 2022 Alibaba Group Holding Ltd. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // modified base on envoy/source/extensions/filters/http/jwt_authn/extractor.cc #include "extensions/jwt_auth/extractor.h" #include <memory> #include <tuple> #include <unordered_map> #include "absl/container/btree_map.h" #include "common/http_util.h" #include "extensions/jwt_auth/plugin.h" #ifdef NULL_PLUGIN namespace proxy_wasm { namespace null_plugin { namespace jwt_auth { #endif namespace { /** * Check Claims specified in Provider */ class JwtClaimChecker { public: JwtClaimChecker(const ClaimsMap& claims) : allowed_claims_(claims) {} // check if a jwt issuer is allowed bool check(const std::string& key, const std::string& value) const { if (allowed_claims_.empty()) { return true; } auto it = allowed_claims_.find(key); return it != allowed_claims_.end() && it->second == value; } private: // Only these specified claims are allowed. const ClaimsMap& allowed_claims_; }; using JwtClaimCheckerPtr = std::unique_ptr<JwtClaimChecker>; // A base JwtLocation object to store token and claim_checker. class JwtLocationBase : public JwtLocation { public: JwtLocationBase(const std::string& token, const JwtClaimChecker& claim_checker) : token_(token), claim_checker_(claim_checker) {} // Get the token string const std::string& token() const override { return token_; } // Check if an claim has specified the location. bool isClaimAllowed(const std::string& key, const std::string& value) const override { return claim_checker_.check(key, value); } void addClaimToHeader(const std::string& header, const std::string& value, bool override) const override { claims_to_headers_.emplace_back(header, value, override); } void claimsToHeaders() const override { for (const auto& claim_to_header : claims_to_headers_) { const auto& header_key = std::get<0>(claim_to_header); const auto& header_value = std::get<1>(claim_to_header); if (std::get<2>(claim_to_header)) { auto header_ptr = getRequestHeader(header_key); if (!header_ptr->view().empty()) { replaceRequestHeader(header_key, header_value); continue; } } addRequestHeader(header_key, header_value); } } private: mutable std::vector<std::tuple<std::string, std::string, bool>> claims_to_headers_; // Extracted token. const std::string token_; // Claim checker const JwtClaimChecker& claim_checker_; }; // The JwtLocation for header extraction. class JwtHeaderLocation : public JwtLocationBase { public: JwtHeaderLocation(const std::string& token, const JwtClaimChecker& claim_checker, const std::string& header) : JwtLocationBase(token, claim_checker), header_(header) {} void removeJwt() const override { removeRequestHeader(header_); } private: // the header name the JWT is extracted from. const std::string& header_; }; // The JwtLocation for param extraction. class JwtParamLocation : public JwtLocationBase { public: JwtParamLocation(const std::string& token, const JwtClaimChecker& claim_checker, const std::string&) : JwtLocationBase(token, claim_checker) {} void removeJwt() const override { // TODO(qiwzhang): remove JWT from parameter. } }; // The JwtLocation for cookie extraction. class JwtCookieLocation : public JwtLocationBase { public: JwtCookieLocation(const std::string& token, const JwtClaimChecker& claim_checker) : JwtLocationBase(token, claim_checker) {} void removeJwt() const override { // TODO(theshubhamp): remove JWT from cookies. } }; class ExtractorImpl : public Extractor { public: ExtractorImpl(const Consumer& provider); std::vector<JwtLocationConstPtr> extract() const override; private: // add a header config void addHeaderConfig(const ClaimsMap& claims, const std::string& header_name, const std::string& value_prefix); // add a query param config void addQueryParamConfig(const ClaimsMap& claims, const std::string& param); // add a query param config void addCookieConfig(const ClaimsMap& claims, const std::string& cookie); // ctor helper for a jwt provider config void addProvider(const Consumer& provider); // HeaderMap value type to store prefix and issuers that specified this // header. struct HeaderLocationSpec { HeaderLocationSpec(const std::string& header, const std::string& value_prefix) : header_(header), value_prefix_(value_prefix) {} // The header name. std::string header_; // The value prefix. e.g. for "Bearer <token>", the value_prefix is "Bearer // ". std::string value_prefix_; // Issuers that specified this header. JwtClaimCheckerPtr claim_checker_; }; using HeaderLocationSpecPtr = std::unique_ptr<HeaderLocationSpec>; // The map of (header + value_prefix) to HeaderLocationSpecPtr std::map<std::string, HeaderLocationSpecPtr> header_locations_; // ParamMap value type to store issuers that specified this header. struct ParamLocationSpec { // Issuers that specified this param. JwtClaimCheckerPtr claim_checker_; }; // The map of a parameter key to set of issuers specified the parameter std::map<std::string, ParamLocationSpec> param_locations_; // CookieMap value type to store issuers that specified this cookie. struct CookieLocationSpec { // Issuers that specified this param. JwtClaimCheckerPtr claim_checker_; }; // The map of a cookie key to set of issuers specified the cookie. absl::btree_map<std::string, CookieLocationSpec> cookie_locations_; }; ExtractorImpl::ExtractorImpl(const Consumer& provider) { addProvider(provider); } void ExtractorImpl::addProvider(const Consumer& provider) { for (const auto& header : provider.from_headers) { addHeaderConfig(provider.allowd_claims, header.header, header.value_prefix); } for (const std::string& param : provider.from_params) { addQueryParamConfig(provider.allowd_claims, param); } for (const std::string& cookie : provider.from_cookies) { addCookieConfig(provider.allowd_claims, cookie); } } void ExtractorImpl::addHeaderConfig(const ClaimsMap& claims, const std::string& header_name, const std::string& value_prefix) { const std::string map_key = header_name + value_prefix; auto& header_location_spec = header_locations_[map_key]; if (!header_location_spec) { header_location_spec = std::make_unique<HeaderLocationSpec>(header_name, value_prefix); } header_location_spec->claim_checker_ = std::make_unique<JwtClaimChecker>(claims); } void ExtractorImpl::addQueryParamConfig(const ClaimsMap& claims, const std::string& param) { auto& param_location_spec = param_locations_[param]; param_location_spec.claim_checker_ = std::make_unique<JwtClaimChecker>(claims); } void ExtractorImpl::addCookieConfig(const ClaimsMap& claims, const std::string& cookie) { auto& cookie_location_spec = cookie_locations_[cookie]; cookie_location_spec.claim_checker_ = std::make_unique<JwtClaimChecker>(claims); } std::vector<JwtLocationConstPtr> ExtractorImpl::extract() const { std::vector<JwtLocationConstPtr> tokens; // Check header locations first for (const auto& location_it : header_locations_) { const auto& location_spec = location_it.second; auto header = getRequestHeader(location_spec->header_)->toString(); if (!header.empty()) { const auto pos = header.find(location_spec->value_prefix_); if (pos == std::string::npos) { continue; } auto header_strip = header.substr(pos + location_spec->value_prefix_.length()); tokens.push_back(std::make_unique<const JwtHeaderLocation>( header_strip, *location_spec->claim_checker_, location_spec->header_)); } } // Check query parameter locations only if query parameter locations specified // and Path() is not null auto path = getRequestHeader(Wasm::Common::Http::Header::Path)->toString(); if (!param_locations_.empty() && !path.empty()) { const auto& params = Wasm::Common::Http::parseAndDecodeQueryString(path); for (const auto& location_it : param_locations_) { const auto& param_key = location_it.first; const auto& location_spec = location_it.second; const auto& it = params.find(param_key); if (it != params.end()) { tokens.push_back(std::make_unique<const JwtParamLocation>( it->second, *location_spec.claim_checker_, param_key)); } } } // Check cookie locations. if (!cookie_locations_.empty()) { const auto& cookies = Wasm::Common::Http::parseCookies([&](absl::string_view k) -> bool { return cookie_locations_.contains(k); }); for (const auto& location_it : cookie_locations_) { const auto& cookie_key = location_it.first; const auto& location_spec = location_it.second; const auto& it = cookies.find(cookie_key); if (it != cookies.end()) { tokens.push_back(std::make_unique<const JwtCookieLocation>( it->second, *location_spec.claim_checker_)); } } } return tokens; } } // namespace ExtractorConstPtr Extractor::create(const Consumer& provider) { return std::make_unique<ExtractorImpl>(provider); } #ifdef NULL_PLUGIN } // namespace jwt_auth } // namespace null_plugin } // namespace proxy_wasm #endif