PrimExpr RewriteSimplifier::Impl::VisitExpr_()

in src/arith/rewrite_simplify.cc [1276:1458]


PrimExpr RewriteSimplifier::Impl::VisitExpr_(const MinNode* op) {
  PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
  op = ret.as<MinNode>();
  if (auto const_res = TryConstFold<Min>(op->a, op->b)) return const_res.value();

  // Pattern var to match any expression
  PVar<PrimExpr> x, y, z, s1, s2;
  // Pattern var match IntImm
  PVar<IntImm> c1, c2;
  PVar<PrimExpr> lanes;

  // vector rule
  if (op->dtype.is_scalable_or_fixed_length_vector()) {
    TVM_TRY_REWRITE(min(broadcast(x, lanes), broadcast(y, lanes)), broadcast(min(x, y), lanes));
    TVM_TRY_REWRITE(min(min(x, broadcast(y, lanes)), broadcast(z, lanes)),
                    min(x, broadcast(min(y, z), lanes)));
  }
  if (IsIndexType(op->dtype)) {
    TVM_TRY_REWRITE(min(x, x), x);

    // constant int bound
    ConstIntBound a_bound = analyzer_->const_int_bound(op->a);
    ConstIntBound b_bound = analyzer_->const_int_bound(op->b);
    if (a_bound->max_value <= b_bound->min_value) {
      return op->a;
    }
    if (b_bound->max_value <= a_bound->min_value) {
      return op->b;
    }

    // constant comparison
    if (min(x + c1, x + c2).Match(ret)) {
      if (c1.Eval()->value < c2.Eval()->value) {
        return (x + c1).Eval();
      } else {
        return (x + c2).Eval();
      }
    }
    if (min(x + c1, x).Match(ret) || min(x, x + c1).Match(ret)) {
      if (c1.Eval()->value < 0) {
        return (x + c1).Eval();
      } else {
        return x.Eval();
      }
    }
    if (min(c1 - x, c2 - x).Match(ret)) {
      if (c1.Eval()->value < c2.Eval()->value) {
        return (c1 - x).Eval();
      } else {
        return (c2 - x).Eval();
      }
    }

    // DivMod rules
    // NOTE: trucdiv(x, y) >= floordiv(x, y)
    TVM_TRY_REWRITE_IF(
        matches_one_of(min(truncdiv(x + c1, c2) * c2, x), min(x, truncdiv(x + c1, c2) * c2),
                       min(floordiv(x + c1, c2) * c2, x), min(x, floordiv(x + c1, c2) * c2)),
        x, c2.Eval()->value > 0 && c1.Eval()->value + 1 == c2.Eval()->value);

    TVM_TRY_REWRITE_IF(matches_one_of(min(truncdiv(x + c1, c2) * c2, max(x, c2)),
                                      min(max(x, c2), truncdiv(x + c1, c2) * c2),
                                      min(floordiv(x + c1, c2) * c2, max(x, c2)),
                                      min(max(x, c2), floordiv(x + c1, c2) * c2)),
                       max(x, c2),
                       c2.Eval()->value > 0 && c1.Eval()->value + 1 == c2.Eval()->value &&
                           CanProveGreaterEqual(x.Eval(), 1));

    TVM_TRY_REWRITE_IF(matches_one_of(min(x, floordiv(x, c2) * c2), min(floordiv(x, c2) * c2, x)),
                       floordiv(x, c2) * c2, c2.Eval()->value > 0);

    TVM_TRY_REWRITE((PMatchesOneOf{
                        min(max(x, y), min(x, y)),
                        min(max(x, y), min(y, x)),
                        min(min(x, y), max(x, y)),
                        min(min(x, y), max(y, x)),
                        min(min(x, y), x),
                        min(min(x, y), y),
                        min(x, min(x, y)),
                        min(y, min(x, y)),
                    }),
                    min(x, y));

    TVM_TRY_REWRITE((PMatchesOneOf{
                        min(max(x, y), x),
                        min(max(y, x), x),
                        min(x, max(x, y)),
                        min(x, max(y, x)),
                    }),
                    x);

    TVM_TRY_REWRITE(min(min(min(x, y), z), y), min(min(x, y), z));
    TVM_TRY_REWRITE(min(min(min(min(x, y), z), s1), y), min(min(min(x, y), z), s1));
    TVM_TRY_REWRITE(min(min(min(min(min(x, y), z), s1), s2), y),
                    min(min(min(min(x, y), z), s1), s2));

    TVM_TRY_REWRITE((PMatchesOneOf{
                        min(max(x, y), max(x, z)),
                        min(max(x, y), max(z, x)),
                        min(max(y, x), max(x, z)),
                        min(max(y, x), max(z, x)),
                    }),
                    max(min(y, z), x));

    TVM_TRY_REWRITE((PMatchesOneOf{
                        min(min(x, y), min(x, z)),
                        min(min(x, y), min(z, x)),
                        min(min(y, x), min(x, z)),
                        min(min(y, x), min(z, x)),
                    }),
                    min(min(y, z), x));

    TVM_TRY_REWRITE((PMatchesOneOf{
                        min(y + x, z + x),
                        min(y + x, x + z),
                        min(x + y, x + z),
                        min(x + y, z + x),
                    }),
                    min(y, z) + x);

    // sub distribution
    TVM_TRY_REWRITE(min(y - x, z - x), min(y, z) - x);
    TVM_TRY_REWRITE(min(x - y, x - z), x - max(y, z));

    // constant folding rule.
    TVM_TRY_REWRITE(min(min(x, c1), c2), min(x, min(c1, c2)));

    // scaling rule
    if (min(truncdiv(x, c1), truncdiv(y, c1)).Match(ret)) {
      if (c1.Eval()->value > 0) {
        return truncdiv(min(x, y), c1).Eval();
      } else {
        return truncdiv(max(x, y), c1).Eval();
      }
    }
    if (min(floordiv(x, c1), floordiv(y, c1)).Match(ret)) {
      if (c1.Eval()->value > 0) {
        return floordiv(min(x, y), c1).Eval();
      } else {
        return floordiv(max(x, y), c1).Eval();
      }
    }
    if (min(x * c1, y * c1).Match(ret)) {
      if (c1.Eval()->value > 0) {
        return (min(x, y) * c1).Eval();
      } else {
        return (max(x, y) * c1).Eval();
      }
    }
    if (min(x * c1, c2).Match(ret)) {
      int64_t c1val = c1.Eval()->value;
      int64_t c2val = c2.Eval()->value;
      if (c1val == 0) {
        return c2val < 0 ? c2.Eval() : c1.Eval();
      }
      if (c2val % c1val == 0) {
        if (c1val > 0) {
          return (min(x, c2val / c1val) * c1val).Eval();
        } else {
          return (max(x, c2val / c1val) * c1val).Eval();
        }
      }
    }

    // vscale expression comparison
    if (ContainsVscaleCall(op->a) || ContainsVscaleCall(op->b)) {
      if (analyzer_->CanProve(op->a <= op->b)) {
        return op->a;
      }
      if (analyzer_->CanProve(op->b <= op->a)) {
        return op->b;
      }
    }

    // canonicalization
    TVM_TRY_RECURSIVE_REWRITE(min(min(x, c1), y), min(min(x, y), c1));
    TVM_TRY_RECURSIVE_REWRITE_IF(min(c1 - x, c2), c1 - max(x, c1 - c2), c2.Eval()->value != 0);
  }

  // condition rules.
  TVM_TRY_REWRITE(min(select(x, y, z), select(x, s1, s2)), select(x, min(y, s1), min(z, s2)));
  return ret;
}