in src/arith/rewrite_simplify.cc [1460:1651]
PrimExpr RewriteSimplifier::Impl::VisitExpr_(const MaxNode* op) {
PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<MaxNode>();
if (auto const_res = TryConstFold<Max>(op->a, op->b)) return const_res.value();
// Pattern var to match any expression
PVar<PrimExpr> x, y, z, s1, s2;
// Pattern var match IntImm
PVar<IntImm> c1, c2;
PVar<PrimExpr> lanes;
// vector rule
if (op->dtype.is_scalable_or_fixed_length_vector()) {
TVM_TRY_REWRITE(max(broadcast(x, lanes), broadcast(y, lanes)), broadcast(max(x, y), lanes));
TVM_TRY_REWRITE(max(max(x, broadcast(y, lanes)), broadcast(z, lanes)),
max(x, broadcast(max(y, z), lanes)));
}
if (IsIndexType(op->dtype)) {
TVM_TRY_REWRITE(max(x, x), x);
// constant int bound
ConstIntBound a_bound = analyzer_->const_int_bound(op->a);
ConstIntBound b_bound = analyzer_->const_int_bound(op->b);
if (a_bound->min_value >= b_bound->max_value) {
return op->a;
}
if (b_bound->min_value >= a_bound->max_value) {
return op->b;
}
// constant comparison
if (max(x + c1, x + c2).Match(ret)) {
if (c1.Eval()->value > c2.Eval()->value) {
return (x + c1).Eval();
} else {
return (x + c2).Eval();
}
}
if (max(x + c1, x).Match(ret) || max(x, x + c1).Match(ret)) {
if (c1.Eval()->value > 0) {
return (x + c1).Eval();
} else {
return x.Eval();
}
}
if (max(c1 - x, c2 - x).Match(ret)) {
if (c1.Eval()->value > c2.Eval()->value) {
return (c1 - x).Eval();
} else {
return (c2 - x).Eval();
}
}
// DivMod rules
// Divide up rounding: truc div
// NOTE: trucdiv(x, y) >= floordiv(x, y)
TVM_TRY_REWRITE_IF((PMatchesOneOf{
max(truncdiv(x + c1, c2) * c2, x),
max(x, truncdiv(x + c1, c2) * c2),
}),
truncdiv(x + c1, c2) * c2,
c2.Eval()->value > 0 && c1.Eval()->value + 1 == c2.Eval()->value);
// Divide up rounding: floor div
TVM_TRY_REWRITE_IF((PMatchesOneOf{
max(floordiv(x + c1, c2) * c2, x),
max(x, floordiv(x + c1, c2) * c2),
}),
floordiv(x + c1, c2) * c2,
c2.Eval()->value > 0 && c1.Eval()->value + 1 == c2.Eval()->value);
TVM_TRY_REWRITE_IF((PMatchesOneOf{
max(floordiv(x, c2) * c2, x),
max(x, floordiv(x, c2) * c2),
}),
x, c2.Eval()->value > 0);
TVM_TRY_REWRITE((PMatchesOneOf{
max(min(x, y), x),
max(min(y, x), x),
max(x, min(x, y)),
max(x, min(y, x)),
}),
x);
TVM_TRY_REWRITE((PMatchesOneOf{
max(min(x, y), max(x, y)),
max(min(x, y), max(y, x)),
max(max(x, y), min(x, y)),
max(max(x, y), min(y, x)),
max(max(x, y), x),
max(max(x, y), y),
max(x, max(x, y)),
max(y, max(x, y)),
}),
max(x, y));
TVM_TRY_REWRITE(max(max(max(x, y), z), y), max(max(x, y), z));
TVM_TRY_REWRITE(max(max(max(max(x, y), z), s1), y), max(max(max(x, y), z), s1));
TVM_TRY_REWRITE(max(max(max(max(max(x, y), z), s1), s2), y),
max(max(max(max(x, y), z), s1), s2));
// max/max cancelation
TVM_TRY_REWRITE((PMatchesOneOf{
max(max(x, y), max(x, z)),
max(max(x, y), max(z, x)),
max(max(y, x), max(x, z)),
max(max(y, x), max(z, x)),
}),
max(max(y, z), x));
// max/min distribution
TVM_TRY_REWRITE((PMatchesOneOf{
max(min(x, y), min(x, z)),
max(min(x, y), min(z, x)),
max(min(y, x), min(x, z)),
max(min(y, x), min(z, x)),
}),
min(max(y, z), x));
// add distribution
TVM_TRY_REWRITE((PMatchesOneOf{
max(y + x, z + x),
max(y + x, x + z),
max(x + y, x + z),
max(x + y, z + x),
}),
max(y, z) + x);
// sub distribution
TVM_TRY_REWRITE(max(y - x, z - x), max(y, z) - x);
TVM_TRY_REWRITE(max(x - y, x - z), x - min(y, z));
// constant folding rule.
TVM_TRY_REWRITE(max(max(x, c1), c2), max(x, max(c1, c2)));
// scaling rule
if (max(truncdiv(x, c1), truncdiv(y, c1)).Match(ret)) {
if (c1.Eval()->value > 0) {
return truncdiv(max(x, y), c1).Eval();
} else {
return truncdiv(min(x, y), c1).Eval();
}
}
if (max(floordiv(x, c1), floordiv(y, c1)).Match(ret)) {
if (c1.Eval()->value > 0) {
return floordiv(max(x, y), c1).Eval();
} else {
return floordiv(min(x, y), c1).Eval();
}
}
if (max(x * c1, y * c1).Match(ret)) {
if (c1.Eval()->value > 0) {
return (max(x, y) * c1).Eval();
} else {
return (min(x, y) * c1).Eval();
}
}
if (max(x * c1, c2).Match(ret)) {
int64_t c1val = c1.Eval()->value;
int64_t c2val = c2.Eval()->value;
if (c1val == 0) {
return c2val > 0 ? c2.Eval() : c1.Eval();
}
if (c2val % c1val == 0) {
if (c1val > 0) {
return (max(x, c2val / c1val) * c1val).Eval();
} else {
return (min(x, c2val / c1val) * c1val).Eval();
}
}
}
// vscale expression comparison
if (ContainsVscaleCall(op->a) || ContainsVscaleCall(op->b)) {
if (analyzer_->CanProve(op->a >= op->b)) {
return op->a;
}
if (analyzer_->CanProve(op->b >= op->a)) {
return op->b;
}
}
// canonicalization
TVM_TRY_RECURSIVE_REWRITE(max(max(x, c1), y), max(max(x, y), c1));
TVM_TRY_RECURSIVE_REWRITE_IF(max(c1 - x, c2), c1 - min(x, c1 - c2), c2.Eval()->value != 0);
}
// condition rules.
TVM_TRY_REWRITE(max(select(x, y, z), select(x, s1, s2)), select(x, max(y, s1), max(z, s2)));
return ret;
}