source/GroupHashedSetAbstractDomain.h (279 lines of code) (raw):

/* * Copyright (c) Meta Platforms, Inc. and affiliates. * * This source code is licensed under the MIT license found in the * LICENSE file in the root directory of this source tree. */ #pragma once #include <ostream> #include <type_traits> #include <unordered_set> #include <utility> #include <boost/iterator/transform_iterator.hpp> #include <AbstractDomain.h> #include <mariana-trench/Assert.h> namespace marianatrench { namespace detail { /** * This is a wrapper around a mutable value. It allows getting a mutable * reference on a constant instance. This is a type safe alternative to * `const_cast`, since mutating a reference produced by `const_cast` is * undefined behavior. */ template <typename Value> class MutableValue { public: explicit MutableValue(Value value) : value_(std::move(value)) {} MutableValue(const MutableValue&) = default; MutableValue(MutableValue&&) = default; MutableValue& operator=(const MutableValue&) = default; MutableValue& operator=(MutableValue&&) = default; ~MutableValue() = default; Value& get() { return value_; } const Value& get() const { return value_; } Value& get_unsafe() const { return value_; } private: mutable Value value_; }; template <typename Element> struct GroupDifference { void operator()(Element& left, const Element& right) const { if (left.leq(right)) { left.set_to_bottom(); } } }; } // namespace detail /** * A powerset abstract domain with grouping implemented using hash tables. * * `GroupHash` and `GroupEqual` describe how elements are grouped together. * * The implementation is mostly based on `sparta::HashedSetAbstractDomain`. */ template < typename Element, typename GroupHash, typename GroupEqual, typename GroupDifference = detail::GroupDifference<Element>> class GroupHashedSetAbstractDomain final : public sparta::AbstractDomain<GroupHashedSetAbstractDomain< Element, GroupHash, GroupEqual, GroupDifference>> { public: static_assert(std::is_same_v< decltype(GroupHash()(std::declval<const Element>())), std::size_t>); static_assert(std::is_same_v< decltype(GroupEqual()( std::declval<const Element>(), std::declval<const Element>())), bool>); static_assert(std::is_same_v< decltype(GroupDifference()( std::declval<Element&>(), std::declval<const Element>())), void>); private: using MutableElement = detail::MutableValue<Element>; struct MutableElementHash { std::size_t operator()(const MutableElement& element) const { return GroupHash()(element.get()); } }; struct MutableElementEqual { bool operator()(const MutableElement& left, const MutableElement& right) const { return GroupEqual()(left.get(), right.get()); } }; struct ExposeElement { const Element& operator()(const MutableElement& element) const { return element.get(); } }; using Set = std:: unordered_set<MutableElement, MutableElementHash, MutableElementEqual>; using ConstIterator = boost::transform_iterator<ExposeElement, typename Set::const_iterator>; public: // C++ container concept member types using iterator = ConstIterator; using const_iterator = ConstIterator; using value_type = Element; using difference_type = std::ptrdiff_t; using size_type = std::size_t; using const_reference = const Element&; using const_pointer = const Element*; public: /* Create the bottom (i.e, empty) abstract set. */ GroupHashedSetAbstractDomain() = default; explicit GroupHashedSetAbstractDomain(const Element& element) { add(element); } explicit GroupHashedSetAbstractDomain( std::initializer_list<Element> elements) { for (const auto& element : elements) { add(element); } } GroupHashedSetAbstractDomain(const GroupHashedSetAbstractDomain&) = default; GroupHashedSetAbstractDomain(GroupHashedSetAbstractDomain&&) = default; GroupHashedSetAbstractDomain& operator=(const GroupHashedSetAbstractDomain&) = default; GroupHashedSetAbstractDomain& operator=(GroupHashedSetAbstractDomain&&) = default; static GroupHashedSetAbstractDomain bottom() { return GroupHashedSetAbstractDomain(); } static GroupHashedSetAbstractDomain top() { mt_unreachable(); // Not implemented. } bool is_bottom() const override { return set_.empty(); } bool is_top() const override { return false; } void set_to_bottom() override { set_.clear(); } void set_to_top() override { mt_unreachable(); // Not implemented. } std::size_t size() const { return set_.size(); } bool empty() const { return set_.empty(); } ConstIterator begin() const { return boost::make_transform_iterator(set_.cbegin(), ExposeElement()); } ConstIterator end() const { return boost::make_transform_iterator(set_.cend(), ExposeElement()); } bool contains(const Element& element) const { if (element.is_bottom()) { return true; } auto found = set_.find(MutableElement(element)); return found != set_.end() && element.leq(found->get()); } void add(const Element& element) { if (element.is_bottom()) { return; } auto result = set_.emplace(element); if (!result.second) { // This is safe as long as `join_with` does not change the grouping result.first->get_unsafe().join_with(element); } } void remove(const Element& element) { if (element.is_bottom()) { return; } auto found = set_.find(MutableElement(element)); if (found != set_.end() && found->get().leq(element)) { set_.erase(found); } } void clear() { set_.clear(); } bool leq(const GroupHashedSetAbstractDomain& other) const override { if (set_.size() > other.set_.size()) { return false; } for (const MutableElement& mutable_element : set_) { const Element& element = mutable_element.get(); auto found = other.set_.find(mutable_element); if (found == other.set_.end() || !element.leq(found->get())) { return false; } } return true; } bool equals(const GroupHashedSetAbstractDomain& other) const override { if (set_.size() != other.set_.size()) { return false; } for (const MutableElement& mutable_element : set_) { const Element& element = mutable_element.get(); auto found = other.set_.find(mutable_element); if (found == other.set_.end() || !(element == found->get())) { return false; } } return true; } void join_with(const GroupHashedSetAbstractDomain& other) override { for (const MutableElement& mutable_element : other.set_) { auto result = set_.insert(mutable_element); if (!result.second) { // This is safe as long as `join_with` does not change the grouping. result.first->get_unsafe().join_with(mutable_element.get()); } } } void widen_with(const GroupHashedSetAbstractDomain& other) override { join_with(other); } void meet_with(const GroupHashedSetAbstractDomain& /*other*/) override { mt_unreachable(); // Not implemented. } void narrow_with(const GroupHashedSetAbstractDomain& other) override { meet_with(other); } void difference_with(const GroupHashedSetAbstractDomain& other) { // For performance, we iterate on the smallest set. if (set_.size() <= other.set_.size()) { for (auto iterator = set_.begin(), end = set_.end(); iterator != end;) { auto found = other.set_.find(*iterator); if (found != other.set_.end()) { GroupDifference()(iterator->get_unsafe(), found->get()); if (iterator->get().is_bottom()) { iterator = set_.erase(iterator); } else { ++iterator; } } else { ++iterator; } } } else { for (const MutableElement& element : other.set_) { auto found = set_.find(element); if (found != set_.end()) { GroupDifference()(found->get_unsafe(), element.get()); if (found->get().is_bottom()) { set_.erase(found); } } } } } /* Update all elements without affecting the grouping. */ void map(const std::function<void(Element&)>& f) { for (auto iterator = set_.begin(), end = set_.end(); iterator != end;) { // This is safe as long as `f` does not change the grouping. const MutableElement& mutable_element = *iterator; auto previous_hash = MutableElementHash()(mutable_element); f(mutable_element.get_unsafe()); if (mutable_element.get().is_bottom()) { iterator = set_.erase(iterator); } else { auto current_hash = MutableElementHash()(mutable_element); mt_assert_log(current_hash == previous_hash, "group hash has changed"); ++iterator; } } } /* Remove all elements that do not match the given predicate. */ void filter(const std::function<bool(const Element&)>& predicate) { for (auto iterator = set_.begin(), end = set_.end(); iterator != end;) { if (!predicate(iterator->get())) { iterator = set_.erase(iterator); } else { iterator++; } } } friend std::ostream& operator<<( std::ostream& out, const GroupHashedSetAbstractDomain& value) { out << "{"; for (auto it = value.begin(); it != value.end();) { out << *it++; if (it != value.end()) { out << ", "; } } return out << "}"; } private: Set set_; }; } // namespace marianatrench