source/AccessPathTreeDomain.h (200 lines of code) (raw):

/* * Copyright (c) Meta Platforms, Inc. and affiliates. * * This source code is licensed under the MIT license found in the * LICENSE file in the root directory of this source tree. */ #pragma once #include <functional> #include <initializer_list> #include <vector> #include <AbstractDomain.h> #include <Show.h> #include <mariana-trench/AbstractTreeDomain.h> #include <mariana-trench/Access.h> #include <mariana-trench/RootPatriciaTreeAbstractPartition.h> namespace marianatrench { /** * An access path tree domain. * * This represents a map from roots to abstract trees. * * See `AbstractTreeDomain` for more information. */ template <typename Elements> class AccessPathTreeDomain final : public sparta::AbstractDomain<AccessPathTreeDomain<Elements>> { public: using AbstractTreeDomainT = AbstractTreeDomain<Elements>; private: using Map = RootPatriciaTreeAbstractPartition<AbstractTreeDomainT>; public: // C++ container concept member types using key_type = Root; using mapped_type = AbstractTreeDomainT; using value_type = std::pair<Root, AbstractTreeDomainT>; using iterator = typename Map::iterator; using const_iterator = iterator; using difference_type = std::ptrdiff_t; using size_type = std::size_t; using const_reference = const value_type&; using const_pointer = const value_type*; private: explicit AccessPathTreeDomain(Map map) : map_(std::move(map)) {} public: /* Return the bottom value (i.e, the empty tree). */ AccessPathTreeDomain() = default; explicit AccessPathTreeDomain( std::initializer_list<std::pair<AccessPath, Elements>> edges) { for (const auto& [access_path, elements] : edges) { write(access_path, elements, UpdateKind::Weak); } } explicit AccessPathTreeDomain( const std::vector<std::pair<AccessPath, Elements>>& edges) { for (const auto& [access_path, elements] : edges) { write(access_path, elements, UpdateKind::Weak); } } AccessPathTreeDomain(const AccessPathTreeDomain&) = default; AccessPathTreeDomain(AccessPathTreeDomain&&) = default; AccessPathTreeDomain& operator=(const AccessPathTreeDomain&) = default; AccessPathTreeDomain& operator=(AccessPathTreeDomain&&) = default; static AccessPathTreeDomain bottom() { return AccessPathTreeDomain(Map::bottom()); } static AccessPathTreeDomain top() { return AccessPathTreeDomain(Map::top()); } bool is_bottom() const override { return map_.is_bottom(); } bool is_top() const override { return map_.is_top(); } void set_to_bottom() override { map_.set_to_bottom(); } void set_to_top() override { map_.set_to_top(); } bool leq(const AccessPathTreeDomain& other) const override { return map_.leq(other.map_); } bool equals(const AccessPathTreeDomain& other) const override { return map_.equals(other.map_); } void join_with(const AccessPathTreeDomain& other) override { map_.join_with(other.map_); } void widen_with(const AccessPathTreeDomain& other) override { join_with(other); } void meet_with(const AccessPathTreeDomain& /*other*/) override { mt_unreachable(); // Not implemented. } void narrow_with(const AccessPathTreeDomain& other) override { meet_with(other); } /* Write elements at the given access path. */ void write(const AccessPath& access_path, Elements elements, UpdateKind kind) { map_.update(access_path.root(), [&](const AbstractTreeDomainT& tree) { auto copy = tree; copy.write(access_path.path(), std::move(elements), kind); return copy; }); } /* Write a tree at the given access path. */ void write( const AccessPath& access_path, AbstractTreeDomainT tree, UpdateKind kind) { map_.update(access_path.root(), [&](const AbstractTreeDomainT& subtree) { auto copy = subtree; copy.write(access_path.path(), std::move(tree), kind); return copy; }); } const AbstractTreeDomainT& read(Root root) const { return map_.get(root); } template <typename Propagate> AbstractTreeDomainT read( const AccessPath& access_path, const Propagate& propagate) const { return map_.get(access_path.root()).read(access_path.path(), propagate); } AbstractTreeDomainT read(const AccessPath& access_path) const { return map_.get(access_path.root()).read(access_path.path()); } AbstractTreeDomainT raw_read(const AccessPath& access_path) const { return map_.get(access_path.root()).raw_read(access_path.path()); } /** * Iterate on all non-empty elements in the tree. * * When visiting the tree, elements do not include their ancestors. */ void visit( std::function<void(const AccessPath&, const Elements&)> visitor) const { mt_assert(!is_top()); for (const auto& [root, tree] : map_) { auto access_path = AccessPath(root); visit_internal(access_path, tree, visitor); } } private: static void visit_internal( AccessPath& access_path, const AbstractTreeDomainT& tree, std::function<void(const AccessPath&, const Elements&)>& visitor) { if (!tree.root().is_bottom()) { visitor(access_path, tree.root()); } for (const auto& [path_element, subtree] : tree.successors()) { access_path.append(path_element); visit_internal(access_path, subtree, visitor); access_path.pop_back(); } } public: /** * Return the list of pairs (access path, elements) in the tree. * * Elements are returned by reference. * Elements do not contain their ancestors. */ std::vector<std::pair<AccessPath, const Elements&>> elements() const { std::vector<std::pair<AccessPath, const Elements&>> results; visit([&](const AccessPath& access_path, const Elements& element) { results.push_back({access_path, element}); }); return results; } /* Apply the given function on all elements. */ void map(const std::function<void(Elements&)>& f) { map_.map([&](const AbstractTreeDomainT& tree) { auto copy = tree; copy.map(f); return copy; }); } /* Return the begin iterator over the pairs (root, tree). */ iterator begin() const { return map_.begin(); } /* Return the end iterator over the pairs (root, tree). */ iterator end() const { return map_.end(); } /** * When a path is invalid, collapse its taint into its parent's. * See AbstractTreeDomain::collapse_invalid_paths. */ template <typename Accumulator> void collapse_invalid_paths( const std::function< std::pair<bool, Accumulator>(const Accumulator&, Path::Element)>& is_valid, const std::function<Accumulator(const Root&)>& initial_accumulator) { Map new_map; for (const auto& [root, tree] : map_) { auto copy = tree; copy.collapse_invalid_paths( is_valid, /* accumulator */ initial_accumulator(root)); new_map.set(root, std::move(copy)); } map_ = new_map; } /* Collapse children that have more than `max_leaves` leaves. */ void limit_leaves(std::size_t max_leaves) { map_.map([=](const AbstractTreeDomainT& tree) { auto copy = tree; copy.limit_leaves(max_leaves); return copy; }); } friend std::ostream& operator<<( std::ostream& out, const AccessPathTreeDomain& tree) { out << "AccessPathTree{"; for (auto iterator = tree.map_.begin(), end = tree.map_.end(); iterator != end;) { out << iterator->first << " -> " << iterator->second; ++iterator; if (iterator != end) { out << ", "; } } return out << "}"; } private: Map map_; }; } // namespace marianatrench