in src/tir/transforms/loop_partition.cc [573:749]
Stmt LoopPartitioner::TryPartition(const Stmt& stmt, Var var, PrimExpr min, PrimExpr max, Stmt body,
bool partition_thread_scope) {
using namespace arith;
// include hint of var.
hint_map_.insert({var.get(), IntSet::Interval(min, max)});
bool has_partition_hint_ = selector.partition_hint_vars.count(var.get());
PartitionFinder finder(var, hint_map_, relax_map_, has_partition_hint_);
finder(body);
hint_map_.erase(var.get());
if (finder.partitions.empty()) return Stmt();
arith::IntervalSet for_interval(min, max);
auto [middle_interval, cond_set,
opt_cond_value] = [&]() -> std::tuple<IntSet, ExpressionSet, std::optional<bool>> {
{
// find an interval in which all conditions on var are true
auto [middle_interval, cond_set] =
GetIntervalAndCondset(finder.partitions, for_interval, true, has_partition_hint_);
if (!middle_interval.IsNothing()) {
return {middle_interval, cond_set, true};
}
}
{
// if such interval doesn't exist, find an interval in which all
// conditions on var are false
auto [middle_interval, cond_set] =
GetIntervalAndCondset(finder.partitions, for_interval, false, has_partition_hint_);
if (!middle_interval.IsNothing()) {
return {middle_interval, cond_set, false};
}
}
bool all_singlepoints_outside = true;
// Check all partitions to see if they are single points and outside `for_interval`
for (const auto& partition : finder.partitions) {
const auto& intset = partition.second;
// Only proceed if the interval set is a single point
if (intset.IsSinglePoint()) {
auto single_point = intset.PointValue();
// Check if the single point is outside the `for_interval`
bool is_inside = analyzer_.CanProve(single_point >= for_interval.min()) &&
analyzer_.CanProve(single_point <= for_interval.max());
if (is_inside) {
// If any single point is inside, this is an error condition
LOG(ERROR) << "unexpected case happened.";
all_singlepoints_outside = false;
break;
}
} else {
// If there is any intset that is not a single point, follow default logic
// For now, we set all_singlepoints_outside to false to indicate default logic was used
all_singlepoints_outside = false;
break;
}
}
if (all_singlepoints_outside) {
// If all single points are outside `for_interval`, return a nothing interval and false
return {IntSet::Nothing(), ExpressionSet(), false};
}
// we couldn't find an interval in which the conditions are
// provably true or false. Therefore, we can't partition the loop
// based on those conds
return {{}, {}, std::nullopt};
}();
if (middle_interval.IsNothing() && opt_cond_value == false) {
return Stmt();
}
if (!opt_cond_value.has_value()) {
if (has_partition_hint_ && unroll_loop_with_partition_hint_no_interval_ &&
analyzer_.CanProve(max - min > 0)) {
auto new_body = VisitAndMutate(body);
return For(var, min, max - min + 1, ForKind::kUnrolled, new_body);
}
return Stmt();
}
bool cond_value = opt_cond_value.value();
IntervalSet middle_interval_i = Downcast<IntervalSet>(middle_interval);
// middle_interval is the subrange of the loop variable range for which a
// set of conditions are true (or false resp.)
// The part of the loop variable range that is before (after resp.) that
// subrange is prefixed with pre- (post- resp.)
// Calculating pre-subrange and generating code for it.
// pre-subrange = [min, body_begin)
PrimExpr body_begin;
Stmt pre_stmt;
bool pre_stmt_recurse = true;
if (middle_interval_i->HasLowerBound()) {
body_begin = analyzer_.Simplify(middle_interval.min());
if (!analyzer_.CanProve(body_begin == min)) {
PrimExpr extent = analyzer_.Simplify(body_begin - min);
if (!analyzer_.CanProve(extent > 0)) {
body_begin = tvm::max(body_begin, min);
// stop recursing on this interval if we can't prove it has non-negative length
pre_stmt_recurse = false;
}
if (!analyzer_.CanProve(extent <= 0)) {
if (!partition_thread_scope) {
Stmt pre_body = Substitute(body, {{Var{var}, var + min}});
pre_stmt = MakeFor(stmt.get(), body_begin - min, pre_body);
}
}
}
} else {
body_begin = min;
}
// Calculating post-subrange and generating code for it.
// post-subrange = [post_doubt_begin, max+1)
PrimExpr post_doubt_begin;
Stmt post_stmt;
bool post_stmt_recurse = true;
if (middle_interval_i->HasUpperBound()) {
post_doubt_begin = analyzer_.Simplify(middle_interval.max() + 1);
if (!analyzer_.CanProve(middle_interval.max() == max)) {
// require the extent to be non-negative
PrimExpr extent = analyzer_.Simplify(max - post_doubt_begin + 1);
if (!analyzer_.CanProve(extent > 0)) {
post_doubt_begin = tvm::min(post_doubt_begin, max + 1);
// stop recursing on this interval if we can't prove it has non-negative length
post_stmt_recurse = false;
}
if (!analyzer_.CanProve(extent <= 0)) {
if (!partition_thread_scope) {
Stmt post_body = Substitute(body, {{Var{var}, var + post_doubt_begin}});
post_stmt = MakeFor(stmt.get(), extent, post_body);
}
}
}
} else {
post_doubt_begin = max + 1;
}
Stmt s;
// Generating code for middle subrange
if (!partition_thread_scope) {
Stmt mid_stmt;
if (!analyzer_.CanProve(body_begin >= post_doubt_begin)) {
// [body_begin, post_doubt_begin)
Stmt simplified_body = ConditionEliminator(cond_set, cond_value)(body);
Stmt new_body = Substitute(simplified_body, {{Var{var}, var + body_begin}});
mid_stmt = MakeFor(stmt.get(), post_doubt_begin - body_begin, new_body);
// Recurse until partitions is empty
mid_stmt = VisitAndMutate(mid_stmt);
// Recurse for each non-empty subrange only if there are at least
// two non-empty subranges
if (pre_stmt.defined() || post_stmt.defined()) {
if (pre_stmt.defined() && pre_stmt_recurse) {
pre_stmt = VisitAndMutate(pre_stmt);
}
if (post_stmt.defined() && post_stmt_recurse) {
post_stmt = VisitAndMutate(post_stmt);
}
}
}
s = SeqStmt::Flatten(pre_stmt, mid_stmt, post_stmt);
} else {
PrimExpr cond = const_true();
if (!analyzer_.CanProve(body_begin == min)) cond = cond && (var >= body_begin);
if (!analyzer_.CanProve(post_doubt_begin == (max + 1))) cond = cond && (var < post_doubt_begin);
s = ThreadPartitionInserter(cond_set, cond)(stmt);
}
s = ConvertSSA(s);
return s;
}