in src/tir/transforms/ir_utils.cc [597:673]
Optional<arith::IntConstraints> ConditionalBoundsContext::TrySolveCondition() {
// extract equations and related vars from condition expression.
// currently only extract simple integral equations which could be solvable.
arith::Analyzer analyzer;
PrimExpr condition = analyzer.Simplify(condition_);
if (is_const_int(condition)) {
return NullOpt;
}
Array<PrimExpr> equations;
Array<Var> vars;
std::function<void(const PrimExpr&)> fvisit = [&equations, &vars, &fvisit](const PrimExpr& e) {
if (e->IsInstance<GENode>() || e->IsInstance<GTNode>() || e->IsInstance<LENode>() ||
e->IsInstance<LTNode>() || e->IsInstance<EQNode>() || e->IsInstance<NENode>()) {
bool is_simple = true;
std::vector<Var> cand_vars;
PostOrderVisit(e, [&cand_vars, &is_simple, &e](const ObjectRef& obj) {
if (obj.same_as(e)) {
return;
} else if (const VarNode* var = obj.as<VarNode>()) {
if (var->dtype.is_int() || var->dtype.is_uint()) {
cand_vars.push_back(GetRef<Var>(var));
}
} else {
is_simple &= obj->IsInstance<AddNode>() || obj->IsInstance<SubNode>() ||
obj->IsInstance<MulNode>() || obj->IsInstance<FloorDivNode>() ||
obj->IsInstance<FloorModNode>() || obj->IsInstance<IntImmNode>();
}
});
if (is_simple && !cand_vars.empty()) {
for (const Var& new_var : cand_vars) {
if (!std::any_of(vars.begin(), vars.end(),
[&new_var](const Var& v) { return v.same_as(new_var); })) {
vars.push_back(new_var);
}
}
equations.push_back(Downcast<PrimExpr>(e));
}
} else if (e->IsInstance<AndNode>()) {
And op = Downcast<And>(e);
fvisit(op->a);
fvisit(op->b);
} else if (e->IsInstance<CallNode>()) {
Call op = Downcast<Call>(e);
if (op->op.same_as(builtin::likely())) {
fvisit(op->args[0]);
}
}
};
fvisit(condition);
if (equations.empty() || vars.empty()) {
return NullOpt;
}
// build dom ranges for related vars
Map<Var, Range> ranges;
for (const Var& v : vars) {
arith::IntSet dom;
auto relax_it = relax_map_->find(v.get());
if (relax_it != relax_map_->end()) {
dom = relax_it->second;
} else {
auto hint_it = hint_map_->find(v.get());
if (hint_it != hint_map_->end()) {
dom = hint_it->second;
}
}
if (dom.defined()) {
ranges.Set(v, Range::FromMinExtent(dom.min(), analyzer.Simplify(dom.max() - dom.min() + 1)));
}
}
// solve constraints
arith::IntConstraints constraint(vars, ranges, equations);
arith::IntConstraints result = arith::SolveInequalitiesToRange(constraint);
if (!result->relations.empty()) {
return NullOpt;
}
return std::move(result);
}