Optional ConditionalBoundsContext::TrySolveCondition()

in src/tir/transforms/ir_utils.cc [597:673]


Optional<arith::IntConstraints> ConditionalBoundsContext::TrySolveCondition() {
  // extract equations and related vars from condition expression.
  // currently only extract simple integral equations which could be solvable.
  arith::Analyzer analyzer;
  PrimExpr condition = analyzer.Simplify(condition_);
  if (is_const_int(condition)) {
    return NullOpt;
  }
  Array<PrimExpr> equations;
  Array<Var> vars;
  std::function<void(const PrimExpr&)> fvisit = [&equations, &vars, &fvisit](const PrimExpr& e) {
    if (e->IsInstance<GENode>() || e->IsInstance<GTNode>() || e->IsInstance<LENode>() ||
        e->IsInstance<LTNode>() || e->IsInstance<EQNode>() || e->IsInstance<NENode>()) {
      bool is_simple = true;
      std::vector<Var> cand_vars;
      PostOrderVisit(e, [&cand_vars, &is_simple, &e](const ObjectRef& obj) {
        if (obj.same_as(e)) {
          return;
        } else if (const VarNode* var = obj.as<VarNode>()) {
          if (var->dtype.is_int() || var->dtype.is_uint()) {
            cand_vars.push_back(GetRef<Var>(var));
          }
        } else {
          is_simple &= obj->IsInstance<AddNode>() || obj->IsInstance<SubNode>() ||
                       obj->IsInstance<MulNode>() || obj->IsInstance<FloorDivNode>() ||
                       obj->IsInstance<FloorModNode>() || obj->IsInstance<IntImmNode>();
        }
      });
      if (is_simple && !cand_vars.empty()) {
        for (const Var& new_var : cand_vars) {
          if (!std::any_of(vars.begin(), vars.end(),
                           [&new_var](const Var& v) { return v.same_as(new_var); })) {
            vars.push_back(new_var);
          }
        }
        equations.push_back(Downcast<PrimExpr>(e));
      }
    } else if (e->IsInstance<AndNode>()) {
      And op = Downcast<And>(e);
      fvisit(op->a);
      fvisit(op->b);
    } else if (e->IsInstance<CallNode>()) {
      Call op = Downcast<Call>(e);
      if (op->op.same_as(builtin::likely())) {
        fvisit(op->args[0]);
      }
    }
  };
  fvisit(condition);
  if (equations.empty() || vars.empty()) {
    return NullOpt;
  }
  // build dom ranges for related vars
  Map<Var, Range> ranges;
  for (const Var& v : vars) {
    arith::IntSet dom;
    auto relax_it = relax_map_->find(v.get());
    if (relax_it != relax_map_->end()) {
      dom = relax_it->second;
    } else {
      auto hint_it = hint_map_->find(v.get());
      if (hint_it != hint_map_->end()) {
        dom = hint_it->second;
      }
    }
    if (dom.defined()) {
      ranges.Set(v, Range::FromMinExtent(dom.min(), analyzer.Simplify(dom.max() - dom.min() + 1)));
    }
  }
  // solve constraints
  arith::IntConstraints constraint(vars, ranges, equations);
  arith::IntConstraints result = arith::SolveInequalitiesToRange(constraint);
  if (!result->relations.empty()) {
    return NullOpt;
  }
  return std::move(result);
}