plugins/wasm-cpp/extensions/key_rate_limit/bucket.cc (178 lines of code) (raw):
// Copyright (c) 2022 Alibaba Group Holding Ltd.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "extensions/key_rate_limit/bucket.h"
#include <string>
#include <unordered_map>
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
#include "absl/strings/str_join.h"
namespace {
const int maxGetTokenRetry = 20;
// Key-prefix for token bucket shared data.
std::string tokenBucketPrefix = "mse.token_bucket";
// Key-prefix for token bucket last updated time.
std::string lastRefilledPrefix = "mse.last_refilled";
} // namespace
bool getToken(int rule_id, const std::string &key) {
WasmDataPtr token_bucket_data;
uint32_t cas;
std::string tokenBucketKey =
std::to_string(rule_id) + tokenBucketPrefix + key;
for (int i = 0; i < maxGetTokenRetry; i++) {
if (WasmResult::Ok !=
getSharedData(tokenBucketKey, &token_bucket_data, &cas)) {
continue;
}
uint64_t token_left =
*reinterpret_cast<const uint64_t *>(token_bucket_data->data());
LOG_DEBUG(absl::StrFormat(
"ratelimit get token: id:%d, tokenBucketKey:%s, token left:%u", rule_id,
tokenBucketKey, token_left));
if (token_left == 0) {
LOG_DEBUG(absl::StrFormat("get token failed, id:%d, tokenBucketKey:%s",
rule_id, tokenBucketKey));
return false;
}
token_left -= 1;
auto res = setSharedData(
tokenBucketKey,
{reinterpret_cast<const char *>(&token_left), sizeof(token_left)}, cas);
if (res == WasmResult::Ok) {
LOG_DEBUG(
absl::StrFormat("ratelimit token update success: id:%d, "
"tokenBucketKey:%s, token left:%u",
rule_id, tokenBucketKey, token_left));
return true;
}
if (res == WasmResult::CasMismatch) {
continue;
}
LOG_WARN(absl::StrFormat("got invalid result:%d, id:%d, tokenBucketKey:%s",
res, rule_id, tokenBucketKey));
return true;
}
LOG_WARN("get token failed with cas mismatch");
return true;
}
void refillToken(const std::vector<std::pair<int, LimitItem>> &rules) {
uint32_t last_update_cas;
WasmDataPtr last_update_data;
for (const auto &rule : rules) {
auto id = std::to_string(rule.first);
std::string lastRefilledKey = id + lastRefilledPrefix + rule.second.key;
std::string tokenBucketKey = id + tokenBucketPrefix + rule.second.key;
auto result =
getSharedData(lastRefilledKey, &last_update_data, &last_update_cas);
if (result != WasmResult::Ok) {
LOG_WARN(
absl::StrCat("failed to get last update time of the local rate limit "
"token bucket ",
toString(result)));
continue;
}
uint64_t last_update =
*reinterpret_cast<const uint64_t *>(last_update_data->data());
uint64_t now = getCurrentTimeNanoseconds();
if (now - last_update < rule.second.refill_interval_nanosec) {
continue;
}
LOG_DEBUG(
absl::StrFormat("ratelimit rule need refilled, id:%s, "
"lastRefilledKey:%s, now:%u, last_update:%u",
id, lastRefilledKey, now, last_update));
// Otherwise, try set last updated time. If updated failed because of cas
// mismatch, the bucket is going to be refilled by other VMs.
auto res = setSharedData(
lastRefilledKey, {reinterpret_cast<const char *>(&now), sizeof(now)},
last_update_cas);
if (res == WasmResult::CasMismatch) {
LOG_DEBUG(
absl::StrFormat("ratelimit update lastRefilledKey casmismatch, the "
"bucket is going to be refilled by other VMs, id:%s, "
"lastRefilledKey:%s",
id, lastRefilledKey));
continue;
}
do {
if (WasmResult::Ok !=
getSharedData(tokenBucketKey, &last_update_data, &last_update_cas)) {
LOG_WARN("failed to get current local rate limit token bucket");
break;
}
uint64_t token_left =
*reinterpret_cast<const uint64_t *>(last_update_data->data());
// Refill tokens, and update bucket with cas. If update failed because of
// cas mismatch, retry refilling.
token_left += rule.second.tokens_per_refill;
if (token_left > rule.second.max_tokens) {
token_left = rule.second.max_tokens;
}
if (WasmResult::CasMismatch ==
setSharedData(
tokenBucketKey,
{reinterpret_cast<const char *>(&token_left), sizeof(token_left)},
last_update_cas)) {
continue;
}
LOG_DEBUG(
absl::StrFormat("ratelimit token refilled: id:%s, "
"tokenBucketKey:%s, token left:%u",
id, tokenBucketKey, token_left));
break;
} while (true);
}
}
bool initializeTokenBucket(
const std::vector<std::pair<int, LimitItem>> &rules) {
uint32_t last_update_cas;
WasmDataPtr last_update_data;
uint64_t initial_value = 0;
for (const auto &rule : rules) {
auto id = std::to_string(rule.first);
std::string lastRefilledKey = id + lastRefilledPrefix + rule.second.key;
std::string tokenBucketKey = id + tokenBucketPrefix + rule.second.key;
auto res =
getSharedData(lastRefilledKey, &last_update_data, &last_update_cas);
if (res == WasmResult::NotFound) {
setSharedData(lastRefilledKey,
{reinterpret_cast<const char *>(&initial_value),
sizeof(initial_value)});
setSharedData(tokenBucketKey,
{reinterpret_cast<const char *>(&rule.second.max_tokens),
sizeof(uint64_t)});
LOG_INFO(absl::StrFormat(
"ratelimit rule created: id:%s, lastRefilledKey:%s, "
"tokenBucketKey:%s, max_tokens:%u",
id, lastRefilledKey, tokenBucketKey, rule.second.max_tokens));
continue;
}
// reconfigure
do {
if (WasmResult::Ok !=
getSharedData(lastRefilledKey, &last_update_data, &last_update_cas)) {
LOG_WARN("failed to get lastRefilled");
return false;
}
if (WasmResult::CasMismatch ==
setSharedData(lastRefilledKey,
{reinterpret_cast<const char *>(&initial_value),
sizeof(initial_value)},
last_update_cas)) {
continue;
}
break;
} while (true);
do {
if (WasmResult::Ok !=
getSharedData(tokenBucketKey, &last_update_data, &last_update_cas)) {
LOG_WARN("failed to get tokenBucket");
return false;
}
if (WasmResult::CasMismatch ==
setSharedData(
tokenBucketKey,
{reinterpret_cast<const char *>(&rule.second.max_tokens),
sizeof(uint64_t)},
last_update_cas)) {
continue;
}
break;
} while (true);
LOG_INFO(absl::StrFormat(
"ratelimit rule reconfigured: id:%s, lastRefilledKey:%s, "
"tokenBucketKey:%s, max_tokens:%u",
id, lastRefilledKey, tokenBucketKey, rule.second.max_tokens));
}
return true;
}