include/FiniteAbstractDomain.h (268 lines of code) (raw):
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
#pragma once
#include <bitset>
#include <cstddef>
#include <functional>
#include <initializer_list>
#include <ostream>
#include <sstream>
#include <type_traits>
#include <unordered_map>
#include <utility>
#include "AbstractDomain.h"
namespace sparta {
/*
* This is the general interface for arbitrary encodings of a lattice. 'Element'
* is the type of the symbolic names for the lattice elements and 'Encoding' is
* the type of the actual encoding.
*/
template <typename Element, typename Encoding>
class LatticeEncoding {
public:
virtual ~LatticeEncoding() {}
virtual Encoding encode(const Element& element) const = 0;
virtual Element decode(const Encoding& encoding) const = 0;
virtual bool is_bottom(const Encoding& element) const = 0;
virtual bool is_top(const Encoding& element) const = 0;
virtual bool equals(const Encoding& x, const Encoding& y) const = 0;
virtual bool leq(const Encoding& x, const Encoding& y) const = 0;
virtual Encoding join(const Encoding& x, const Encoding& y) const = 0;
virtual Encoding meet(const Encoding& x, const Encoding& y) const = 0;
virtual Encoding bottom() const = 0;
virtual Encoding top() const = 0;
};
/*
* Example usage:
*
* Encoding the following lattice using bit vectors:
*
* TOP
* / \
* A B
* \ /
* BOTTOM
*
* enum Elements {BOTTOM, A, B, TOP};
* using Lattice = BitVectorLattice<Elements, 4, std::hash<int>>;
* Lattice lattice({BOTTOM, A, B, TOP},
* {{BOTTOM, A}, {BOTTOM, B}, {A, TOP}, {B, TOP}});
* using Domain =
* FiniteAbstractDomain<Elements, Lattice, Lattice::Encoding, &lattice>;
* ...
* Domain a(A), b(B);
* Domain x = a.join(b);
* ...
*
* Note: since 'lattice' is a template argument, this object must be statically
* defined, for example as a global variable. The lattice is instantiated just
* once at startup time.
*
*/
template <typename Element,
typename Lattice,
typename Encoding,
Lattice* lattice>
class FiniteAbstractDomain final
: public AbstractDomain<
FiniteAbstractDomain<Element, Lattice, Encoding, lattice>> {
public:
~FiniteAbstractDomain() {
// The destructor is the only method that is guaranteed to be created when a
// class template is instantiated. This is a good place to perform all the
// sanity checks on the template parameters.
static_assert(
std::is_base_of<LatticeEncoding<Element, Encoding>, Lattice>::value,
"Lattice doesn't derive from LatticeEncoding");
}
/*
* A default constructor is required in the AbstractDomain specification.
*/
FiniteAbstractDomain() : m_encoding(lattice->top()) {}
explicit FiniteAbstractDomain(const Element& element)
: m_encoding(lattice->encode(element)) {}
Element element() const { return lattice->decode(m_encoding); }
bool is_bottom() const override { return lattice->is_bottom(m_encoding); }
bool is_top() const override { return lattice->is_top(m_encoding); }
bool leq(const FiniteAbstractDomain& other) const override {
return lattice->leq(m_encoding, other.m_encoding);
}
bool equals(const FiniteAbstractDomain& other) const override {
return lattice->equals(m_encoding, other.m_encoding);
}
void set_to_bottom() override { m_encoding = lattice->bottom(); }
void set_to_top() override { m_encoding = lattice->top(); }
void join_with(const FiniteAbstractDomain& other) override {
m_encoding = lattice->join(m_encoding, other.m_encoding);
}
void widen_with(const FiniteAbstractDomain& other) override {
join_with(other);
}
void meet_with(const FiniteAbstractDomain& other) override {
m_encoding = lattice->meet(m_encoding, other.m_encoding);
}
void narrow_with(const FiniteAbstractDomain& other) override {
meet_with(other);
}
static FiniteAbstractDomain bottom() {
return FiniteAbstractDomain(lattice->bottom());
}
static FiniteAbstractDomain top() {
return FiniteAbstractDomain(lattice->top());
}
private:
FiniteAbstractDomain(const Encoding& encoding) : m_encoding(encoding) {}
Encoding m_encoding;
};
} // namespace sparta
template <typename Element,
typename Lattice,
typename Encoding,
Lattice* lattice>
inline std::ostream& operator<<(
std::ostream& o,
const typename sparta::
FiniteAbstractDomain<Element, Lattice, Encoding, lattice>& x) {
o << x.element();
return o;
}
namespace sparta {
namespace fad_impl {
/*
* Our encoding of lattices is based on the following paper that proposes an
* efficient representation based on bit vectors:
*
* H. Aït-Kaci, R. Boyer, P. Lincoln, R. Nasr. Efficient implementation of
* lattice operations. In ACM Transactions on Programming Languages and
* Systems (TOPLAS), Volume 11, Issue 1, Jan. 1989, pages 115-146.
*
* The approach described in the paper only works with the Meet operation. The
* idea is to represent the Hasse diagram of a lattice using a Boolean matrix,
* as shown below:
*
* d a b c d
* / \ a 0 0 0 0
* b c b 1 0 0 0
* \ / c 1 0 0 0
* a d 0 1 1 0
*
* This matrix represents the "immediately greater than" relation in the
* lattice. The technique consists of computing the reflexive and transitive
* closure of that relation. Then, an element can be encoded by its
* corresponding row (i.e., a bit vector) in the resulting matrix. Computing the
* Meet simply amounts to performing the bitwise And operation on the bit
* vector representation. For the example above that gives:
*
* Reflexive-transitive closure:
*
* a b c d b Meet c = 1100 & 1010
* a 1 0 0 0 = 1000
* b 1 1 0 0 = a
* c 1 0 1 0
* d 1 1 1 1
*
* In order to compute the Join, we apply the same technique to the opposite
* lattice, i.e., the lattice in which the order relation has been reversed and
* the Top and Bottom elements have been swapped. The opposite lattice and the
* corresponding Boolean matrix are constructed as follows:
*
* a a b c d
* / \ a 0 1 1 0
* b c b 0 0 0 1
* \ / c 0 0 0 1
* d d 0 0 0 0
*
* It can be easily seen that the Meet in the opposite lattice is exactly the
* Join in the original lattice.
*
* Reflexive-transitive closure:
*
* a b c d b Meet c = 0101 & 0011
* a 1 1 1 1 = 0001
* b 0 1 0 1 = d
* c 0 0 1 1 = b Join c in the original
* d 0 0 0 1 lattice
*
* The template parameter 'construct_opposite_lattice' specifies the lattice to
* consider for the encoding.
*
* Note that constructing this representation has cubic time complexity in the
* number of elements of the lattice. Since the construction is done only once
* at startup time and finite lattices built this way are usually small, this
* should not be a problem in practice.
*
*/
template <typename Element,
size_t cardinality,
bool construct_opposite_lattice,
typename Hash,
typename Equal>
class BitVectorSemiLattice final {
public:
// The size of a bitset structure is a compile-time constant, hence the need
// for the 'cardinality' parameter.
using Encoding = std::bitset<cardinality>;
BitVectorSemiLattice() = delete;
/*
* In order to construct the bit vector representation, the user provides the
* complete set of elements in the lattice (including the Top and Bottom
* elements) as well as the Hasse diagram of the partial order relation.
*/
BitVectorSemiLattice(
std::initializer_list<Element> elements,
std::initializer_list<std::pair<Element, Element>> hasse_diagram) {
RUNTIME_CHECK(elements.size() == cardinality,
invalid_argument()
<< argument_name("elements")
<< operation_name("BitVectorSemiLattice()"));
// We assign each element of the lattice an index, so that we can construct
// the Boolean matrix.
Element index_to_element[cardinality];
std::unordered_map<Element, size_t, Hash, Equal> element_to_index;
std::copy(elements.begin(), elements.end(), index_to_element);
for (size_t i = 0; i < cardinality; ++i) {
element_to_index[index_to_element[i]] = i;
}
// We populate the Boolean matrix by traversing the Hasse diagram of the
// partial order.
Encoding matrix[cardinality];
for (auto pair : hasse_diagram) {
// The Hasse diagram provided by the user describes the partial order in
// the original lattice. We need to normalize the representation when the
// opposite lattice is considered.
if (construct_opposite_lattice) {
std::swap(pair.first, pair.second);
}
// If y is immediately greater than x in the partial order considered,
// then matrix[y][x] = 1.
auto x_it = element_to_index.find(pair.first);
auto y_it = element_to_index.find(pair.second);
RUNTIME_CHECK(x_it != element_to_index.end() &&
y_it != element_to_index.end(),
internal_error());
matrix[y_it->second][x_it->second] = true;
}
// We first compute the reflexive closure of the "immediately greater than"
// relation in the lattice considered.
for (size_t i = 0; i < cardinality; ++i) {
matrix[i][i] = true;
}
// Then we compute the transitive closure of the "immediately greater than"
// relation in the lattice considered, using Warshall's algorithm.
for (size_t k = 0; k < cardinality; ++k) {
for (size_t i = 0; i < cardinality; ++i) {
for (size_t j = 0; j < cardinality; ++j) {
matrix[i][j] = matrix[i][j] || (matrix[i][k] && matrix[k][j]);
}
}
}
// The last step is to assign a bit vector representation to each element in
// the lattice considered, i.e. the corresponding row in the Boolean matrix.
// We also maintain a reverse table for decoding purposes.
for (size_t i = 0; i < cardinality; ++i) {
Element element = index_to_element[i];
Encoding encoding = matrix[i];
m_element_to_encoding[element] = encoding;
m_encoding_to_element[encoding] = element;
// We identify the Bottom and Top elements on the fly.
if (is_bottom(encoding)) {
m_bottom = encoding;
}
if (is_top(encoding)) {
m_top = encoding;
}
}
// Make sure that we obtain a semi-lattice.
sanity_check(&index_to_element[0]);
}
Encoding encode(const Element& element) const {
auto it = m_element_to_encoding.find(element);
RUNTIME_CHECK(it != m_element_to_encoding.end(), undefined_operation());
return it->second;
}
Element decode(const Encoding& encoding) const {
auto it = m_encoding_to_element.find(encoding);
RUNTIME_CHECK(it != m_encoding_to_element.end(), undefined_operation());
return it->second;
}
bool is_bottom(const Encoding& x) const {
// In the lower semi-lattice representation the Bottom element is the unique
// bit vector that has only one bit set to 1, whereas in the opposite
// semi-lattice it has all its bits set to 1.
return construct_opposite_lattice ? x.all() : (x.count() == 1);
}
bool is_top(const Encoding& x) const {
// The Top element is defined as the the dual of Bottom.
return construct_opposite_lattice ? (x.count() == 1) : x.all();
}
Encoding bottom() const { return m_bottom; }
Encoding top() const { return m_top; }
private:
// This sanity check verifies that the bitwise And of any two pairs of
// elements (i.e., the Meet or the Join of those elements depending on the
// lattice considered) corresponds to an actual element in the lattice.
// In other words, this procedure makes sure that the input Hasse diagram
// defines a semi-lattice.
void sanity_check(Element* index_to_element) {
// We count the number of bit vectors that have all their bits set to one.
size_t all_bits_are_set = 0;
// We count the number of bit vectors that have only one bit set to one.
size_t one_bit_is_set = 0;
for (size_t i = 0; i < cardinality; ++i) {
Encoding x = m_element_to_encoding[index_to_element[i]];
if (x.all()) {
++all_bits_are_set;
}
if (x.count() == 1) {
++one_bit_is_set;
}
for (size_t j = 0; j < cardinality; ++j) {
Encoding y = m_element_to_encoding[index_to_element[j]];
RUNTIME_CHECK(m_encoding_to_element.find(x & y) !=
m_encoding_to_element.end(),
internal_error());
}
}
RUNTIME_CHECK(all_bits_are_set == 1 && one_bit_is_set == 1,
internal_error()
<< error_msg("Missing or duplicate extremal element"));
}
std::unordered_map<Element, Encoding, Hash, Equal> m_element_to_encoding;
std::unordered_map<Encoding, Element> m_encoding_to_element;
Encoding m_bottom;
Encoding m_top;
};
} // namespace fad_impl
/*
* A lattice maintains two semi-lattices internally, always use opposite
* semi-lattice representation and calculate corresponding lower semi-lattice
* when needed
*/
template <typename Element,
size_t cardinality,
typename Hash = std::hash<Element>,
typename Equal = std::equal_to<Element>>
class BitVectorLattice final
: public LatticeEncoding<Element, std::bitset<cardinality>> {
public:
using Encoding = std::bitset<cardinality>;
~BitVectorLattice() {
// The destructor is the only method that is guaranteed to be created when a
// class template is instantiated. This is a good place to perform all the
// sanity checks on the template parameters.
static_assert(std::is_default_constructible<Element>::value,
"Element is not default constructible");
static_assert(std::is_copy_constructible<Element>::value,
"Element is not copy constructible");
static_assert(std::is_copy_assignable<Element>::value,
"Element is not copy assignable");
}
BitVectorLattice() = delete;
BitVectorLattice(
std::initializer_list<Element> elements,
std::initializer_list<std::pair<Element, Element>> hasse_diagram)
: m_lower_semi_lattice(elements, hasse_diagram),
m_opposite_semi_lattice(elements, hasse_diagram) {}
Encoding encode(const Element& element) const override {
// In a standard fixpoint computation the Join is by far the dominant
// operation. Hence, we favor the opposite semi-lattice encoding whenever we
// construct a domain element.
return m_opposite_semi_lattice.encode(element);
}
// Default use opposite semi-lattice for decoding.
Element decode(const Encoding& encoding) const override {
return m_opposite_semi_lattice.decode(encoding);
}
Element decode_lower(const Encoding& encoding) const {
return m_lower_semi_lattice.decode(encoding);
}
bool is_bottom(const Encoding& x) const override { return x.all(); }
bool is_top(const Encoding& x) const override { return x.count() == 1; }
bool equals(const Encoding& x, const Encoding& y) const override {
return x == y;
}
bool leq(const Encoding& x, const Encoding& y) const override {
return (x & y) == y;
}
Encoding join(const Encoding& x, const Encoding& y) const override {
return x & y;
}
Encoding meet(const Encoding& x, const Encoding& y) const override {
// In order to perform the Meet, we need to calculate corresponding lower
// semi-lattice encoding, and switch back to opposite semi-lattice encoding
// before returning.
auto x_lower = get_lower_encoding(x);
auto y_lower = get_lower_encoding(y);
Encoding lower_encoding = x_lower & y_lower;
return get_opposite_encoding(lower_encoding);
}
Encoding bottom() const override { return m_opposite_semi_lattice.bottom(); }
Encoding top() const override { return m_opposite_semi_lattice.top(); }
private:
Encoding get_lower_encoding(const Encoding& x) const {
const Element& element = decode(x);
return m_lower_semi_lattice.encode(element);
}
Encoding get_opposite_encoding(const Encoding& x) const {
const Element& element = decode_lower(x);
return m_opposite_semi_lattice.encode(element);
}
fad_impl::BitVectorSemiLattice<Element,
cardinality,
/* construct_opposite_lattice */ false,
Hash,
Equal>
m_lower_semi_lattice;
fad_impl::BitVectorSemiLattice<Element,
cardinality,
/* construct_opposite_lattice */ true,
Hash,
Equal>
m_opposite_semi_lattice;
};
} // namespace sparta