in src/arith/rewrite_simplify.cc [764:922]
PrimExpr RewriteSimplifier::Impl::VisitExpr_(const DivNode* op) {
PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<DivNode>();
if (auto const_res = TryConstFold<Div>(op->a, op->b)) return const_res.value();
// Pattern var to match any expression
PVar<PrimExpr> x, y, z, b1;
// Pattern var match IntImm
PVar<IntImm> c1, c2, c3;
// Pattern var for lanes in broadcast and ramp
PVar<PrimExpr> lanes;
// x / 2.0 = x * 0.5
if (const FloatImmNode* ptr = op->b.as<FloatImmNode>()) {
ICHECK(op->dtype.is_float() || op->dtype.is_bfloat16() ||
datatype::Registry::Global()->GetTypeRegistered(op->dtype.code()));
return op->a * make_const(op->b.dtype(), 1.0 / ptr->value);
}
// Vector rules
if (op->dtype.is_scalable_or_fixed_length_vector()) {
// NOTE: use div as the pattern also works for float.
TVM_TRY_REWRITE(div(broadcast(x, lanes), broadcast(y, lanes)), broadcast(div(x, y), lanes));
// ramp / bcast
if ((div(ramp(b1, c1, lanes), broadcast(c2, lanes))).Match(ret)) {
int64_t c1val = c1.Eval()->value;
int64_t c2val = c2.Eval()->value;
ICHECK(c2val != 0) << "division by zero";
if (c1val % c2val == 0) {
return ramp(div(b1, c2), div(c1, c2), lanes).Eval();
}
// If all possible indices in ramp are the same.
if (CanProveGreaterEqual(b1.Eval(), 0) && !arith::ExtractVscaleFactor(lanes.Eval())) {
ModularSet bmod = analyzer_->modular_set(b1.Eval());
int64_t ramp_min = bmod->base / c2val;
auto lanes_int = lanes.Eval().as<IntImmNode>()->value;
int64_t ramp_max = (bmod->base + (lanes_int - 1) * c1val) / c2val;
if (bmod->coeff % c2val == 0 && ramp_min == ramp_max) {
return broadcast(div(b1, c2), lanes).Eval();
}
}
}
}
if (IsIndexType(op->dtype)) {
// Be-aware of the division rules:
// We adopt the default C division uses truncation instead of floordiv.
// This means most rules need to check non-negativeness of the operands.
// TryConstFold doesn't work for negative cases because it is also used by legacy
// parts of tvm which still assume euclidean div. In this simplifier we assume that the division
// is truncated, so perform const folding again.
// NOTE: trunc div required
if (truncdiv(c1, c2).Match(ret)) {
int64_t c1val = c1.Eval()->value;
int64_t c2val = c2.Eval()->value;
return make_const(op->dtype, truncdiv(c1val, c2val));
}
// while it is always true for trunc div
// restrict to common case(positive div)
TVM_TRY_REWRITE_IF(truncdiv(truncdiv(x, c1), c2), truncdiv(x, c1 * c2),
c1.Eval()->value > 0 && c2.Eval()->value > 0);
TVM_TRY_REWRITE_IF(truncdiv(truncdiv(x, c1) + c2, c3), truncdiv(x + c1 * c2, c1 * c3),
c1.Eval()->value > 0 && c2.Eval()->value >= 0 && c3.Eval()->value > 0 &&
CanProveGreaterEqual(x.Eval(), 0));
if (truncdiv(x * c1, c2).Match(ret)) {
int64_t c1val = c1.Eval()->value;
int64_t c2val = c2.Eval()->value;
if (c1val > 0 && c2val > 0) {
if (c1val % c2val == 0) return (x * truncdiv(c1, c2)).Eval();
if (c2val % c1val == 0) return truncdiv(x, truncdiv(c2, c1)).Eval();
}
}
TVM_TRY_REWRITE(truncdiv(x, x), OneWithTypeLike(x));
TVM_TRY_REWRITE(matches_one_of(truncdiv(x * c1, x), truncdiv(c1 * x, x)), c1);
// Rules involving 2-operands.
TVM_TRY_REWRITE_IF(truncdiv(x * c1 + y, c2), x * truncdiv(c1, c2) + truncdiv(y, c2),
c1.Eval()->value >= 0 && c2.Eval()->value > 0 &&
c1.Eval()->value % c2.Eval()->value == 0 &&
CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual(y.Eval(), 0));
TVM_TRY_REWRITE_IF(truncdiv(min(x * c1, y), c2), min(x * truncdiv(c1, c2), truncdiv(y, c2)),
c1.Eval()->value >= 0 && c2.Eval()->value > 0 &&
c1.Eval()->value % c2.Eval()->value == 0 &&
CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual(y.Eval(), 0));
TVM_TRY_REWRITE_IF(truncdiv(max(x * c1, y), c2), max(x * truncdiv(c1, c2), truncdiv(y, c2)),
c1.Eval()->value >= 0 && c2.Eval()->value > 0 &&
c1.Eval()->value % c2.Eval()->value == 0 &&
CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual(y.Eval(), 0));
TVM_TRY_REWRITE_IF(truncdiv(y + x * c1, c2), truncdiv(y, c2) + x * truncdiv(c1, c2),
c1.Eval()->value >= 0 && c2.Eval()->value > 0 &&
c1.Eval()->value % c2.Eval()->value == 0 &&
CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual(y.Eval(), 0));
TVM_TRY_REWRITE_IF(truncdiv(min(y, x * c1), c2), min(truncdiv(y, c2), x * truncdiv(c1, c2)),
c1.Eval()->value >= 0 && c2.Eval()->value > 0 &&
c1.Eval()->value % c2.Eval()->value == 0 &&
CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual(y.Eval(), 0));
TVM_TRY_REWRITE_IF(truncdiv(max(y, x * c1), c2), max(truncdiv(y, c2), x * truncdiv(c1, c2)),
c1.Eval()->value >= 0 && c2.Eval()->value > 0 &&
c1.Eval()->value % c2.Eval()->value == 0 &&
CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual(y.Eval(), 0));
// Rules involving 3-operands.
TVM_TRY_REWRITE_IF(
truncdiv(x * c1 + y + z, c2), x * truncdiv(c1, c2) + truncdiv(y + z, c2),
c1.Eval()->value >= 0 && c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0 &&
CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual((y + z).Eval(), 0));
TVM_TRY_REWRITE_IF(
truncdiv(x * c1 - y + z, c2), x * truncdiv(c1, c2) + truncdiv(z - y, c2),
c1.Eval()->value >= 0 && c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0 &&
CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual((z - y).Eval(), 0));
TVM_TRY_REWRITE_IF(
truncdiv(x * c1 + y - z, c2), x * truncdiv(c1, c2) + truncdiv(y - z, c2),
c1.Eval()->value >= 0 && c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0 &&
CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual((y - z).Eval(), 0));
TVM_TRY_REWRITE_IF(
truncdiv(y + x * c1 + z, c2), x * truncdiv(c1, c2) + truncdiv(y + z, c2),
c1.Eval()->value > 0 && c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0 &&
CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual((y + z).Eval(), 0));
TVM_TRY_REWRITE_IF(truncdiv(x + c1, c2), truncdiv(x, c2) + truncdiv(c1, c2),
c1.Eval()->value > 0 && c2.Eval()->value > 0 &&
c1.Eval()->value % c2.Eval()->value == 0 &&
CanProveGreaterEqual(x.Eval(), 0));
TVM_TRY_REWRITE_IF(matches_one_of(truncdiv(x + y, x), truncdiv(y + x, x)), truncdiv(y, x) + 1,
CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual(y.Eval(), 0));
TVM_TRY_REWRITE_IF(
matches_one_of(truncdiv((x + y) + z, x), truncdiv((y + x) + z, x), truncdiv(y + (z + x), x),
truncdiv(y + (x + z), x)),
truncdiv(y + z, x) + 1,
CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual((y + z).Eval(), 0));
TVM_TRY_REWRITE_IF(matches_one_of(truncdiv(x * y, y), truncdiv(y * x, y)), x,
CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual(y.Eval(), 0));
TVM_TRY_REWRITE_IF(matches_one_of(truncdiv(x * z + y, z), truncdiv(z * x + y, z)),
x + truncdiv(y, z),
CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual(y.Eval(), 0) &&
CanProveGreaterEqual(z.Eval(), 0));
TVM_TRY_REWRITE_IF(matches_one_of(truncdiv(y + x * z, z), truncdiv(y + z * x, z)),
truncdiv(y, z) + x,
CanProveGreaterEqual(x.Eval(), 0) && CanProveGreaterEqual(y.Eval(), 0) &&
CanProveGreaterEqual(z.Eval(), 0));
}
return ret;
}