include/ylt/util/map_sharded.hpp (207 lines of code) (raw):
#pragma once
#include <atomic>
#include <cstddef>
#include <iterator>
#include <memory>
#include <mutex>
#include <type_traits>
#include <utility>
#include <vector>
namespace ylt::util {
namespace internal {
template <typename Map>
class map_lock_t {
public:
using key_type = typename Map::key_type;
using value_type = typename Map::value_type;
using mapped_type = typename Map::mapped_type;
map_lock_t() : mtx_(std::make_unique<std::mutex>()) {}
std::shared_ptr<typename mapped_type::element_type> find(
const key_type& key) const {
std::lock_guard lock(*mtx_);
if (!map_) [[unlikely]] {
return nullptr;
}
auto it = map_->find(key);
if (it == map_->end()) {
return nullptr;
}
return it->second;
}
template <typename Op, typename... Args>
std::pair<std::shared_ptr<typename mapped_type::element_type>, bool>
try_emplace_with_op(const key_type& key, Op&& op, Args&&... args) {
std::lock_guard lock(*mtx_);
auto result = visit_map().try_emplace(key, std::forward<Args>(args)...);
op(result);
return {result.first->second, result.second};
}
size_t erase(const key_type& key) {
std::lock_guard lock(*mtx_);
if (!map_) [[unlikely]] {
return 0;
}
return map_->erase(key);
}
template <typename Func>
size_t erase_if(Func&& op) {
std::lock_guard guard(*mtx_);
if (!map_) [[unlikely]] {
return 0;
}
return std::erase_if(*map_, std::forward<Func>(op));
}
template <typename Func>
bool for_each(Func&& op) {
std::lock_guard guard(*mtx_);
if (!map_) [[unlikely]] {
return true;
}
for (auto& e : *map_) {
if constexpr (requires { op(e) == true; }) {
if (!op(e)) {
break;
return false;
}
}
else {
op(e);
}
}
return true;
}
template <typename Func>
bool for_each(Func&& op) const {
std::lock_guard guard(*mtx_);
if (!map_) [[unlikely]] {
return true;
}
for (const auto& e : *map_) {
if constexpr (requires { op(e) == true; }) {
if (!op(e)) {
break;
return false;
}
}
else {
op(e);
}
}
return true;
}
private:
Map& visit_map() {
if (!map_) [[unlikely]] {
map_ = std::make_unique<Map>();
}
return *map_;
}
std::unique_ptr<std::mutex> mtx_;
std::unique_ptr<Map> map_;
};
} // namespace internal
template <typename Map, typename Hash>
class map_sharded_t {
public:
using key_type = typename Map::key_type;
using value_type = typename Map::value_type;
using mapped_type = typename Map::mapped_type;
map_sharded_t(size_t shard_num) : shards_(shard_num) {}
template <typename KeyType, typename... Args>
std::pair<std::shared_ptr<typename mapped_type::element_type>, bool>
try_emplace(KeyType&& key, Args&&... args) {
return try_emplace_with_op(
std::forward<KeyType>(key),
[](auto&&) {
},
std::forward<Args>(args)...);
}
template <typename Op, typename... Args>
std::pair<std::shared_ptr<typename mapped_type::element_type>, bool>
try_emplace_with_op(const key_type& key, Op&& func, Args&&... args) {
auto ret = get_sharded(Hash{}(key))
.try_emplace_with_op(key, std::forward<Op>(func),
std::forward<Args>(args)...);
if (ret.second) {
size_.fetch_add(1);
}
return ret;
}
size_t size() const { // this value is approx
int64_t val = size_.load();
if (val < 0) [[unlikely]] { // may happen when insert & deleted frequently
val = 0;
}
return val;
}
std::shared_ptr<typename mapped_type::element_type> find(
const key_type& key) const {
return get_sharded(Hash{}(key)).find(key);
}
size_t erase(const key_type& key) {
auto result = get_sharded(Hash{}(key)).erase(key);
if (result) {
size_.fetch_sub(result);
}
return result;
}
template <typename Func>
size_t erase_if(Func&& op) {
auto total = 0;
for (auto& map : shards_) {
auto result = map.erase_if(std::forward<Func>(op));
total += result;
size_.fetch_sub(result);
}
return total;
}
template <typename Func>
size_t erase_one(Func&& op) {
auto total = 0;
for (auto& map : shards_) {
auto result = map.erase_if(std::forward<Func>(op));
if (result) {
total += result;
size_.fetch_sub(result);
break;
}
}
return total;
}
template <typename Func>
void for_each(Func&& op) {
for (auto& map : shards_) {
if (!map.for_each(op))
break;
}
}
template <typename T>
std::vector<T> copy(auto&& op) const {
std::vector<T> ret;
ret.reserve(size());
for (auto& map : shards_) {
map.for_each([&ret, &op](auto& e) {
if (op(e.second)) {
ret.push_back(e.second);
}
});
}
return ret;
}
template <typename T>
std::vector<T> copy() const {
return copy<T>([](auto&) {
return true;
});
}
private:
internal::map_lock_t<Map>& get_sharded(size_t hash) {
return shards_[hash % shards_.size()];
}
const internal::map_lock_t<Map>& get_sharded(size_t hash) const {
return shards_[hash % shards_.size()];
}
std::vector<internal::map_lock_t<Map>> shards_;
std::atomic<int64_t> size_;
};
} // namespace ylt::util