in src/arith/iter_affine_map.cc [1300:1419]
bool MatchBoundConstraints(PrimExpr pred, Map<Var, Range>* input_iters,
std::vector<IterConstraint>* result) {
arith::PVar<PrimExpr> lhs, rhs, rest;
for (;;) {
// try extract comparisions
bool is_finish = false;
bool is_greater = false;
bool is_equal = false;
if ((rest && (lhs < rhs)).Match(pred) || ((lhs < rhs) && rest).Match(pred)) {
// pass
} else if ((lhs < rhs).Match(pred)) {
is_finish = true;
} else if ((rest && (lhs <= rhs)).Match(pred) || ((lhs <= rhs) && rest).Match(pred)) {
is_equal = true;
} else if ((lhs <= rhs).Match(pred)) {
is_equal = true;
is_finish = true;
} else if ((rest && (lhs > rhs)).Match(pred) || ((lhs > rhs) && rest).Match(pred)) {
is_greater = true;
} else if ((lhs > rhs).Match(pred)) {
is_greater = true;
is_finish = true;
} else if ((rest && (lhs >= rhs)).Match(pred) || ((lhs >= rhs) && rest).Match(pred)) {
is_greater = true;
is_equal = true;
} else if ((lhs >= rhs).Match(pred)) {
is_greater = true;
is_equal = true;
is_finish = true;
} else {
return false;
}
PrimExpr lhs_expr = lhs.Eval();
PrimExpr rhs_expr = rhs.Eval();
// we only accept predicate of integers
if (!((lhs_expr->dtype.is_int() || lhs_expr->dtype.is_uint()) &&
(rhs_expr->dtype.is_int() || rhs_expr->dtype.is_uint()))) {
return false;
}
// determine iter and bound, if we can not distinguish them simply,
// try divide (lhs - rhs) into itervar aware and itervar free parts
auto f_use_itervar = [&input_iters](const VarNode* v) {
return input_iters->count(GetRef<Var>(v));
};
bool bound_at_left;
if (UsesVar(lhs_expr, f_use_itervar) || UsesVar(rhs_expr, f_use_itervar)) {
// At least it uses one input iter
if (is_const_int(lhs_expr) || !UsesVar(lhs_expr, f_use_itervar)) {
bound_at_left = true;
} else if (is_const_int(rhs_expr) || !UsesVar(rhs_expr, f_use_itervar)) {
bound_at_left = false;
} else {
bound_at_left = false; // accumulate bound to rhs
PrimExpr sum_parts = lhs_expr - rhs_expr;
lhs_expr = 0;
rhs_expr = 0;
std::function<void(const PrimExpr&, bool)> f_extract =
[&lhs_expr, &rhs_expr, f_use_itervar, &f_extract](const PrimExpr& part, bool sign) {
if (const AddNode* add = part.as<AddNode>()) {
f_extract(add->a, sign);
f_extract(add->b, sign);
} else if (const SubNode* sub = part.as<SubNode>()) {
f_extract(sub->a, sign);
f_extract(sub->b, !sign);
} else if (UsesVar(part, f_use_itervar)) {
lhs_expr = sign ? lhs_expr + part : lhs_expr - part;
} else {
rhs_expr = sign ? rhs_expr - part : rhs_expr + part;
}
};
f_extract(sum_parts, true);
arith::Analyzer analyzer;
lhs_expr = analyzer.Simplify(lhs_expr);
rhs_expr = analyzer.Simplify(rhs_expr);
}
Optional<PrimExpr> lower_bound = NullOpt, upper_bound = NullOpt;
PrimExpr iter;
if (is_greater) {
if (bound_at_left) {
// bound > iter / bound >= iter
upper_bound = is_equal ? lhs_expr + 1 : lhs_expr;
iter = rhs_expr;
} else {
// iter > bound / iter >= bound
lower_bound = is_equal ? rhs_expr : rhs_expr + 1;
iter = lhs_expr;
}
} else {
if (bound_at_left) {
// bound < iter / bound <= iter
lower_bound = is_equal ? lhs_expr : lhs_expr + 1;
iter = rhs_expr;
} else {
// iter < bound / iter <= bound
upper_bound = is_equal ? rhs_expr + 1 : rhs_expr;
iter = lhs_expr;
}
}
// If it is a predicate for a single input iter
if (auto opt = iter.as<Var>()) {
auto var = opt.value();
auto it = input_iters->find(var);
if (it != input_iters->end()) {
PrimExpr iter_min = (*it).second->min;
PrimExpr iter_max = (*it).second->min + (*it).second->extent;
if (lower_bound.defined()) iter_min = max(iter_min, lower_bound.value());
if (upper_bound.defined()) iter_max = min(iter_max, upper_bound.value());
input_iters->Set(var, Range(iter_min, iter_max));
}
} else {
result->emplace_back(iter, lower_bound, upper_bound, 0);
}
}
if (is_finish) {
break;
}
pred = rest.Eval();
}
return true;
}