include/FlatSet.h (180 lines of code) (raw):

/* * Copyright (c) Meta Platforms, Inc. and affiliates. * * This source code is licensed under the MIT license found in the * LICENSE file in the root directory of this source tree. */ #pragma once #include <algorithm> #include <functional> #include <initializer_list> #include <limits> #include <ostream> #include <vector> #include "PatriciaTreeUtil.h" namespace sparta { /* * Represents a set implemented with a sorted vector. * * It is similar to `boost::container::flat_set` but provides set operations * such as union, intersection and difference, using the same interface as * `PatriciaTreeSet`. */ template <typename Element, typename Compare = std::less<Element>, typename Equal = std::equal_to<Element>> class FlatSet final { public: // C++ container concept member types using iterator = typename std::vector<Element>::const_iterator; using const_iterator = iterator; using value_type = Element; using difference_type = std::ptrdiff_t; using size_type = std::size_t; using const_reference = const Element&; using const_pointer = const Element*; FlatSet() = default; explicit FlatSet(std::initializer_list<Element> l) { for (const Element& x : l) { insert(x); } } template <typename InputIterator> FlatSet(InputIterator first, InputIterator last) { for (auto it = first; it != last; ++it) { insert(*it); } } bool empty() const { return m_vector.empty(); } std::size_t size() const { return m_vector.size(); } std::size_t max_size() const { return m_vector.max_size(); } iterator begin() const { return m_vector.begin(); } iterator end() const { return m_vector.end(); } bool contains(const Element& key) const { auto it = std::lower_bound(m_vector.begin(), m_vector.end(), key, Compare()); return it != m_vector.end() && Equal()(*it, key); } bool is_subset_of(const FlatSet& other) const { // This is optimized for `this.size() << other.size()`. auto it = m_vector.begin(), end = m_vector.end(); auto other_it = other.m_vector.begin(), other_end = other.m_vector.end(); while (it != end) { if (std::distance(it, end) > std::distance(other_it, other_end)) { return false; } other_it = std::lower_bound(other_it, other_end, *it, Compare()); if (other_it == other_end || !Equal()(*it, *other_it)) { return false; } ++it; ++other_it; } return true; } bool equals(const FlatSet& other) const { return std::equal(m_vector.begin(), m_vector.end(), other.m_vector.begin(), other.m_vector.end(), Equal()); } friend bool operator==(const FlatSet& s1, const FlatSet& s2) { return s1.equals(s2); } friend bool operator!=(const FlatSet& s1, const FlatSet& s2) { return !s1.equals(s2); } FlatSet& insert(Element key) { auto it = std::lower_bound(m_vector.begin(), m_vector.end(), key, Compare()); if (it == m_vector.end() || !Equal()(key, *it)) { m_vector.insert(it, std::move(key)); } return *this; } FlatSet& remove(const Element& key) { auto it = std::lower_bound(m_vector.begin(), m_vector.end(), key, Compare()); if (it != m_vector.end() && Equal()(key, *it)) { m_vector.erase(it); } return *this; } FlatSet& filter(const std::function<bool(const Element&)>& predicate) { m_vector.erase( std::remove_if(m_vector.begin(), m_vector.end(), [&](const Element& e) { return !predicate(e); }), m_vector.end()); return *this; } FlatSet& union_with(const FlatSet& other) { // This is optimized for `this.size() >> other.size()`. auto it = m_vector.begin(); auto other_it = other.m_vector.begin(), other_end = other.m_vector.end(); while (other_it != other_end) { it = std::lower_bound(it, m_vector.end(), *other_it, Compare()); if (it == m_vector.end() || !Equal()(*it, *other_it)) { it = m_vector.insert(it, *other_it); } ++it; ++other_it; } return *this; } FlatSet& intersection_with(const FlatSet& other) { // This is optimized for `this.size() << other.size()`. auto first = m_vector.begin(); // Where to write the next element to keep. auto it = m_vector.begin(), end = m_vector.end(); auto other_it = other.m_vector.begin(), other_end = other.m_vector.end(); while (it != end) { other_it = std::lower_bound(other_it, other_end, *it, Compare()); if (other_it != other_end && Equal()(*it, *other_it)) { if (first != it) { *first = std::move(*it); } ++first; ++other_it; } ++it; } m_vector.erase(first, end); return *this; } FlatSet& difference_with(const FlatSet& other) { // This is optimized for `this.size() >> other.size()`. auto it = m_vector.begin(); auto other_it = other.m_vector.begin(), other_end = other.m_vector.end(); while (other_it != other_end) { it = std::lower_bound(it, m_vector.end(), *other_it); if (it != m_vector.end() && Equal()(*it, *other_it)) { it = m_vector.erase(it); ++it; } ++other_it; } return *this; } FlatSet get_union_with(const FlatSet& other) const { if (m_vector.size() > other.m_vector.size()) { auto result = *this; result.union_with(other); return result; } else { auto result = other; result.union_with(*this); return result; } } FlatSet get_intersection_with(const FlatSet& other) const { if (m_vector.size() < other.m_vector.size()) { auto result = *this; result.intersection_with(other); return result; } else { auto result = other; result.intersection_with(*this); return result; } } FlatSet get_difference_with(const FlatSet& other) const { auto result = *this; result.difference_with(other); return result; } void clear() { m_vector.clear(); } friend std::ostream& operator<<(std::ostream& o, const FlatSet<Element>& s) { o << "{"; for (auto it = s.begin(), end = s.end(); it != end;) { o << pt_util::Dereference<Element>()(*it); ++it; if (it != end) { o << ", "; } } o << "}"; return o; } private: std::vector<Element> m_vector; }; } // namespace sparta