glean/rts/ownership/triearray.h (217 lines of code) (raw):
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#pragma once
#include "glean/rts/ownership/pool.h"
#include <cstdint>
#include <vector>
namespace facebook {
namespace glean {
namespace rts {
/**
* A datastructure which maps ranges of Ids to values. The current
* implementation is a hack which only supports Ids which fit in 32 bits.
*
* The implementation is a bit trie with a large initial fanout (64k) and then a
* smaller inner fanout (16), giving a maximum depth of 4 (for 32-bit values).
*
* This should most likely be switched to some form of Patricia tree.
*/
template<typename T>
class TrieArray {
public:
TrieArray() : trees_(new ForestN<FANOUT_TOP>(Tree::null())) {}
/**
* Insert a sorted sequence of non-overlapping Id ranges by combining the
* previously stored values via `get`.
*
* This tries to split the tree as little as possible. We also guarantee to
* call `get` exactly once for each previous value (including `nullptr` for
* "no previous value").
*/
template<typename Get>
void insert(const OwnershipUnit::Ids *start, const OwnershipUnit::Ids *finish, Get&& get) {
if (start == finish) {
return;
}
minkey_ = std::min(minkey_, start->start.toWord());
maxkey_ = std::max(maxkey_, finish[-1].finish.toWord());
// only 32-bit keys are supported; this property is assumed later
CHECK(maxkey_ <= std::numeric_limits<uint32_t>::max());
// Algorithm:
//
// - Collect previously existing values in `values`.
// - When we see a value for the first time, set its `link` field to point
// to the first `Tree` that contains it and add it to `values`.
// - When we see a value again, temporarily make the current `Tree` point
// to the previous tree that contained the value (from the value's `link`)
// and update the value's `link` to point to the current tree. This
// effectively maintains a linked list of trees which contain a value,
// with the value's `link` being the root.
// - Do the same for newly inserted trees which don't have a previous value
// via `null_link`.
// - Once we've collected everything, for each value in `values` compute the
// new value via `get` and then traverse the linked list of trees, storing
// a pointer to the new value in each one.
// - Do the same for `null_link`.
std::vector<T*> values;
Tree *null_link = nullptr;
while (start != finish) {
const auto [first_id, last_id] = *start++;
if (first_id <= last_id) {
traverse(
first_id.toWord(),
last_id.toWord() - first_id.toWord() + 1,
[&](Tree& tree, uint64_t key, uint64_t size, size_t block) {
if (size == block) {
if (const auto value = tree.value()) {
const auto prev = static_cast<Tree *>(value->link());
value->link(&tree);
tree = Tree::link(prev);
if (prev == nullptr) {
values.push_back(value);
} else {
value->use(-1);
}
} else {
tree = Tree::link(null_link);
null_link = &tree;
}
} else {
if (auto value = tree.value()) {
value->use(FANOUT-1);
}
tree = Tree::forest(pool_.alloc(tree));
}
});
}
}
const auto unlink = [&](Tree *tree, T * FOLLY_NULLABLE value) {
auto upd = get(value, 1);
uint32_t refs = 0;
while (tree != nullptr) {
auto next = tree->link();
*tree = Tree::value(upd);
tree = next;
++refs;
}
upd->use(refs-1);
};
if (null_link) {
unlink(null_link, nullptr);
}
for (auto value : values) {
Tree *tree = static_cast<Tree*>(value->link());
value->link(nullptr);
unlink(tree, value);
}
}
template<typename F>
void foreach(F&& f) {
traverse([&](Tree& tree, uint64_t key, uint64_t size, uint64_t block) {
if (auto *value = tree.value()) {
if (auto new_value = f(value)) {
tree = Tree::value(new_value);
}
}
});
}
std::vector<T*> flatten() {
if (maxkey_ < minkey_) {
return std::vector<T*>(0);
}
std::vector<T*> vec(maxkey_+1, nullptr);
traverse([&](const Tree& tree, uint64_t key, uint64_t size, uint64_t block) {
auto *value = tree.value();
std::fill(vec.begin() + key, vec.begin() + key + size, value);
if (value) {
value->use(size-1);
}
});
return vec;
}
private:
static constexpr size_t FANOUT_TOP = 65536;
static constexpr size_t FANOUT = 16;
static constexpr size_t BLOCK = (size_t(std::numeric_limits<uint32_t>::max()) + 1) / FANOUT_TOP;
static constexpr size_t blockSize(uint8_t level) {
auto size = BLOCK;
while (level != 0) {
size /= FANOUT;
--level;
}
return size;
}
template<uint32_t N> struct ForestN;
using Forest = ForestN<FANOUT>;
// A tagged pointer based sum of nothing (`nullptr`), a non-null pointer to a
// value and a pointer to a forest.
struct Tree {
uintptr_t ptr;
static Tree null() {
Tree t;
t.ptr = 0;
return t;
}
static Tree value(T *x) {
Tree t;
t.ptr = reinterpret_cast<uintptr_t>(x);
assert((t.ptr&1) == 0);
assert(t.ptr != 0);
return t;
}
static Tree forest(Forest *forest) {
Tree t;
t.ptr = reinterpret_cast<uintptr_t>(forest);
assert((t.ptr&1) == 0);
t.ptr |= 1;
return t;
}
static Tree link(Tree *x) {
Tree t;
t.ptr = reinterpret_cast<uintptr_t>(x);
return t;
}
bool empty() const {
return ptr == 0;
}
T * FOLLY_NULLABLE value() const {
return (ptr&1) == 0 ? reinterpret_cast<T *>(ptr) : nullptr;
}
Forest * FOLLY_NULLABLE forest() {
return (ptr&1) == 1 ? reinterpret_cast<Forest *>(ptr-1) : nullptr;
}
Tree *link() const {
return reinterpret_cast<Tree *>(ptr);
}
bool isForest() const {
return (ptr&1) == 1;
}
};
template<uint32_t N>
struct ForestN {
Tree trees_[N];
explicit ForestN(Tree tree) {
std::fill(trees_, trees_+N, tree);
}
Tree& at(uint32_t i) {
return trees_[i];
}
Tree at(uint32_t i) const {
return trees_[i];
}
};
static std::pair<uint64_t, uint64_t> location(uint64_t key) {
assert(key <= std::numeric_limits<uint32_t>::max());
return {key / BLOCK, key % BLOCK};
}
template<typename F>
void traverse(F&& f) {
if (minkey_ <= maxkey_) {
traverse(minkey_, maxkey_-minkey_+1, f);
}
}
// NOTE: We traverse the tree via type-level recursion rather than a runtime
// loop since there is a very small, statically knows bound on its depth.
//
// `F` gets called for each leaf tree (with or without a value). It can modify
// the tree to become a forest in which case `traverse` will descend into it.
template<typename F>
void traverse(uint64_t start, uint64_t size, F&& f) {
auto [first_block, first_index] = location(start);
auto [last_block, last_index] = location(start + (size - 1));
uint64_t key = start;
while (first_block < last_block) {
const auto n = BLOCK - first_index;
traverse<0>(trees_->at(first_block), key, first_index, n, f);
key += n;
++first_block;
first_index = 0;
}
traverse<0>(trees_->at(first_block), key, first_index, last_index - first_index + 1, f);
}
template<size_t level, typename F>
static void traverse(Tree& tree, uint64_t key, uint64_t start, uint64_t size, F& f) {
if (!tree.isForest()) {
f(tree, key, size, blockSize(level));
}
if (auto forest = tree.forest()) {
constexpr auto block = blockSize(level+1);
if constexpr (block != 0) {
auto t = start / block;
auto i = start % block;
while (size != 0) {
const auto n = std::min(size, block-i);
traverse<level+1>(forest->at(t), key, i, n, f);
i = 0;
key += n;
size -= n;
++t;
}
}
}
}
std::unique_ptr<ForestN<FANOUT_TOP>> trees_;
Pool<Forest> pool_;
uint64_t minkey_ = std::numeric_limits<uint64_t>::max();
uint64_t maxkey_ = 0;
};
}
}
}