plugins/wasm-cpp/extensions/key_rate_limit/plugin.cc (192 lines of code) (raw):
// Copyright (c) 2022 Alibaba Group Holding Ltd.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "extensions/key_rate_limit/plugin.h"
#include <array>
#include <vector>
#include "absl/strings/str_cat.h"
#include "absl/strings/str_split.h"
#include "common/json_util.h"
using ::nlohmann::json;
using ::Wasm::Common::JsonArrayIterate;
using ::Wasm::Common::JsonGetField;
using ::Wasm::Common::JsonObjectIterate;
using ::Wasm::Common::JsonValueAs;
#ifdef NULL_PLUGIN
namespace proxy_wasm {
namespace null_plugin {
namespace key_rate_limit {
PROXY_WASM_NULL_PLUGIN_REGISTRY
#endif
static RegisterContextFactory register_KeyRateLimit(
CONTEXT_FACTORY(PluginContext), ROOT_FACTORY(PluginRootContext));
namespace {
constexpr uint64_t second_nano = 1000 * 1000 * 1000;
constexpr uint64_t minute_nano = 60 * second_nano;
constexpr uint64_t hour_nano = 60 * minute_nano;
constexpr uint64_t day_nano = 24 * hour_nano;
// tooManyRequest returns a 429 response code.
void tooManyRequest() {
sendLocalResponse(429, "Too many requests", "rate_limited", {});
}
} // namespace
bool PluginRootContext::parsePluginConfig(const json& configuration,
KeyRateLimitConfigRule& rule) {
if (!JsonArrayIterate(
configuration, "limit_keys", [&](const json& item) -> bool {
std::string key =
Wasm::Common::JsonGetField<std::string>(item, "key").value();
uint64_t qps =
Wasm::Common::JsonGetField<uint64_t>(item, "query_per_second")
.value_or(0);
if (qps > 0) {
rule.limit_keys.emplace(key, LimitItem{
key,
qps,
second_nano,
qps,
});
return true;
}
uint64_t qpm =
Wasm::Common::JsonGetField<uint64_t>(item, "query_per_minute")
.value_or(0);
if (qpm > 0) {
rule.limit_keys.emplace(key, LimitItem{
key,
qpm,
minute_nano,
qpm,
});
return true;
}
uint64_t qph =
Wasm::Common::JsonGetField<uint64_t>(item, "query_per_hour")
.value_or(0);
if (qph > 0) {
rule.limit_keys.emplace(key, LimitItem{
key,
qph,
hour_nano,
qph,
});
return true;
}
uint64_t qpd =
Wasm::Common::JsonGetField<uint64_t>(item, "query_per_day")
.value_or(0);
if (qpd > 0) {
rule.limit_keys.emplace(key, LimitItem{
key,
qpd,
day_nano,
qpd,
});
return true;
}
LOG_WARN(
"one of 'query_per_second', 'query_per_minute', "
"'query_per_hour' or 'query_per_day' must be set");
return false;
})) {
LOG_WARN("failed to parse configuration for limit_keys.");
return false;
}
if (rule.limit_keys.empty()) {
LOG_WARN("no limit keys found in configuration");
return false;
}
auto it = configuration.find("limit_by_header");
if (it != configuration.end()) {
auto limit_by_header = JsonValueAs<std::string>(it.value());
if (limit_by_header.second != Wasm::Common::JsonParserResultDetail::OK) {
LOG_WARN("cannot parse limit_by_header");
return false;
}
rule.limit_by_header = limit_by_header.first.value();
}
it = configuration.find("limit_by_param");
if (it != configuration.end()) {
auto limit_by_param = JsonValueAs<std::string>(it.value());
if (limit_by_param.second != Wasm::Common::JsonParserResultDetail::OK) {
LOG_WARN("cannot parse limit_by_param");
return false;
}
rule.limit_by_param = limit_by_param.first.value();
}
auto emptyHeader = rule.limit_by_header.empty();
auto emptyParam = rule.limit_by_param.empty();
if ((emptyHeader && emptyParam) || (!emptyHeader && !emptyParam)) {
LOG_WARN("only one of 'limit_by_param' and 'limit_by_header' can be set");
return false;
}
return true;
}
bool PluginRootContext::checkPlugin(int rule_id,
const KeyRateLimitConfigRule& config) {
const auto& headerKey = config.limit_by_header;
const auto& paramKey = config.limit_by_param;
std::string key;
if (!headerKey.empty()) {
GET_HEADER_VIEW(headerKey, header);
key = header;
} else {
// use paramKey which must not be empty
GET_HEADER_VIEW(":path", path);
const auto& params = Wasm::Common::Http::parseQueryString(path);
auto it = params.find(paramKey);
if (it != params.end()) {
key = it->second;
}
}
const auto& limit_keys = config.limit_keys;
if (limit_keys.find(key) == limit_keys.end()) {
return true;
}
if (!getToken(rule_id, key)) {
LOG_INFO(absl::StrCat("request rate limited by key: ", key));
tooManyRequest();
return false;
}
return true;
}
void PluginRootContext::onTick() { refillToken(limits_); }
bool PluginRootContext::onConfigure(size_t size) {
// Parse configuration JSON string.
if (size > 0 && !configure(size)) {
LOG_WARN("configuration has errors initialization will not continue.");
return false;
}
const auto& rules = getRules();
for (const auto& rule : rules) {
for (auto& keyItem : rule.second.get().limit_keys) {
limits_.emplace_back(rule.first, keyItem.second);
}
}
initializeTokenBucket(limits_);
proxy_set_tick_period_milliseconds(500);
return true;
}
bool PluginRootContext::configure(size_t configuration_size) {
auto configuration_data = getBufferBytes(WasmBufferType::PluginConfiguration,
0, configuration_size);
// Parse configuration JSON string.
auto result = ::Wasm::Common::JsonParse(configuration_data->view());
if (!result.has_value()) {
LOG_WARN(absl::StrCat("cannot parse plugin configuration JSON string: ",
configuration_data->view()));
return false;
}
if (!parseRuleConfig(result.value())) {
LOG_WARN(absl::StrCat("cannot parse plugin configuration JSON string: ",
configuration_data->view()));
return false;
}
return true;
}
FilterHeadersStatus PluginContext::onRequestHeaders(uint32_t, bool) {
auto* rootCtx = rootContext();
return rootCtx->checkRuleWithId([rootCtx](auto rule_id, const auto& config) {
return rootCtx->checkPlugin(rule_id, config);
})
? FilterHeadersStatus::Continue
: FilterHeadersStatus::StopIteration;
}
#ifdef NULL_PLUGIN
} // namespace key_rate_limit
} // namespace null_plugin
} // namespace proxy_wasm
#endif