in src/arith/rewrite_simplify.cc [2013:2158]
PrimExpr RewriteSimplifier::Impl::VisitExpr_(const AndNode* op) {
PrimExpr ret = [&]() -> PrimExpr {
// If this extension isn't enabled, just delegate out.
if (!(enabled_extensions_ & kApplyConstraintsToBooleanBranches)) {
return IRMutatorWithAnalyzer::VisitExpr_(op);
}
PrimExpr a = op->a;
PrimExpr b = op->b;
// Alternate which branch is used as the constraint, and which is
// being simplified. Because some sub-analyzers expect their
// constraints to already be simplified, each branch may require
// more than one update. The loop condition allows each branch to
// be visited up to twice, but only performs the second visit if
// necessary.
size_t iterations_since_update = 0;
for (size_t i = 0; i < 4; i++) {
PrimExpr& to_update = (i % 2 == 0) ? a : b;
const PrimExpr& constraint = (i % 2 == 0) ? b : a;
With<ConstraintContext> context(analyzer_, constraint);
PrimExpr updated = VisitExpr(to_update);
if (!to_update.same_as(updated)) {
to_update = updated;
iterations_since_update = 0;
} else {
iterations_since_update++;
if (iterations_since_update >= 2) {
break;
}
}
}
// Only construct a new object if a change has been made.
// Otherwise, follow ExprMutator's convention of returning the
// original object.
if (a.same_as(op->a) && b.same_as(op->b)) {
return GetRef<PrimExpr>(op);
} else {
return And(a, b);
}
}();
op = ret.as<AndNode>();
if (auto const_res = TryConstFold<And>(op->a, op->b)) return const_res.value();
if (auto match = TryMatchLiteralConstraint(ret)) return match.value();
if ((enabled_extensions_ & RewriteSimplifier::kConvertBooleanToAndOfOrs) &&
!recursively_visiting_boolean_) {
return SimplifyAsAndOfOrs(ret, analyzer_);
}
// Pattern var to match any expression
PVar<PrimExpr> x, y, z;
// Pattern var match IntImm
PVar<IntImm> c1, c2, c3;
PVar<PrimExpr> lanes;
if (op->dtype.is_scalable_or_fixed_length_vector()) {
TVM_TRY_REWRITE(broadcast(x, lanes) && broadcast(y, lanes), broadcast(x && y, lanes));
}
auto cfalse = PConst<PrimExpr>(make_const(op->dtype, false));
TVM_TRY_REWRITE(x == y && x != y, cfalse);
TVM_TRY_REWRITE(x != y && x == y, cfalse);
TVM_TRY_REWRITE(x && !x, cfalse);
TVM_TRY_REWRITE(x <= y && y < x, cfalse);
TVM_TRY_REWRITE(y < x && x <= y, cfalse);
TVM_TRY_REWRITE_IF(x < c1 && c2 < x, cfalse, c2.Eval()->value + 1 >= c1.Eval()->value);
TVM_TRY_REWRITE_IF(c2 < x && x < c1, cfalse, c2.Eval()->value + 1 >= c1.Eval()->value);
TVM_TRY_REWRITE_IF((PMatchesOneOf{
x < c1 && c2 <= x,
c2 <= x && x < c1,
x <= c1 && c2 < x,
c2 < x && x <= c1,
}),
cfalse, c2.Eval()->value >= c1.Eval()->value);
TVM_TRY_REWRITE_IF((PMatchesOneOf{
x <= c1 && c2 <= x,
c2 <= x && x <= c1,
}),
cfalse, c2.Eval()->value > c1.Eval()->value);
TVM_TRY_REWRITE((x == c1) && (x == c2), (x == c1) && (c1 == c2));
TVM_TRY_REWRITE(matches_one_of(x == c1 && x != c2, x != c2 && x == c1), x == c1 && c1 != c2);
TVM_TRY_RECURSIVE_REWRITE(matches_one_of(floordiv(x, c2) == c1 && floormod(x, c2) == c3,
floormod(x, c2) == c3 && floordiv(x, c2) == c1),
x == c1 * c2 + c3);
TVM_TRY_RECURSIVE_REWRITE_IF((PMatchesOneOf{
0 <= x - y * c1 && x - y * c1 < c1,
x - y * c1 < c1 && 0 <= x - y * c1,
}),
y == floordiv(x, c1), c1.Eval()->value > 0);
TVM_TRY_RECURSIVE_REWRITE((PMatchesOneOf{
c1 < x - y * c1 && x - y * c1 <= 0,
x - y * c1 < c1 && 0 <= x - y * c1,
}),
y == floordiv(x, c1));
TVM_TRY_RECURSIVE_REWRITE_IF((PMatchesOneOf{
0 <= x + y * c2 && x + y * c2 < c1,
x + y * c2 < c1 && 0 <= x + y * c2,
}),
y == floordiv(x, c1), c2.Eval()->value == -c1.Eval()->value);
TVM_TRY_RECURSIVE_REWRITE_IF(x < c1 && floormod(x, c2) < c3,
x < c1 - c2 + c3 && floormod(x, c2) < c3,
c1.Eval()->value % c2.Eval()->value == 0);
TVM_TRY_RECURSIVE_REWRITE_IF(
x < c1 && floormod(x, c2) < c3, x < c1 - floormod(c1, c2) + c3 && floormod(x, c2) < c3,
(c1.Eval()->value % c2.Eval()->value + c2.Eval()->value) % c2.Eval()->value >
c3.Eval()->value);
TVM_TRY_RECURSIVE_REWRITE_IF(x <= c1 && floormod(x, c2) < c3,
x < c1 + 1 - c2 + c3 && floormod(x, c2) < c3,
(c1.Eval()->value + 1) % c2.Eval()->value == 0);
TVM_TRY_RECURSIVE_REWRITE_IF(
x <= c1 && floormod(x, c2) < c3, x < c1 + 1 - floormod(c1, c2) + c3 && floormod(x, c2) < c3,
(((c1.Eval()->value + 1) % c2.Eval()->value) + c2.Eval()->value) % c2.Eval()->value >
c3.Eval()->value);
TVM_TRY_RECURSIVE_REWRITE(matches_one_of(floordiv(x, c2) == c1 && floormod(x, c2) < c3,
floormod(x, c2) < c3 && floordiv(x, c2) == c1),
c1 * c2 <= x && x < c1 * c2 + c3);
TVM_TRY_RECURSIVE_REWRITE(matches_one_of(floordiv(x, c2) == c1 && floormod(x, c2) <= c3,
floormod(x, c2) <= c3 && floordiv(x, c2) == c1),
c1 * c2 <= x && x <= c1 * c2 + c3);
TVM_TRY_RECURSIVE_REWRITE(matches_one_of(floordiv(x, c2) == c1 && c3 <= floormod(x, c2),
c3 <= floormod(x, c2) && floordiv(x, c2) == c1),
c1 * c2 + c3 <= x && x < (c1 + 1) * c2);
TVM_TRY_RECURSIVE_REWRITE(matches_one_of(floordiv(x, c2) == c1 && c3 < floormod(x, c2),
c3 < floormod(x, c2) && floordiv(x, c2) == c1),
c1 * c2 + c3 < x && x < (c1 + 1) * c2);
TVM_TRY_RECURSIVE_REWRITE(x && (y && z), (x && y) && z);
return ret;
}