in src/arith/rewrite_simplify.cc [536:723]
PrimExpr RewriteSimplifier::Impl::VisitExpr_(const SubNode* op) {
PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<SubNode>();
if (auto const_res = TryConstFold<Sub>(op->a, op->b)) return const_res.value();
// Pattern var to match any expression
PVar<PrimExpr> x, y, z, b1, b2, s1, s2;
// Pattern var match IntImm
PVar<IntImm> c1, c2, c3;
// Pattern var for lanes in broadcast and ramp
PVar<PrimExpr> lanes;
// Vector rules
if (op->dtype.is_scalable_or_fixed_length_vector()) {
TVM_TRY_REWRITE(ramp(b1, s1, lanes) - ramp(b2, s2, lanes), ramp(b1 - b2, s1 - s2, lanes));
TVM_TRY_REWRITE(ramp(b1, s1, lanes) - broadcast(x, lanes), ramp(b1 - x, s1, lanes));
TVM_TRY_REWRITE(broadcast(x, lanes) - ramp(b1, s1, lanes), ramp(x - b1, 0 - s1, lanes));
TVM_TRY_REWRITE(broadcast(x, lanes) - broadcast(y, lanes), broadcast(x - y, lanes));
}
if (IsIndexType(op->dtype)) {
// Index rules
// cancelation rules
TVM_TRY_REWRITE(matches_one_of((x + y) - y, (y + x) - y), x);
TVM_TRY_REWRITE(matches_one_of(x - (y + x), x - (x + y)), 0 - y);
TVM_TRY_REWRITE(matches_one_of(min(x, y) - y, x - max(y, x)), min(x - y, 0));
TVM_TRY_REWRITE(matches_one_of(x - max(x, y), min(y, x) - y), min(0, x - y));
TVM_TRY_REWRITE(matches_one_of(max(x, y) - y, x - min(y, x)), max(x - y, 0));
TVM_TRY_REWRITE(matches_one_of(x - min(x, y), max(y, x) - y), max(0, x - y));
// mul co-efficient folding
TVM_TRY_REWRITE(x - x, ZeroWithTypeLike(x));
TVM_TRY_REWRITE(matches_one_of(x * y - x, y * x - x), x * (y - 1));
TVM_TRY_REWRITE(matches_one_of(x - y * x, x - x * y), x * (1 - y));
TVM_TRY_REWRITE(matches_one_of(x * y - x * z, y * x - x * z, x * y - z * x, y * x - z * x),
x * (y - z));
// constant cancelation
TVM_TRY_REWRITE((x + c1) - c2, x + (c1 - c2));
TVM_TRY_REWRITE((c1 - x) - (c2 - y), (y - x) + (c1 - c2));
// cancelization rule involving 4 operands
TVM_TRY_REWRITE(
matches_one_of((x + y) - (x + z), (x + y) - (z + x), (y + x) - (z + x), (y + x) - (x + z)),
y - z);
TVM_TRY_REWRITE(matches_one_of(min(x + y, z) - x, min(y + x, z) - x), min(y, z - x));
TVM_TRY_REWRITE(matches_one_of(min(z, x + y) - x, min(z, y + x) - x), min(z - x, y));
TVM_TRY_REWRITE(matches_one_of(max(x + y, z) - x, max(y + x, z) - x), max(y, z - x));
TVM_TRY_REWRITE(matches_one_of(max(z, x + y) - x, max(z, y + x) - x), max(z - x, y));
TVM_TRY_REWRITE(matches_one_of(x - min(x + y, z), x - min(y + x, z)), max(0 - y, x - z));
TVM_TRY_REWRITE(matches_one_of(x - min(z, x + y), x - min(z, y + x)), max(x - z, 0 - y));
TVM_TRY_REWRITE(matches_one_of(x - max(x + y, z), x - max(y + x, z)), min(0 - y, x - z));
TVM_TRY_REWRITE(matches_one_of(x - max(z, x + y), x - max(z, y + x)), min(x - z, 0 - y));
TVM_TRY_REWRITE(min(x, y) - min(y, x), ZeroWithTypeLike(x));
TVM_TRY_REWRITE(max(x, y) - max(y, x), ZeroWithTypeLike(x));
TVM_TRY_REWRITE_IF(matches_one_of(min(b1, b2) - min(s1, s2), min(b1, b2) - min(s2, s1)),
b1 - s1, CanProveEqual(((b1 - s1) - (b2 - s2)).Eval(), 0));
TVM_TRY_REWRITE_IF(matches_one_of(max(b1, b2) - max(s1, s2), max(b1, b2) - max(s2, s1)),
b1 - s1, CanProveEqual(((b1 - s1) - (b2 - s2)).Eval(), 0));
// DivMod rules
// trucdiv
// NOTE: c*(x/c) + x % c == x is true all division mode.
TVM_TRY_REWRITE_IF(x - truncdiv(x, c1) * c1, truncmod(x, c1), c1.Eval()->value != 0);
TVM_TRY_REWRITE_IF(truncdiv(x, c1) * c1 - x, 0 - truncmod(x, c1), c1.Eval()->value != 0);
TVM_TRY_REWRITE_IF(x - (truncdiv(x + y, c1)) * c1, truncmod(x + y, c1) - y,
c1.Eval()->value != 0);
TVM_TRY_REWRITE_IF((truncdiv(x + y, c1)) * c1 - x, y - truncmod(x + y, c1),
c1.Eval()->value != 0);
TVM_TRY_REWRITE_IF(x - truncdiv(x - y, c1) * c1, truncmod(x - y, c1) + y,
c1.Eval()->value != 0);
TVM_TRY_REWRITE_IF(truncdiv(x - y, c1) * c1 - x, 0 - truncmod(x - y, c1) - y,
c1.Eval()->value != 0);
TVM_TRY_REWRITE_IF(
x * c2 - truncdiv(x, c1) * c3, truncmod(x, c1) * c2,
c1.Eval()->value != 0 && c3.Eval()->value == c1.Eval()->value * c2.Eval()->value);
TVM_TRY_REWRITE_IF(
truncdiv(x, c1) * c3 - x * c2, 0 - truncmod(x, c1) * c2,
c1.Eval()->value != 0 && c3.Eval()->value == c1.Eval()->value * c2.Eval()->value);
TVM_TRY_REWRITE_IF(
x * c2 - truncdiv(x + y, c1) * c3, (truncmod(x + y, c1) - y) * c2,
c1.Eval()->value != 0 && c3.Eval()->value == c1.Eval()->value * c2.Eval()->value);
TVM_TRY_REWRITE_IF(
truncdiv(x + y, c1) * c3 - x * c2, (y - truncmod(x + y, c1)) * c2,
c1.Eval()->value != 0 && c3.Eval()->value == c1.Eval()->value * c2.Eval()->value);
TVM_TRY_REWRITE_IF(
x * c2 - truncdiv(x - y, c1) * c3, (truncmod(x - y, c1) + y) * c2,
c1.Eval()->value != 0 && c3.Eval()->value == c1.Eval()->value * c2.Eval()->value);
TVM_TRY_REWRITE_IF(
truncdiv(x - y, c1) * c3 - x * c2, (0 - truncmod(x - y, c1) - y) * c2,
c1.Eval()->value != 0 && c3.Eval()->value == c1.Eval()->value * c2.Eval()->value);
// Proof in the case of floordiv, need positive condition.
// let x = a * c3 + r
// (x + c1) / c3 - x / c3 => (r + c1) / c3
// NOTE: the use of floormod(c2, c3) was intentional to simplify the const.
TVM_TRY_REWRITE_IF(truncdiv(x + c1, c3) - truncdiv(x + c2, c3),
truncdiv(truncmod(x + floormod(c2, c3), c3) + (c1 - c2), c3),
CanProveGreaterEqual(x.Eval(), -c2.Eval()->value) &&
c1.Eval()->value >= c2.Eval()->value && c3.Eval()->value > 0);
TVM_TRY_REWRITE_IF(
truncdiv(x + c1, c3) - truncdiv(x, c3), truncdiv(truncmod(x, c3) + c1, c3),
CanProveGreaterEqual(x.Eval(), 0) && c1.Eval()->value >= 0 && c3.Eval()->value > 0);
// floordiv
TVM_TRY_REWRITE_IF(x - floordiv(x, c1) * c1, floormod(x, c1), c1.Eval()->value != 0);
TVM_TRY_REWRITE_IF(floordiv(x, c1) * c1 - x, 0 - floormod(x, c1), c1.Eval()->value != 0);
TVM_TRY_REWRITE_IF(x - floordiv(x + y, c1) * c1, floormod(x + y, c1) - y,
c1.Eval()->value != 0);
TVM_TRY_REWRITE_IF(floordiv(x + y, c1) * c1 - x, y - floormod(x + y, c1),
c1.Eval()->value != 0);
TVM_TRY_REWRITE_IF(x - floordiv(x - y, c1) * c1, floormod(x - y, c1) + y,
c1.Eval()->value != 0);
TVM_TRY_REWRITE_IF(floordiv(x - y, c1) * c1 - x, 0 - floormod(x - y, c1) - y,
c1.Eval()->value != 0);
TVM_TRY_RECURSIVE_REWRITE(
floordiv(x + c1, 2) - floordiv(x + c2, 2),
floormod(x, 2) * (floormod(c1, 2) - floormod(c2, 2)) + (floordiv(c1, 2) - floordiv(c2, 2)));
TVM_TRY_RECURSIVE_REWRITE(floordiv(x, 2) - floordiv(x + c2, 2),
floormod(x, 2) * (0 - floormod(c2, 2)) - floordiv(c2, 2));
TVM_TRY_RECURSIVE_REWRITE(floordiv(x + c1, 2) - floordiv(x, 2),
floormod(x, 2) * floormod(c1, 2) + floordiv(c1, 2));
TVM_TRY_REWRITE_IF(
x * c2 - floordiv(x, c1) * c3, floormod(x, c1) * c2,
c1.Eval()->value != 0 && c3.Eval()->value == c1.Eval()->value * c2.Eval()->value);
TVM_TRY_REWRITE_IF(
floordiv(x, c1) * c3 - x * c2, 0 - floormod(x, c1) * c2,
c1.Eval()->value != 0 && c3.Eval()->value == c1.Eval()->value * c2.Eval()->value);
TVM_TRY_REWRITE_IF(
x * c2 - floordiv(x + y, c1) * c3, (floormod(x + y, c1) - y) * c2,
c1.Eval()->value != 0 && c3.Eval()->value == c1.Eval()->value * c2.Eval()->value);
TVM_TRY_REWRITE_IF(
floordiv(x + y, c1) * c3 - x * c2, (y - floormod(x + y, c1)) * c2,
c1.Eval()->value != 0 && c3.Eval()->value == c1.Eval()->value * c2.Eval()->value);
TVM_TRY_REWRITE_IF(
x * c2 - floordiv(x - y, c1) * c3, (floormod(x - y, c1) + y) * c2,
c1.Eval()->value != 0 && c3.Eval()->value == c1.Eval()->value * c2.Eval()->value);
TVM_TRY_REWRITE_IF(
floordiv(x - y, c1) * c3 - x * c2, (0 - floormod(x - y, c1) - y) * c2,
c1.Eval()->value != 0 && c3.Eval()->value == c1.Eval()->value * c2.Eval()->value);
TVM_TRY_RECURSIVE_REWRITE(floordiv(x + 1, 2) - floormod(x, 2), floordiv(x, 2));
TVM_TRY_REWRITE_IF(floordiv(x + c1, c3) - floordiv(x + c2, c3),
floordiv(floormod(x + floormod(c2, c3), c3) + (c1 - c2), c3),
c3.Eval()->value > 0);
TVM_TRY_REWRITE_IF(floordiv(x + c1, c3) - floordiv(x, c3), floordiv(floormod(x, c3) + c1, c3),
c3.Eval()->value > 0);
// canonicalization rule
// will try rewrite again after canonicalization.
TVM_TRY_REWRITE(x - c1, x + (0 - c1));
TVM_TRY_RECURSIVE_REWRITE((x + c1) - y, (x - y) + c1);
TVM_TRY_RECURSIVE_REWRITE(x - (y + c1), (x - y) + (0 - c1));
TVM_TRY_RECURSIVE_REWRITE(x - (y - z), (x + z) - y);
TVM_TRY_RECURSIVE_REWRITE(x - y * c1, x + y * (0 - c1));
} else {
// Cancellation rules. Deliberately off of the integer path, to
// avoid introducing checks on the side effects for the fast path.
//
// These simplifications do not preserve NaN/Inf that may occur in
// the inputs. For IEEE floats, `NaN - NaN` is `NaN`, and does
// not cancel out. However, since models should not encounter NaN
// in the first place, this allows better simplification for the
// supported path.
TVM_TRY_REWRITE_IF(x - x, ZeroWithTypeLike(x),
SideEffect(x.Eval()) <= CallEffectKind::kReadState);
TVM_TRY_REWRITE_IF((x + y) - y, x, SideEffect(y.Eval()) <= CallEffectKind::kReadState);
TVM_TRY_REWRITE_IF((x + y) - x, y, SideEffect(x.Eval()) <= CallEffectKind::kReadState);
TVM_TRY_REWRITE_IF(x - (y + x), 0 - y, SideEffect(x.Eval()) <= CallEffectKind::kReadState);
TVM_TRY_REWRITE_IF(x - (x + y), 0 - y, SideEffect(x.Eval()) <= CallEffectKind::kReadState);
}
// condition rules.
TVM_TRY_REWRITE(select(x, b1, b2) - select(x, s1, s2), select(x, b1 - s1, b2 - s2));
TVM_TRY_REWRITE(select(x, y, z) - z, select(x, y - z, ZeroWithTypeLike(z)));
TVM_TRY_REWRITE(select(x, y, z) - y, select(x, ZeroWithTypeLike(y), z - y));
return ret;
}