in src/arith/rewrite_simplify.cc [1014:1157]
PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorDivNode* op) {
PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<FloorDivNode>();
if (auto const_res = TryConstFold<FloorDiv>(op->a, op->b)) return const_res.value();
// Pattern var to match any expression
PVar<PrimExpr> x, y, z, b1;
// Pattern var match IntImm
PVar<IntImm> c1, c2, c3;
// Pattern var for lanes in broadcast and ramp
PVar<PrimExpr> lanes;
// Vector rules
if (op->dtype.is_scalable_or_fixed_length_vector()) {
TVM_TRY_REWRITE(floordiv(broadcast(x, lanes), broadcast(y, lanes)),
broadcast(floordiv(x, y), lanes));
// ramp // bcast
if (floordiv(ramp(b1, c1, lanes), broadcast(c2, lanes)).Match(ret)) {
int64_t c1val = c1.Eval()->value;
int64_t c2val = c2.Eval()->value;
ICHECK(c2val != 0) << "division by zero";
if (c1val % c2val == 0) {
return ramp(floordiv(b1, c2), floordiv(c1, c2), lanes).Eval();
}
// If all possible indices in ramp are the same.
if (!arith::ExtractVscaleFactor(lanes.Eval())) {
ModularSet bmod = analyzer_->modular_set(b1.Eval());
int64_t ramp_min = floordiv(bmod->base, c2val);
auto lanes_int = lanes.Eval().as<IntImmNode>()->value;
int64_t ramp_max = floordiv(bmod->base + (lanes_int - 1) * c1val, c2val);
if (ramp_min == ramp_max) {
// If b1 can divide c2
if (bmod->coeff % c2val == 0) {
return broadcast(floordiv(b1, c2), lanes).Eval();
}
// If all indices can be guaranteed to settle inside a coeff range
if (c2val % bmod->coeff == 0 && bmod->base + (lanes_int - 1) * c1val < bmod->coeff) {
return broadcast(floordiv(b1, c2), lanes).Eval();
}
}
}
}
}
if (IsIndexType(op->dtype)) {
// Be-aware of the division rules: this is floor division.
TVM_TRY_REWRITE_IF(floordiv(floordiv(x, c1), c2), floordiv(x, c1 * c2),
c1.Eval()->value > 0 && c2.Eval()->value > 0);
TVM_TRY_REWRITE_IF(floordiv(floordiv(x, c1) + c2, c3), floordiv(x + c1 * c2, c1 * c3),
c1.Eval()->value > 0 && c3.Eval()->value > 0);
if (floordiv(x * c1 + y, c2).Match(ret) || floordiv(x * c1, c2).Match(ret) ||
floordiv(y + x * c1, c2).Match(ret)) {
int64_t c1val = c1.Eval()->value;
int64_t c2val = c2.Eval()->value;
PrimExpr yval = y.EvalOr(Integer(0));
if (c2val == 0) return ret;
// try eliminate residue part
PrimExpr residue =
floordiv(x.Eval() * floormod(c1.Eval(), c2val) + floormod(yval, c2val), c2val);
PrimExpr y_div = CanProveEqual(floordiv(yval, c2val), 0) ? 0 : floordiv(yval, c2val);
auto bound = analyzer_->const_int_bound(residue);
if (bound.defined() && bound->max_value == bound->min_value) {
return x.Eval() * floordiv(c1val, c2.Eval()) + (y_div + Integer(bound->max_value));
}
// try simplify divisor
if (c1val > 0 && c2val > 0 && c2val % c1val == 0 &&
CanProveLess(floormod(yval, c2val), c1val)) {
// assume c2 == a * c1, x == a * x' + b, y = d * c2 + e then
// (x * c1 + y) // c2
// ==> ((a * x' + b) * c1 + d * a * c1 + e) // (a * c1)
// ==> x' + d + (b * c1 + e) // c2
// ==> x' + d since 0 <= b * c1 <= (a-1) * c1, 0 <= e < c1
// ==> x // (c2 // c1) + (y // c2)
return floordiv(x.Eval(), floordiv(c2val, c1val)) + y_div;
}
}
TVM_TRY_REWRITE(floordiv(x, x), OneWithTypeLike(x));
TVM_TRY_REWRITE(matches_one_of(floordiv(x * c1, x), floordiv(c1 * x, x)), c1);
TVM_TRY_REWRITE(floordiv(floormod(x, 2) + 1, 2), floormod(x, 2));
// Rules involving 2-operands.
TVM_TRY_REWRITE_IF(floordiv(min(x * c1, y), c2), min(x * floordiv(c1, c2), floordiv(y, c2)),
c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0);
TVM_TRY_REWRITE_IF(floordiv(max(x * c1, y), c2), max(x * floordiv(c1, c2), floordiv(y, c2)),
c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0);
TVM_TRY_REWRITE_IF(floordiv(min(y, x * c1), c2), min(floordiv(y, c2), x * floordiv(c1, c2)),
c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0);
TVM_TRY_REWRITE_IF(floordiv(max(y, x * c1), c2), max(floordiv(y, c2), x * floordiv(c1, c2)),
c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0);
// Rules involving 3-operands.
TVM_TRY_REWRITE_IF(floordiv(x * c1 + y + z, c2), x * floordiv(c1, c2) + floordiv(y + z, c2),
c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0);
TVM_TRY_REWRITE_IF(floordiv(x * c1 + y + z, c2), floordiv(x, floordiv(c2, c1)),
c1.Eval()->value > 0 && c2.Eval()->value > 0 &&
c2.Eval()->value % c1.Eval()->value == 0 &&
CanProveEqual(floordiv(y.Eval() + z.Eval(), c1.Eval()), 0));
TVM_TRY_REWRITE_IF(matches_one_of(floordiv(x * c1 - y + z, c2), floordiv(x * c1 + z - y, c2)),
x * floordiv(c1, c2) + floordiv(z - y, c2),
c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0);
TVM_TRY_REWRITE_IF(floordiv(y + x * c1 + z, c2), x * floordiv(c1, c2) + floordiv(y + z, c2),
c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0);
TVM_TRY_REWRITE_IF(floordiv(x + c1, c2), floordiv(x, c2) + floordiv(c1, c2),
c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0);
TVM_TRY_REWRITE_IF(floordiv(x * c1, x * c2), floordiv(c1, c2), c2.Eval()->value > 0);
TVM_TRY_REWRITE_IF(matches_one_of(floordiv(x + y, x), floordiv(y + x, x)), floordiv(y, x) + 1,
CanProveGreaterEqual(x.Eval(), 0));
TVM_TRY_REWRITE_IF(matches_one_of(floordiv((x + y) + z, x), floordiv((y + x) + z, x),
floordiv(y + (z + x), x), floordiv(y + (x + z), x)),
floordiv(y + z, x) + 1, CanProveGreaterEqual(x.Eval(), 0));
TVM_TRY_REWRITE_IF(matches_one_of(floordiv(x * y, y), floordiv(y * x, y)), x,
CanProveGreaterEqual(y.Eval(), 0));
TVM_TRY_REWRITE_IF(matches_one_of(floordiv(x * z + y, z), floordiv(z * x + y, z)),
x + floordiv(y, z), CanProveGreaterEqual(z.Eval(), 0));
TVM_TRY_REWRITE_IF(matches_one_of(floordiv(y + x * z, z), floordiv(y + z * x, z)),
floordiv(y, z) + x, CanProveGreaterEqual(z.Eval(), 0));
TVM_TRY_REWRITE_IF(floordiv(x * z * c1 + y, z * c1), x + floordiv(y, z * c1),
CanProveGreaterEqual(z.Eval() * c1.Eval(), 0));
TVM_TRY_REWRITE_IF(floordiv(x - floormod(x, c1), c1), floordiv(x, c1), c1.Eval()->value != 0);
// Scalable divisor
TVM_TRY_REWRITE_IF(floordiv(x, y), ZeroWithTypeLike(x),
ContainsVscaleCall(y.Eval()) && CanProveGreaterEqual(x.Eval(), 0) &&
CanProveGreaterEqual(y.Eval(), 0) && CanProve(x.Eval() < y.Eval()));
}
return ret;
}