bool MatchBoundConstraints()

in src/arith/iter_affine_map.cc [1300:1419]


bool MatchBoundConstraints(PrimExpr pred, Map<Var, Range>* input_iters,
                           std::vector<IterConstraint>* result) {
  arith::PVar<PrimExpr> lhs, rhs, rest;
  for (;;) {
    // try extract comparisions
    bool is_finish = false;
    bool is_greater = false;
    bool is_equal = false;
    if ((rest && (lhs < rhs)).Match(pred) || ((lhs < rhs) && rest).Match(pred)) {
      // pass
    } else if ((lhs < rhs).Match(pred)) {
      is_finish = true;
    } else if ((rest && (lhs <= rhs)).Match(pred) || ((lhs <= rhs) && rest).Match(pred)) {
      is_equal = true;
    } else if ((lhs <= rhs).Match(pred)) {
      is_equal = true;
      is_finish = true;
    } else if ((rest && (lhs > rhs)).Match(pred) || ((lhs > rhs) && rest).Match(pred)) {
      is_greater = true;
    } else if ((lhs > rhs).Match(pred)) {
      is_greater = true;
      is_finish = true;
    } else if ((rest && (lhs >= rhs)).Match(pred) || ((lhs >= rhs) && rest).Match(pred)) {
      is_greater = true;
      is_equal = true;
    } else if ((lhs >= rhs).Match(pred)) {
      is_greater = true;
      is_equal = true;
      is_finish = true;
    } else {
      return false;
    }
    PrimExpr lhs_expr = lhs.Eval();
    PrimExpr rhs_expr = rhs.Eval();
    // we only accept predicate of integers
    if (!((lhs_expr->dtype.is_int() || lhs_expr->dtype.is_uint()) &&
          (rhs_expr->dtype.is_int() || rhs_expr->dtype.is_uint()))) {
      return false;
    }
    // determine iter and bound, if we can not distinguish them simply,
    // try divide (lhs - rhs) into itervar aware and itervar free parts
    auto f_use_itervar = [&input_iters](const VarNode* v) {
      return input_iters->count(GetRef<Var>(v));
    };
    bool bound_at_left;
    if (UsesVar(lhs_expr, f_use_itervar) || UsesVar(rhs_expr, f_use_itervar)) {
      // At least it uses one input iter
      if (is_const_int(lhs_expr) || !UsesVar(lhs_expr, f_use_itervar)) {
        bound_at_left = true;
      } else if (is_const_int(rhs_expr) || !UsesVar(rhs_expr, f_use_itervar)) {
        bound_at_left = false;
      } else {
        bound_at_left = false;  // accumulate bound to rhs
        PrimExpr sum_parts = lhs_expr - rhs_expr;
        lhs_expr = 0;
        rhs_expr = 0;
        std::function<void(const PrimExpr&, bool)> f_extract =
            [&lhs_expr, &rhs_expr, f_use_itervar, &f_extract](const PrimExpr& part, bool sign) {
              if (const AddNode* add = part.as<AddNode>()) {
                f_extract(add->a, sign);
                f_extract(add->b, sign);
              } else if (const SubNode* sub = part.as<SubNode>()) {
                f_extract(sub->a, sign);
                f_extract(sub->b, !sign);
              } else if (UsesVar(part, f_use_itervar)) {
                lhs_expr = sign ? lhs_expr + part : lhs_expr - part;
              } else {
                rhs_expr = sign ? rhs_expr - part : rhs_expr + part;
              }
            };
        f_extract(sum_parts, true);
        arith::Analyzer analyzer;
        lhs_expr = analyzer.Simplify(lhs_expr);
        rhs_expr = analyzer.Simplify(rhs_expr);
      }
      Optional<PrimExpr> lower_bound = NullOpt, upper_bound = NullOpt;
      PrimExpr iter;
      if (is_greater) {
        if (bound_at_left) {
          // bound > iter / bound >= iter
          upper_bound = is_equal ? lhs_expr + 1 : lhs_expr;
          iter = rhs_expr;
        } else {
          // iter > bound / iter >= bound
          lower_bound = is_equal ? rhs_expr : rhs_expr + 1;
          iter = lhs_expr;
        }
      } else {
        if (bound_at_left) {
          // bound < iter / bound <= iter
          lower_bound = is_equal ? lhs_expr : lhs_expr + 1;
          iter = rhs_expr;
        } else {
          // iter < bound / iter <= bound
          upper_bound = is_equal ? rhs_expr + 1 : rhs_expr;
          iter = lhs_expr;
        }
      }
      // If it is a predicate for a single input iter
      if (auto opt = iter.as<Var>()) {
        auto var = opt.value();
        auto it = input_iters->find(var);
        if (it != input_iters->end()) {
          PrimExpr iter_min = (*it).second->min;
          PrimExpr iter_max = (*it).second->min + (*it).second->extent;
          if (lower_bound.defined()) iter_min = max(iter_min, lower_bound.value());
          if (upper_bound.defined()) iter_max = min(iter_max, upper_bound.value());
          input_iters->Set(var, Range(iter_min, iter_max));
        }
      } else {
        result->emplace_back(iter, lower_bound, upper_bound, 0);
      }
    }
    if (is_finish) {
      break;
    }
    pred = rest.Eval();
  }
  return true;
}