mlir/lib/Analysis/Presburger/PresburgerSet.cpp (253 lines of code) (raw):
//===- Set.cpp - MLIR PresburgerSet Class ---------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Analysis/Presburger/PresburgerSet.h"
#include "mlir/Analysis/Presburger/Simplex.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallBitVector.h"
using namespace mlir;
PresburgerSet::PresburgerSet(const IntegerPolyhedron &poly)
: nDim(poly.getNumDimIds()), nSym(poly.getNumSymbolIds()) {
unionPolyInPlace(poly);
}
unsigned PresburgerSet::getNumPolys() const {
return integerPolyhedrons.size();
}
unsigned PresburgerSet::getNumDims() const { return nDim; }
unsigned PresburgerSet::getNumSyms() const { return nSym; }
ArrayRef<IntegerPolyhedron> PresburgerSet::getAllIntegerPolyhedron() const {
return integerPolyhedrons;
}
const IntegerPolyhedron &
PresburgerSet::getIntegerPolyhedron(unsigned index) const {
assert(index < integerPolyhedrons.size() && "index out of bounds!");
return integerPolyhedrons[index];
}
/// Assert that the IntegerPolyhedron and PresburgerSet live in
/// compatible spaces.
static void assertDimensionsCompatible(const IntegerPolyhedron &poly,
const PresburgerSet &set) {
assert(poly.getNumDimIds() == set.getNumDims() &&
"Number of dimensions of the IntegerPolyhedron and PresburgerSet"
"do not match!");
assert(poly.getNumSymbolIds() == set.getNumSyms() &&
"Number of symbols of the IntegerPolyhedron and PresburgerSet"
"do not match!");
}
/// Assert that the two PresburgerSets live in compatible spaces.
static void assertDimensionsCompatible(const PresburgerSet &setA,
const PresburgerSet &setB) {
assert(setA.getNumDims() == setB.getNumDims() &&
"Number of dimensions of the PresburgerSets do not match!");
assert(setA.getNumSyms() == setB.getNumSyms() &&
"Number of symbols of the PresburgerSets do not match!");
}
/// Mutate this set, turning it into the union of this set and the given
/// IntegerPolyhedron.
void PresburgerSet::unionPolyInPlace(const IntegerPolyhedron &poly) {
assertDimensionsCompatible(poly, *this);
integerPolyhedrons.push_back(poly);
}
/// Mutate this set, turning it into the union of this set and the given set.
///
/// This is accomplished by simply adding all the Poly of the given set to this
/// set.
void PresburgerSet::unionSetInPlace(const PresburgerSet &set) {
assertDimensionsCompatible(set, *this);
for (const IntegerPolyhedron &poly : set.integerPolyhedrons)
unionPolyInPlace(poly);
}
/// Return the union of this set and the given set.
PresburgerSet PresburgerSet::unionSet(const PresburgerSet &set) const {
assertDimensionsCompatible(set, *this);
PresburgerSet result = *this;
result.unionSetInPlace(set);
return result;
}
/// A point is contained in the union iff any of the parts contain the point.
bool PresburgerSet::containsPoint(ArrayRef<int64_t> point) const {
return llvm::any_of(integerPolyhedrons, [&](const IntegerPolyhedron &poly) {
return (poly.containsPoint(point));
});
}
PresburgerSet PresburgerSet::getUniverse(unsigned nDim, unsigned nSym) {
PresburgerSet result(nDim, nSym);
result.unionPolyInPlace(IntegerPolyhedron::getUniverse(nDim, nSym));
return result;
}
PresburgerSet PresburgerSet::getEmptySet(unsigned nDim, unsigned nSym) {
return PresburgerSet(nDim, nSym);
}
// Return the intersection of this set with the given set.
//
// We directly compute (S_1 or S_2 ...) and (T_1 or T_2 ...)
// as (S_1 and T_1) or (S_1 and T_2) or ...
//
// If S_i or T_j have local variables, then S_i and T_j contains the local
// variables of both.
PresburgerSet PresburgerSet::intersect(const PresburgerSet &set) const {
assertDimensionsCompatible(set, *this);
PresburgerSet result(nDim, nSym);
for (const IntegerPolyhedron &csA : integerPolyhedrons) {
for (const IntegerPolyhedron &csB : set.integerPolyhedrons) {
IntegerPolyhedron csACopy = csA, csBCopy = csB;
csACopy.mergeLocalIds(csBCopy);
csACopy.append(csBCopy);
if (!csACopy.isEmpty())
result.unionPolyInPlace(csACopy);
}
}
return result;
}
/// Return `coeffs` with all the elements negated.
static SmallVector<int64_t, 8> getNegatedCoeffs(ArrayRef<int64_t> coeffs) {
SmallVector<int64_t, 8> negatedCoeffs;
negatedCoeffs.reserve(coeffs.size());
for (int64_t coeff : coeffs)
negatedCoeffs.emplace_back(-coeff);
return negatedCoeffs;
}
/// Return the complement of the given inequality.
///
/// The complement of a_1 x_1 + ... + a_n x_ + c >= 0 is
/// a_1 x_1 + ... + a_n x_ + c < 0, i.e., -a_1 x_1 - ... - a_n x_ - c - 1 >= 0,
/// since all the variables are constrained to be integers.
static SmallVector<int64_t, 8> getComplementIneq(ArrayRef<int64_t> ineq) {
SmallVector<int64_t, 8> coeffs;
coeffs.reserve(ineq.size());
for (int64_t coeff : ineq)
coeffs.emplace_back(-coeff);
--coeffs.back();
return coeffs;
}
/// Return the set difference b \ s and accumulate the result into `result`.
/// `simplex` must correspond to b.
///
/// In the following, U denotes union, ^ denotes intersection, \ denotes set
/// difference and ~ denotes complement.
/// Let b be the IntegerPolyhedron and s = (U_i s_i) be the set. We want
/// b \ (U_i s_i).
///
/// Let s_i = ^_j s_ij, where each s_ij is a single inequality. To compute
/// b \ s_i = b ^ ~s_i, we partition s_i based on the first violated inequality:
/// ~s_i = (~s_i1) U (s_i1 ^ ~s_i2) U (s_i1 ^ s_i2 ^ ~s_i3) U ...
/// And the required result is (b ^ ~s_i1) U (b ^ s_i1 ^ ~s_i2) U ...
/// We recurse by subtracting U_{j > i} S_j from each of these parts and
/// returning the union of the results. Each equality is handled as a
/// conjunction of two inequalities.
///
/// Note that the same approach works even if an inequality involves a floor
/// division. For example, the complement of x <= 7*floor(x/7) is still
/// x > 7*floor(x/7). Since b \ s_i contains the inequalities of both b and s_i
/// (or the complements of those inequalities), b \ s_i may contain the
/// divisions present in both b and s_i. Therefore, we need to add the local
/// division variables of both b and s_i to each part in the result. This means
/// adding the local variables of both b and s_i, as well as the corresponding
/// division inequalities to each part. Since the division inequalities are
/// added to each part, we can skip the parts where the complement of any
/// division inequality is added, as these parts will become empty anyway.
///
/// As a heuristic, we try adding all the constraints and check if simplex
/// says that the intersection is empty. If it is, then subtracting this Poly is
/// a no-op and we just skip it. Also, in the process we find out that some
/// constraints are redundant. These redundant constraints are ignored.
///
/// b and simplex are callee saved, i.e., their values on return are
/// semantically equivalent to their values when the function is called.
static void subtractRecursively(IntegerPolyhedron &b, Simplex &simplex,
const PresburgerSet &s, unsigned i,
PresburgerSet &result) {
if (i == s.getNumPolys()) {
result.unionPolyInPlace(b);
return;
}
IntegerPolyhedron sI = s.getIntegerPolyhedron(i);
// Below, we append some additional constraints and ids to b. We want to
// rollback b to its initial state before returning, which we will do by
// removing all constraints beyond the original number of inequalities
// and equalities, so we store these counts first.
const unsigned bInitNumIneqs = b.getNumInequalities();
const unsigned bInitNumEqs = b.getNumEqualities();
const unsigned bInitNumLocals = b.getNumLocalIds();
// Similarly, we also want to rollback simplex to its original state.
const unsigned initialSnapshot = simplex.getSnapshot();
// Automatically restore the original state when we return.
auto restoreState = [&]() {
b.removeIdRange(IntegerPolyhedron::IdKind::Local, bInitNumLocals,
b.getNumLocalIds());
b.removeInequalityRange(bInitNumIneqs, b.getNumInequalities());
b.removeEqualityRange(bInitNumEqs, b.getNumEqualities());
simplex.rollback(initialSnapshot);
};
// Find out which inequalities of sI correspond to division inequalities for
// the local variables of sI.
std::vector<llvm::Optional<std::pair<unsigned, unsigned>>> repr(
sI.getNumLocalIds());
sI.getLocalReprs(repr);
// Add sI's locals to b, after b's locals. Also add b's locals to sI, before
// sI's locals.
b.mergeLocalIds(sI);
// Mark which inequalities of sI are division inequalities and add all such
// inequalities to b.
llvm::SmallBitVector isDivInequality(sI.getNumInequalities());
for (Optional<std::pair<unsigned, unsigned>> &maybePair : repr) {
assert(maybePair &&
"Subtraction is not supported when a representation of the local "
"variables of the subtrahend cannot be found!");
b.addInequality(sI.getInequality(maybePair->first));
b.addInequality(sI.getInequality(maybePair->second));
assert(maybePair->first != maybePair->second &&
"Upper and lower bounds must be different inequalities!");
isDivInequality[maybePair->first] = true;
isDivInequality[maybePair->second] = true;
}
unsigned offset = simplex.getNumConstraints();
unsigned numLocalsAdded = b.getNumLocalIds() - bInitNumLocals;
simplex.appendVariable(numLocalsAdded);
unsigned snapshotBeforeIntersect = simplex.getSnapshot();
simplex.intersectIntegerPolyhedron(sI);
if (simplex.isEmpty()) {
/// b ^ s_i is empty, so b \ s_i = b. We move directly to i + 1.
/// We are ignoring level i completely, so we restore the state
/// *before* going to level i + 1.
restoreState();
subtractRecursively(b, simplex, s, i + 1, result);
return;
}
simplex.detectRedundant();
// Equalities are added to simplex as a pair of inequalities.
unsigned totalNewSimplexInequalities =
2 * sI.getNumEqualities() + sI.getNumInequalities();
llvm::SmallBitVector isMarkedRedundant(totalNewSimplexInequalities);
for (unsigned j = 0; j < totalNewSimplexInequalities; j++)
isMarkedRedundant[j] = simplex.isMarkedRedundant(offset + j);
simplex.rollback(snapshotBeforeIntersect);
// Recurse with the part b ^ ~ineq. Note that b is modified throughout
// subtractRecursively. At the time this function is called, the current b is
// actually equal to b ^ s_i1 ^ s_i2 ^ ... ^ s_ij, and ineq is the next
// inequality, s_{i,j+1}. This function recurses into the next level i + 1
// with the part b ^ s_i1 ^ s_i2 ^ ... ^ s_ij ^ ~s_{i,j+1}.
auto recurseWithInequality = [&, i](ArrayRef<int64_t> ineq) {
size_t snapshot = simplex.getSnapshot();
b.addInequality(ineq);
simplex.addInequality(ineq);
subtractRecursively(b, simplex, s, i + 1, result);
b.removeInequality(b.getNumInequalities() - 1);
simplex.rollback(snapshot);
};
// For each inequality ineq, we first recurse with the part where ineq
// is not satisfied, and then add the ineq to b and simplex because
// ineq must be satisfied by all later parts.
auto processInequality = [&](ArrayRef<int64_t> ineq) {
recurseWithInequality(getComplementIneq(ineq));
b.addInequality(ineq);
simplex.addInequality(ineq);
};
// Process all the inequalities, ignoring redundant inequalities and division
// inequalities. The result is correct whether or not we ignore these, but
// ignoring them makes the result simpler.
for (unsigned j = 0, e = sI.getNumInequalities(); j < e; j++) {
if (isMarkedRedundant[j])
continue;
if (isDivInequality[j])
continue;
processInequality(sI.getInequality(j));
}
offset = sI.getNumInequalities();
for (unsigned j = 0, e = sI.getNumEqualities(); j < e; ++j) {
ArrayRef<int64_t> coeffs = sI.getEquality(j);
// For each equality, process the positive and negative inequalities that
// make up this equality. If Simplex found an inequality to be redundant, we
// skip it as above to make the result simpler. Divisions are always
// represented in terms of inequalities and not equalities, so we do not
// check for division inequalities here.
if (!isMarkedRedundant[offset + 2 * j])
processInequality(coeffs);
if (!isMarkedRedundant[offset + 2 * j + 1])
processInequality(getNegatedCoeffs(coeffs));
}
restoreState();
}
/// Return the set difference poly \ set.
///
/// The Poly here is modified in subtractRecursively, so it cannot be a const
/// reference even though it is restored to its original state before returning
/// from that function.
PresburgerSet PresburgerSet::getSetDifference(IntegerPolyhedron poly,
const PresburgerSet &set) {
assertDimensionsCompatible(poly, set);
if (poly.isEmptyByGCDTest())
return PresburgerSet::getEmptySet(poly.getNumDimIds(),
poly.getNumSymbolIds());
PresburgerSet result(poly.getNumDimIds(), poly.getNumSymbolIds());
Simplex simplex(poly);
subtractRecursively(poly, simplex, set, 0, result);
return result;
}
/// Return the complement of this set.
PresburgerSet PresburgerSet::complement() const {
return getSetDifference(
IntegerPolyhedron::getUniverse(getNumDims(), getNumSyms()), *this);
}
/// Return the result of subtract the given set from this set, i.e.,
/// return `this \ set`.
PresburgerSet PresburgerSet::subtract(const PresburgerSet &set) const {
assertDimensionsCompatible(set, *this);
PresburgerSet result(nDim, nSym);
// We compute (U_i t_i) \ (U_i set_i) as U_i (t_i \ V_i set_i).
for (const IntegerPolyhedron &poly : integerPolyhedrons)
result.unionSetInPlace(getSetDifference(poly, set));
return result;
}
/// Two sets S and T are equal iff S contains T and T contains S.
/// By "S contains T", we mean that S is a superset of or equal to T.
///
/// S contains T iff T \ S is empty, since if T \ S contains a
/// point then this is a point that is contained in T but not S.
///
/// Therefore, S is equal to T iff S \ T and T \ S are both empty.
bool PresburgerSet::isEqual(const PresburgerSet &set) const {
assertDimensionsCompatible(set, *this);
return this->subtract(set).isIntegerEmpty() &&
set.subtract(*this).isIntegerEmpty();
}
/// Return true if all the sets in the union are known to be integer empty,
/// false otherwise.
bool PresburgerSet::isIntegerEmpty() const {
// The set is empty iff all of the disjuncts are empty.
for (const IntegerPolyhedron &poly : integerPolyhedrons) {
if (!poly.isIntegerEmpty())
return false;
}
return true;
}
bool PresburgerSet::findIntegerSample(SmallVectorImpl<int64_t> &sample) {
// A sample exists iff any of the disjuncts contains a sample.
for (const IntegerPolyhedron &poly : integerPolyhedrons) {
if (Optional<SmallVector<int64_t, 8>> opt = poly.findIntegerSample()) {
sample = std::move(*opt);
return true;
}
}
return false;
}
PresburgerSet PresburgerSet::coalesce() const {
PresburgerSet newSet = PresburgerSet::getEmptySet(getNumDims(), getNumSyms());
llvm::SmallBitVector isRedundant(getNumPolys());
for (unsigned i = 0, e = integerPolyhedrons.size(); i < e; ++i) {
if (isRedundant[i])
continue;
Simplex simplex(integerPolyhedrons[i]);
// Check whether the polytope of `simplex` is empty. If so, it is trivially
// redundant.
if (simplex.isEmpty()) {
isRedundant[i] = true;
continue;
}
// Check whether `IntegerPolyhedron[i]` is contained in any Poly, that is
// different from itself and not yet marked as redundant.
for (unsigned j = 0, e = integerPolyhedrons.size(); j < e; ++j) {
if (j == i || isRedundant[j])
continue;
if (simplex.isRationalSubsetOf(integerPolyhedrons[j])) {
isRedundant[i] = true;
break;
}
}
}
for (unsigned i = 0, e = integerPolyhedrons.size(); i < e; ++i)
if (!isRedundant[i])
newSet.unionPolyInPlace(integerPolyhedrons[i]);
return newSet;
}
void PresburgerSet::print(raw_ostream &os) const {
os << getNumPolys() << " IntegerPolyhedron:\n";
for (const IntegerPolyhedron &poly : integerPolyhedrons) {
poly.print(os);
os << '\n';
}
}
void PresburgerSet::dump() const { print(llvm::errs()); }