PrimExpr RewriteSimplifier::Impl::VisitExpr_()

in src/arith/rewrite_simplify.cc [2013:2158]


PrimExpr RewriteSimplifier::Impl::VisitExpr_(const AndNode* op) {
  PrimExpr ret = [&]() -> PrimExpr {
    // If this extension isn't enabled, just delegate out.
    if (!(enabled_extensions_ & kApplyConstraintsToBooleanBranches)) {
      return IRMutatorWithAnalyzer::VisitExpr_(op);
    }

    PrimExpr a = op->a;
    PrimExpr b = op->b;

    // Alternate which branch is used as the constraint, and which is
    // being simplified.  Because some sub-analyzers expect their
    // constraints to already be simplified, each branch may require
    // more than one update.  The loop condition allows each branch to
    // be visited up to twice, but only performs the second visit if
    // necessary.
    size_t iterations_since_update = 0;
    for (size_t i = 0; i < 4; i++) {
      PrimExpr& to_update = (i % 2 == 0) ? a : b;
      const PrimExpr& constraint = (i % 2 == 0) ? b : a;

      With<ConstraintContext> context(analyzer_, constraint);
      PrimExpr updated = VisitExpr(to_update);

      if (!to_update.same_as(updated)) {
        to_update = updated;
        iterations_since_update = 0;
      } else {
        iterations_since_update++;
        if (iterations_since_update >= 2) {
          break;
        }
      }
    }

    // Only construct a new object if a change has been made.
    // Otherwise, follow ExprMutator's convention of returning the
    // original object.
    if (a.same_as(op->a) && b.same_as(op->b)) {
      return GetRef<PrimExpr>(op);
    } else {
      return And(a, b);
    }
  }();

  op = ret.as<AndNode>();

  if (auto const_res = TryConstFold<And>(op->a, op->b)) return const_res.value();
  if (auto match = TryMatchLiteralConstraint(ret)) return match.value();
  if ((enabled_extensions_ & RewriteSimplifier::kConvertBooleanToAndOfOrs) &&
      !recursively_visiting_boolean_) {
    return SimplifyAsAndOfOrs(ret, analyzer_);
  }

  // Pattern var to match any expression
  PVar<PrimExpr> x, y, z;
  // Pattern var match IntImm
  PVar<IntImm> c1, c2, c3;
  PVar<PrimExpr> lanes;

  if (op->dtype.is_scalable_or_fixed_length_vector()) {
    TVM_TRY_REWRITE(broadcast(x, lanes) && broadcast(y, lanes), broadcast(x && y, lanes));
  }

  auto cfalse = PConst<PrimExpr>(make_const(op->dtype, false));
  TVM_TRY_REWRITE(x == y && x != y, cfalse);
  TVM_TRY_REWRITE(x != y && x == y, cfalse);
  TVM_TRY_REWRITE(x && !x, cfalse);
  TVM_TRY_REWRITE(x <= y && y < x, cfalse);
  TVM_TRY_REWRITE(y < x && x <= y, cfalse);

  TVM_TRY_REWRITE_IF(x < c1 && c2 < x, cfalse, c2.Eval()->value + 1 >= c1.Eval()->value);
  TVM_TRY_REWRITE_IF(c2 < x && x < c1, cfalse, c2.Eval()->value + 1 >= c1.Eval()->value);

  TVM_TRY_REWRITE_IF((PMatchesOneOf{
                         x < c1 && c2 <= x,
                         c2 <= x && x < c1,
                         x <= c1 && c2 < x,
                         c2 < x && x <= c1,
                     }),
                     cfalse, c2.Eval()->value >= c1.Eval()->value);

  TVM_TRY_REWRITE_IF((PMatchesOneOf{
                         x <= c1 && c2 <= x,
                         c2 <= x && x <= c1,
                     }),
                     cfalse, c2.Eval()->value > c1.Eval()->value);

  TVM_TRY_REWRITE((x == c1) && (x == c2), (x == c1) && (c1 == c2));
  TVM_TRY_REWRITE(matches_one_of(x == c1 && x != c2, x != c2 && x == c1), x == c1 && c1 != c2);

  TVM_TRY_RECURSIVE_REWRITE(matches_one_of(floordiv(x, c2) == c1 && floormod(x, c2) == c3,
                                           floormod(x, c2) == c3 && floordiv(x, c2) == c1),
                            x == c1 * c2 + c3);

  TVM_TRY_RECURSIVE_REWRITE_IF((PMatchesOneOf{
                                   0 <= x - y * c1 && x - y * c1 < c1,
                                   x - y * c1 < c1 && 0 <= x - y * c1,
                               }),
                               y == floordiv(x, c1), c1.Eval()->value > 0);

  TVM_TRY_RECURSIVE_REWRITE((PMatchesOneOf{
                                c1 < x - y * c1 && x - y * c1 <= 0,
                                x - y * c1 < c1 && 0 <= x - y * c1,
                            }),
                            y == floordiv(x, c1));
  TVM_TRY_RECURSIVE_REWRITE_IF((PMatchesOneOf{
                                   0 <= x + y * c2 && x + y * c2 < c1,
                                   x + y * c2 < c1 && 0 <= x + y * c2,
                               }),
                               y == floordiv(x, c1), c2.Eval()->value == -c1.Eval()->value);

  TVM_TRY_RECURSIVE_REWRITE_IF(x < c1 && floormod(x, c2) < c3,
                               x < c1 - c2 + c3 && floormod(x, c2) < c3,
                               c1.Eval()->value % c2.Eval()->value == 0);
  TVM_TRY_RECURSIVE_REWRITE_IF(
      x < c1 && floormod(x, c2) < c3, x < c1 - floormod(c1, c2) + c3 && floormod(x, c2) < c3,
      (c1.Eval()->value % c2.Eval()->value + c2.Eval()->value) % c2.Eval()->value >
          c3.Eval()->value);

  TVM_TRY_RECURSIVE_REWRITE_IF(x <= c1 && floormod(x, c2) < c3,
                               x < c1 + 1 - c2 + c3 && floormod(x, c2) < c3,
                               (c1.Eval()->value + 1) % c2.Eval()->value == 0);
  TVM_TRY_RECURSIVE_REWRITE_IF(
      x <= c1 && floormod(x, c2) < c3, x < c1 + 1 - floormod(c1, c2) + c3 && floormod(x, c2) < c3,
      (((c1.Eval()->value + 1) % c2.Eval()->value) + c2.Eval()->value) % c2.Eval()->value >
          c3.Eval()->value);

  TVM_TRY_RECURSIVE_REWRITE(matches_one_of(floordiv(x, c2) == c1 && floormod(x, c2) < c3,
                                           floormod(x, c2) < c3 && floordiv(x, c2) == c1),
                            c1 * c2 <= x && x < c1 * c2 + c3);
  TVM_TRY_RECURSIVE_REWRITE(matches_one_of(floordiv(x, c2) == c1 && floormod(x, c2) <= c3,
                                           floormod(x, c2) <= c3 && floordiv(x, c2) == c1),
                            c1 * c2 <= x && x <= c1 * c2 + c3);

  TVM_TRY_RECURSIVE_REWRITE(matches_one_of(floordiv(x, c2) == c1 && c3 <= floormod(x, c2),
                                           c3 <= floormod(x, c2) && floordiv(x, c2) == c1),
                            c1 * c2 + c3 <= x && x < (c1 + 1) * c2);
  TVM_TRY_RECURSIVE_REWRITE(matches_one_of(floordiv(x, c2) == c1 && c3 < floormod(x, c2),
                                           c3 < floormod(x, c2) && floordiv(x, c2) == c1),
                            c1 * c2 + c3 < x && x < (c1 + 1) * c2);

  TVM_TRY_RECURSIVE_REWRITE(x && (y && z), (x && y) && z);

  return ret;
}