include/ylt/coro_io/load_balancer.hpp (178 lines of code) (raw):
/*
* Copyright (c) 2023, Alibaba Group Holding Limited;
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <async_simple/coro/Lazy.h>
#include <atomic>
#include <memory>
#include <numeric>
#include <random>
#include "client_pool.hpp"
#include "io_context_pool.hpp"
namespace coro_io {
enum class load_balance_algorithm {
RR = 0, // round-robin
WRR, // weight round-robin
random,
};
template <typename client_t, typename io_context_pool_t = io_context_pool>
class load_balancer {
using client_pool_t = client_pool<client_t, io_context_pool_t>;
using client_pools_t = client_pools<client_t, io_context_pool_t>;
public:
struct load_balancer_config {
typename client_pool_t::pool_config pool_config;
load_balance_algorithm lba = load_balance_algorithm::RR;
~load_balancer_config(){};
};
private:
struct RRLoadbalancer {
std::unique_ptr<std::atomic<uint32_t>> index =
std::make_unique<std::atomic<uint32_t>>();
async_simple::coro::Lazy<std::shared_ptr<client_pool_t>> operator()(
const load_balancer& load_balancer) {
auto i = index->fetch_add(1, std::memory_order_relaxed);
co_return load_balancer
.client_pools_[i % load_balancer.client_pools_.size()];
}
};
/*
Supposing that there is a server set ''S'' = {S0, S1, …, Sn-1};
W(Si) indicates the weight of Si;
''i'' indicates the server selected last time, and ''i'' is initialized with
-1;
''cw'' is the current weight in scheduling, and cw is initialized with zero;
max(S) is the maximum weight of all the servers in S;
gcd(S) is the greatest common divisor of all server weights in S;
while (true) {
i = (i + 1) mod n;
if (i == 0) {
cw = cw - gcd(S);
if (cw <= 0) {
cw = max(S);
if (cw == 0)
return NULL;
}
}
if (W(Si) >= cw)
return Si;
}
*/
struct WRRLoadbalancer {
WRRLoadbalancer(const std::vector<int>& weights) : weights_(weights) {
max_gcd_ = get_max_weight_gcd();
max_weight_ = get_max_weight();
}
async_simple::coro::Lazy<std::shared_ptr<client_pool_t>> operator()(
const load_balancer& load_balancer) {
int selected = select_host_with_weight_round_robin();
if (selected == -1) {
selected = 0;
}
wrr_current_ = selected;
co_return load_balancer
.client_pools_[selected % load_balancer.client_pools_.size()];
}
private:
int select_host_with_weight_round_robin() {
while (true) {
wrr_current_ = (wrr_current_ + 1) % weights_.size();
if (wrr_current_ == 0) {
weight_current_ = weight_current_ - max_gcd_;
if (weight_current_ <= 0) {
weight_current_ = max_weight_;
if (weight_current_ == 0) {
return -1; // can't find max weight server
}
}
}
if (weights_[wrr_current_] >= weight_current_) {
return wrr_current_;
}
}
}
int get_max_weight_gcd() {
int res = weights_[0];
int cur_max = 0, cur_min = 0;
for (size_t i = 0; i < weights_.size(); i++) {
cur_max = (std::max)(res, weights_[i]);
cur_min = (std::min)(res, weights_[i]);
res = std::gcd(cur_max, cur_min);
}
return res;
}
int get_max_weight() {
return *std::max_element(weights_.begin(), weights_.end());
}
std::vector<int> weights_;
int max_gcd_ = 0;
int max_weight_ = 0;
int wrr_current_ = -1;
int weight_current_ = 0;
};
struct RandomLoadbalancer {
async_simple::coro::Lazy<std::shared_ptr<client_pool_t>> operator()(
const load_balancer& load_balancer) {
static thread_local std::default_random_engine e(std::time(nullptr));
std::uniform_int_distribution rnd{std::size_t{0},
load_balancer.client_pools_.size() - 1};
co_return load_balancer.client_pools_[rnd(e)];
}
};
load_balancer() = default;
public:
load_balancer(load_balancer&& o)
: config_(std::move(o.config_)),
lb_worker(std::move(o.lb_worker)),
client_pools_(std::move(o.client_pools_)){};
load_balancer& operator=(load_balancer&& o) {
this->config_ = std::move(o.config_);
this->lb_worker = std::move(o.lb_worker);
this->client_pools_ = std::move(o.client_pools_);
return *this;
}
load_balancer(const load_balancer& o) = delete;
load_balancer& operator=(const load_balancer& o) = delete;
auto send_request(auto op, typename client_t::config& config)
-> decltype(std::declval<client_pool_t>().send_request(std::move(op),
std::string_view{},
config)) {
std::shared_ptr<client_pool_t> client_pool;
if (client_pools_.size() > 1) {
int cnt = 0;
do {
client_pool = co_await std::visit(
[this](auto& worker) {
return worker(*this);
},
lb_worker);
} while (!client_pool->is_alive() && ++cnt <= size() * 2);
}
else {
client_pool = client_pools_[0];
}
co_return co_await client_pool->send_request(
std::move(op), client_pool->get_host_name(), config);
}
auto send_request(auto op) {
return send_request(std::move(op), config_.pool_config.client_config);
}
static load_balancer create(
const std::vector<std::string_view>& hosts,
const load_balancer_config& config = {},
const std::vector<int>& weights = {},
client_pools_t& client_pools =
g_clients_pool<client_t, io_context_pool_t>()) {
load_balancer ch;
ch.init(hosts, config, weights, client_pools);
return ch;
}
/**
* @brief return the load_balancer's hosts size.
*
* @return std::size_t
*/
std::size_t size() const noexcept { return client_pools_.size(); }
private:
void init(const std::vector<std::string_view>& hosts,
const load_balancer_config& config, const std::vector<int>& weights,
client_pools_t& client_pools) {
config_ = config;
client_pools_.reserve(hosts.size());
for (auto& host : hosts) {
client_pools_.emplace_back(client_pools.at(host, config.pool_config));
}
switch (config_.lba) {
case load_balance_algorithm::RR:
lb_worker = RRLoadbalancer{};
break;
case load_balance_algorithm::WRR: {
if (hosts.empty() || weights.empty()) {
throw std::invalid_argument("host/weight list is empty!");
}
if (hosts.size() != weights.size()) {
throw std::invalid_argument("hosts count is not equal with weights!");
}
lb_worker = WRRLoadbalancer(weights);
} break;
case load_balance_algorithm::random:
default:
lb_worker = RandomLoadbalancer{};
}
return;
}
load_balancer_config config_;
std::variant<RRLoadbalancer, WRRLoadbalancer, RandomLoadbalancer> lb_worker;
std::vector<std::shared_ptr<client_pool_t>> client_pools_;
};
} // namespace coro_io