source/AbstractTreeDomain.h (557 lines of code) (raw):

/* * Copyright (c) Meta Platforms, Inc. and affiliates. * * This source code is licensed under the MIT license found in the * LICENSE file in the root directory of this source tree. */ #pragma once #include <functional> #include <initializer_list> #include <iterator> #include <vector> #include <AbstractDomain.h> #include <PatriciaTreeMap.h> #include <mariana-trench/Access.h> #include <mariana-trench/Assert.h> #include <mariana-trench/Heuristics.h> namespace marianatrench { enum class UpdateKind { /* Perform a strong update, i.e previous elements are replaced. */ Strong, /* Perform a weak update, i.e elements are joined. */ Weak, }; /** * An abstract tree domain. * * This is mainly used with a source set or a sink set as `Elements`, to store * the taint on each access paths. * * Elements on nodes are implicitly propagated to their children. */ template <typename Elements> class AbstractTreeDomain final : public sparta::AbstractDomain<AbstractTreeDomain<Elements>> { public: using PathElement = typename Path::Element; private: struct ValueInterface { using type = AbstractTreeDomain; static AbstractTreeDomain default_value() { return AbstractTreeDomain::bottom(); } static bool is_default_value(const AbstractTreeDomain& x) { return x.is_bottom(); } static bool equals( const AbstractTreeDomain& x, const AbstractTreeDomain& y) { // This is a structural equality, because this is used in // `sparta::PatriciaTreeMap`'s implementation to avoid node duplication. return x.elements_.equals(y.elements_) && x.children_.reference_equals(y.children_); } static bool leq( const AbstractTreeDomain& /*x*/, const AbstractTreeDomain& /*y*/) { mt_unreachable(); // Never used. } }; public: using Map = sparta::PatriciaTreeMap<PathElement, AbstractTreeDomain, ValueInterface>; public: /* Return the bottom value (i.e, the empty tree). */ AbstractTreeDomain() : elements_(Elements::bottom()) {} explicit AbstractTreeDomain(Elements elements) : elements_(std::move(elements)) {} explicit AbstractTreeDomain( std::initializer_list<std::pair<Path, Elements>> edges) : elements_(Elements::bottom()) { for (const auto& [path, elements] : edges) { write(path, elements, UpdateKind::Weak); } } AbstractTreeDomain(const AbstractTreeDomain&) = default; AbstractTreeDomain(AbstractTreeDomain&&) = default; AbstractTreeDomain& operator=(const AbstractTreeDomain&) = default; AbstractTreeDomain& operator=(AbstractTreeDomain&&) = default; static AbstractTreeDomain bottom() { return AbstractTreeDomain(); } static AbstractTreeDomain top() { mt_unreachable(); // Not implemented. } bool is_bottom() const override { return elements_.is_bottom() && children_.empty(); } bool is_top() const override { return false; } void set_to_bottom() override { elements_.set_to_bottom(); children_.clear(); } void set_to_top() override { mt_unreachable(); // Not implemented. } const Elements& root() const { return elements_; } const Map& successors() const { return children_; } const AbstractTreeDomain& successor(PathElement path_element) const { return children_.at(path_element); } bool leq(const AbstractTreeDomain& other) const override { if (!elements_.leq(other.elements_)) { return false; } if (children_.reference_equals(other.children_)) { return true; } for (const auto& [path_element, subtree] : children_) { auto other_subtree = other.children_.at(path_element); // Read semantics: we propagate the elements to the children. other_subtree.elements_.join_with(other.elements_); if (!subtree.leq(other_subtree)) { return false; } } return true; } bool equals(const AbstractTreeDomain& other) const override { if (!elements_.equals(other.elements_)) { return false; } if (children_.reference_equals(other.children_)) { return true; } for (const auto& [path_element, subtree] : children_) { auto subtree_copy = subtree; auto other_subtree = other.children_.at(path_element); // Read semantics: we propagate the elements to the children. subtree_copy.elements_.join_with(elements_); other_subtree.elements_.join_with(other.elements_); if (!subtree_copy.equals(other_subtree)) { return false; } } for (const auto& [path_element, other_subtree] : other.children_) { auto subtree = children_.at(path_element); if (!subtree.is_bottom()) { continue; // Already handled. } // Read semantics: we propagate the elements to the children. auto other_subtree_copy = other_subtree; other_subtree_copy.elements_.join_with(other.elements_); subtree = AbstractTreeDomain(elements_); if (!subtree.equals(other_subtree_copy)) { return false; } } return true; } void join_with(const AbstractTreeDomain& other) override { mt_if_expensive_assert(auto previous = *this); if (other.is_bottom()) { return; } else if (is_bottom()) { *this = other; } else { join_with_internal(other, Elements::bottom()); } mt_expensive_assert(previous.leq(*this) && other.leq(*this)); } private: void join_with_internal( const AbstractTreeDomain& other, const Elements& accumulator) { // The read semantics implies that an element on a node is implicitly // propagated to all its children. The `accumulator` contains all elements // of the ancestors/parents. If the elements on a child are included in // the accumulator, we can remove them. elements_.join_with(other.elements_); elements_.difference_with(accumulator); if (children_.reference_equals(other.children_)) { return; } const auto new_accumulator_tree = AbstractTreeDomain{accumulator.join(elements_)}; Map new_children; for (const auto& [path_element, subtree] : children_) { const auto& other_subtree = other.children_.at(path_element); if (!other_subtree.is_bottom()) { auto subtree_copy = subtree; subtree_copy.join_with_internal( other_subtree, new_accumulator_tree.elements_); if (!subtree_copy.is_bottom()) { new_children.insert_or_assign(path_element, std::move(subtree_copy)); } } else { if (!subtree.leq(new_accumulator_tree)) { new_children.insert_or_assign(path_element, subtree); } } } for (const auto& [path_element, other_subtree] : other.children_) { const auto& subtree = children_.at(path_element); if (!subtree.is_bottom()) { continue; // Already handled. } if (!other_subtree.leq(new_accumulator_tree)) { new_children.insert_or_assign(path_element, other_subtree); } } children_ = new_children; } public: void widen_with(const AbstractTreeDomain& other) override { mt_if_expensive_assert(auto previous = *this); if (other.is_bottom()) { return; } else if (is_bottom()) { *this = other; } else { widen_with_internal( other, Elements::bottom(), /* max_height */ Heuristics::kAbstractTreeWideningHeight); } mt_expensive_assert(previous.leq(*this) && other.leq(*this)); } private: void widen_with_internal( const AbstractTreeDomain& other, const Elements& accumulator, std::size_t max_height) { if (max_height == 0) { collapse_inplace(); other.collapse_into(elements_); elements_.difference_with(accumulator); return; } // The read semantics implies that an element on a node is implicitly // propagated to all its children. The `accumulator` contains all elements // of the ancestors/parents. If the elements on a child are included in // the accumulator, we can remove them. elements_.join_with(other.elements_); elements_.difference_with(accumulator); if (children_.reference_equals(other.children_)) { collapse_deeper_than(max_height); return; } const auto new_accumulator_tree = AbstractTreeDomain{accumulator.join(elements_)}; Map new_children; for (const auto& [path_element, subtree] : children_) { const auto& other_subtree = other.children_.at(path_element); if (!other_subtree.is_bottom()) { auto subtree_copy = subtree; subtree_copy.widen_with_internal( other_subtree, new_accumulator_tree.elements_, max_height - 1); if (!subtree_copy.is_bottom()) { new_children.insert_or_assign(path_element, std::move(subtree_copy)); } } else { if (!subtree.leq(new_accumulator_tree)) { auto subtree_copy = subtree; subtree_copy.collapse_deeper_than(max_height - 1); new_children.insert_or_assign(path_element, subtree_copy); } } } for (const auto& [path_element, other_subtree] : other.children_) { const auto& subtree = children_.at(path_element); if (!subtree.is_bottom()) { continue; // Already handled. } if (!other_subtree.leq(new_accumulator_tree)) { auto other_subtree_copy = other_subtree; other_subtree_copy.collapse_deeper_than(max_height - 1); new_children.insert_or_assign(path_element, other_subtree_copy); } } children_ = new_children; } public: void meet_with(const AbstractTreeDomain& /*other*/) override { mt_unreachable(); // Not implemented. } void narrow_with(const AbstractTreeDomain& other) override { meet_with(other); } /* Return all elements in the tree. */ Elements collapse() const { Elements elements = elements_; for (const auto& [path_element, subtree] : children_) { subtree.collapse_into(elements); } return elements; } /* Collapse the tree into a singleton, in place. */ void collapse_inplace() { for (const auto& [path_element, subtree] : children_) { subtree.collapse_into(elements_); } children_.clear(); } /* Collapse the tree into the given set of elements. */ void collapse_into(Elements& elements) const { elements.join_with(elements_); for (const auto& [path_eleemnt, subtree] : children_) { subtree.collapse_into(elements); } } /* Collapse the tree to the given maximum height. */ void collapse_deeper_than(std::size_t height) { if (height == 0) { collapse_inplace(); } else { children_.map([=](const AbstractTreeDomain& subtree) { auto copy = subtree; copy.collapse_deeper_than(height - 1); return copy; }); } } /* Remove the given elements from the tree. */ void prune(Elements accumulator) { elements_.difference_with(accumulator); accumulator.join_with(elements_); prune_children(accumulator); } /* Remove the given elements from the subtrees. */ void prune_children(const Elements& accumulator) { children_.map([&](const AbstractTreeDomain& subtree) { auto copy = subtree; copy.prune(accumulator); return copy; }); } /** * When a path is invalid, collapse its taint into its parent's. * * A path is invalid if `is_valid().first` is `false`. If valid, the * Accumulator contains information about visited paths so far. */ template <typename Accumulator> void collapse_invalid_paths( const std::function< std::pair<bool, Accumulator>(const Accumulator&, PathElement)>& is_valid, const Accumulator& accumulator) { Map new_children; for (const auto& [path_element, subtree] : children_) { const auto& [valid, accumulator_for_subtree] = is_valid(accumulator, path_element); if (!valid) { // Invalid path, collapse subtree into current tree. elements_.join_with(subtree.collapse()); } else { auto subtree_copy = subtree; subtree_copy.collapse_invalid_paths(is_valid, accumulator_for_subtree); new_children.insert_or_assign(path_element, std::move(subtree_copy)); } } children_ = new_children; } /* Collapse children that have more than `max_leaves` leaves. */ void limit_leaves(std::size_t max_leaves) { auto depth = depth_exceeding_max_leaves(max_leaves); if (!depth) { return; } collapse_deeper_than(*depth); } /* Return the depth at which the tree exceeds the given number of leaves. */ std::optional<std::size_t> depth_exceeding_max_leaves( std::size_t max_leaves) const { // Set of trees at the current depth. std::vector<const AbstractTreeDomain*> trees = {this}; std::size_t depth = 0; // Breadth-first search. while (!trees.empty()) { std::vector<const AbstractTreeDomain*> new_trees; for (const auto* tree : trees) { for (const auto& [path_element, subtree] : tree->children_) { if (subtree.children_.empty()) { if (max_leaves > 0) { max_leaves--; } else { return depth; } } else { new_trees.push_back(&subtree); } } } if (new_trees.size() > max_leaves) { return depth; } depth++; trees = std::move(new_trees); } return std::nullopt; } /* Write the given elements at the given path. */ void write(const Path& path, Elements elements, UpdateKind kind) { write_internal( path.begin(), path.end(), std::move(elements), Elements::bottom(), kind); } private: void write_internal( Path::ConstIterator begin, Path::ConstIterator end, Elements elements, Elements accumulator, UpdateKind kind) { if (begin == end) { switch (kind) { case UpdateKind::Strong: { elements_ = std::move(elements); children_.clear(); break; } case UpdateKind::Weak: { elements_.join_with(elements); accumulator.join_with(elements_); prune_children(accumulator); break; } } return; } accumulator.join_with(elements_); elements.difference_with(accumulator); if (elements.is_bottom() && kind == UpdateKind::Weak) { return; } auto path_head = *begin; ++begin; children_.update( [begin, end, &elements, &accumulator, kind](const auto& subtree) { auto new_subtree = subtree; new_subtree.write_internal( begin, end, std::move(elements), std::move(accumulator), kind); return new_subtree; }, path_head); } public: /* Write the given tree at the given path. */ void write(const Path& path, AbstractTreeDomain tree, UpdateKind kind) { write_internal( path.begin(), path.end(), std::move(tree), Elements::bottom(), kind); } private: void write_internal( Path::ConstIterator begin, Path::ConstIterator end, AbstractTreeDomain tree, Elements accumulator, UpdateKind kind) { if (begin == end) { switch (kind) { case UpdateKind::Strong: { *this = std::move(tree); prune(std::move(accumulator)); break; } case UpdateKind::Weak: { join_with_internal(tree, accumulator); break; } } return; } accumulator.join_with(elements_); auto path_head = *begin; ++begin; children_.update( [begin, end, &tree, &accumulator, kind](const auto& subtree) { auto new_subtree = subtree; new_subtree.write_internal( begin, end, std::move(tree), std::move(accumulator), kind); return new_subtree; }, path_head); } public: /** * Return the subtree at the given path. * * `propagate` is a function that is called when propagating elements down to * a child. This is mainly used to attach the correct access path to * artificial sources. */ template <typename Propagate> AbstractTreeDomain read(const Path& path, const Propagate& propagate) const { return read_internal(path.begin(), path.end(), propagate); } /** * Return the subtree at the given path. * * Elements are propagated down to children unchanged. */ AbstractTreeDomain read(const Path& path) const { return read_internal( path.begin(), path.end(), [](Elements elements, Path::Element /*path_element*/) -> Elements { return elements; }); } private: template <typename Propagate> AbstractTreeDomain read_internal( Path::ConstIterator begin, Path::ConstIterator end, const Propagate& propagate) const { if (begin == end) { return *this; } auto path_head = *begin; ++begin; auto subtree = children_.at(path_head); if (subtree.is_bottom()) { auto result = propagate(elements_, path_head); for (; begin != end; ++begin) { result = propagate(result, *begin); } return AbstractTreeDomain(result); } subtree.elements_.join_with(propagate(elements_, path_head)); return subtree.read_internal(begin, end, propagate); } public: /** * Return the subtree at the given path. * * Elements are NOT propagated down to children. */ AbstractTreeDomain raw_read(const Path& path) const { return raw_read_internal(path.begin(), path.end()); } private: AbstractTreeDomain raw_read_internal( Path::ConstIterator begin, Path::ConstIterator end) const { if (is_bottom() || begin == end) { return *this; } return children_.at(*begin).raw_read_internal(std::next(begin), end); } public: /** * Iterate on all non-empty elements in the tree. * * When visiting the tree, elements do not include their ancestors. */ void visit(std::function<void(const Path&, const Elements&)> visitor) const { Path path; visit_internal(path, visitor); } private: void visit_internal( Path& path, std::function<void(const Path&, const Elements&)>& visitor) const { if (!elements_.is_bottom()) { visitor(path, elements_); } for (const auto& [path_element, subtree] : children_) { path.append(path_element); subtree.visit_internal(path, visitor); path.pop_back(); } } public: /** * Return the list of all pairs (path, elements) in the tree. * * Elements are returned by reference. * Elements do not contain their ancestors. */ std::vector<std::pair<Path, const Elements&>> elements() const { std::vector<std::pair<Path, const Elements&>> results; visit([&](const Path& path, const Elements& elements) { results.push_back({path, elements}); }); return results; } /* Apply the given function on all elements. */ void map(const std::function<void(Elements&)>& f) { map_internal(f, Elements::bottom()); } private: void map_internal( const std::function<void(Elements&)>& f, Elements accumulator) { if (!elements_.is_bottom()) { f(elements_); elements_.difference_with(accumulator); accumulator.join_with(elements_); } children_.map([&](const AbstractTreeDomain& tree) { auto copy = tree; copy.map_internal(f, accumulator); return copy; }); } public: friend std::ostream& operator<<( std::ostream& out, const AbstractTreeDomain& tree) { return tree.write(out, ""); } private: std::ostream& write(std::ostream& out, const std::string& indent) const { out << "AbstractTree{"; if (is_bottom()) { return out << "}"; } else if (!elements_.is_bottom() && children_.empty()) { return out << elements_ << "}"; } else { auto new_indent = indent + " "; if (!elements_.is_bottom()) { out << "\n" << new_indent << elements_; } for (const auto& [path_element, subtree] : children_) { out << "\n" << new_indent << "`" << show(path_element) << "` -> "; subtree.write(out, new_indent); } return out << "\n" << indent << "}"; } } private: // The abstract elements at this node. // In theory, this includes all the elements from the ancestors. // In practice, we only store new elements. Elements elements_; // The edges to the child nodes. Map children_; }; } // namespace marianatrench