Stmt LoopPartitioner::TryPartition()

in src/tir/transforms/loop_partition.cc [573:749]


Stmt LoopPartitioner::TryPartition(const Stmt& stmt, Var var, PrimExpr min, PrimExpr max, Stmt body,
                                   bool partition_thread_scope) {
  using namespace arith;
  // include hint of var.
  hint_map_.insert({var.get(), IntSet::Interval(min, max)});

  bool has_partition_hint_ = selector.partition_hint_vars.count(var.get());
  PartitionFinder finder(var, hint_map_, relax_map_, has_partition_hint_);
  finder(body);

  hint_map_.erase(var.get());
  if (finder.partitions.empty()) return Stmt();

  arith::IntervalSet for_interval(min, max);

  auto [middle_interval, cond_set,
        opt_cond_value] = [&]() -> std::tuple<IntSet, ExpressionSet, std::optional<bool>> {
    {
      // find an interval in which all conditions on var are true
      auto [middle_interval, cond_set] =
          GetIntervalAndCondset(finder.partitions, for_interval, true, has_partition_hint_);
      if (!middle_interval.IsNothing()) {
        return {middle_interval, cond_set, true};
      }
    }

    {
      // if such interval doesn't exist, find an interval in which all
      // conditions on var are false
      auto [middle_interval, cond_set] =
          GetIntervalAndCondset(finder.partitions, for_interval, false, has_partition_hint_);

      if (!middle_interval.IsNothing()) {
        return {middle_interval, cond_set, false};
      }
    }

    bool all_singlepoints_outside = true;

    // Check all partitions to see if they are single points and outside `for_interval`
    for (const auto& partition : finder.partitions) {
      const auto& intset = partition.second;
      // Only proceed if the interval set is a single point
      if (intset.IsSinglePoint()) {
        auto single_point = intset.PointValue();
        // Check if the single point is outside the `for_interval`
        bool is_inside = analyzer_.CanProve(single_point >= for_interval.min()) &&
                         analyzer_.CanProve(single_point <= for_interval.max());
        if (is_inside) {
          // If any single point is inside, this is an error condition
          LOG(ERROR) << "unexpected case happened.";
          all_singlepoints_outside = false;
          break;
        }
      } else {
        // If there is any intset that is not a single point, follow default logic
        // For now, we set all_singlepoints_outside to false to indicate default logic was used
        all_singlepoints_outside = false;
        break;
      }
    }

    if (all_singlepoints_outside) {
      // If all single points are outside `for_interval`, return a nothing interval and false
      return {IntSet::Nothing(), ExpressionSet(), false};
    }

    // we couldn't find an interval in which the conditions are
    // provably true or false.  Therefore, we can't partition the loop
    // based on those conds
    return {{}, {}, std::nullopt};
  }();

  if (middle_interval.IsNothing() && opt_cond_value == false) {
    return Stmt();
  }

  if (!opt_cond_value.has_value()) {
    if (has_partition_hint_ && unroll_loop_with_partition_hint_no_interval_ &&
        analyzer_.CanProve(max - min > 0)) {
      auto new_body = VisitAndMutate(body);
      return For(var, min, max - min + 1, ForKind::kUnrolled, new_body);
    }
    return Stmt();
  }
  bool cond_value = opt_cond_value.value();

  IntervalSet middle_interval_i = Downcast<IntervalSet>(middle_interval);
  // middle_interval is the subrange of the loop variable range for which a
  // set of conditions are true (or false resp.)
  // The part of the loop variable range that is before (after resp.) that
  // subrange is prefixed with pre- (post- resp.)

  // Calculating pre-subrange and generating code for it.
  // pre-subrange = [min, body_begin)
  PrimExpr body_begin;
  Stmt pre_stmt;
  bool pre_stmt_recurse = true;
  if (middle_interval_i->HasLowerBound()) {
    body_begin = analyzer_.Simplify(middle_interval.min());
    if (!analyzer_.CanProve(body_begin == min)) {
      PrimExpr extent = analyzer_.Simplify(body_begin - min);
      if (!analyzer_.CanProve(extent > 0)) {
        body_begin = tvm::max(body_begin, min);
        // stop recursing on this interval if we can't prove it has non-negative length
        pre_stmt_recurse = false;
      }
      if (!analyzer_.CanProve(extent <= 0)) {
        if (!partition_thread_scope) {
          Stmt pre_body = Substitute(body, {{Var{var}, var + min}});
          pre_stmt = MakeFor(stmt.get(), body_begin - min, pre_body);
        }
      }
    }
  } else {
    body_begin = min;
  }

  // Calculating post-subrange and generating code for it.
  // post-subrange = [post_doubt_begin, max+1)
  PrimExpr post_doubt_begin;
  Stmt post_stmt;
  bool post_stmt_recurse = true;
  if (middle_interval_i->HasUpperBound()) {
    post_doubt_begin = analyzer_.Simplify(middle_interval.max() + 1);
    if (!analyzer_.CanProve(middle_interval.max() == max)) {
      // require the extent to be non-negative
      PrimExpr extent = analyzer_.Simplify(max - post_doubt_begin + 1);
      if (!analyzer_.CanProve(extent > 0)) {
        post_doubt_begin = tvm::min(post_doubt_begin, max + 1);
        // stop recursing on this interval if we can't prove it has non-negative length
        post_stmt_recurse = false;
      }
      if (!analyzer_.CanProve(extent <= 0)) {
        if (!partition_thread_scope) {
          Stmt post_body = Substitute(body, {{Var{var}, var + post_doubt_begin}});
          post_stmt = MakeFor(stmt.get(), extent, post_body);
        }
      }
    }
  } else {
    post_doubt_begin = max + 1;
  }

  Stmt s;

  // Generating code for middle subrange
  if (!partition_thread_scope) {
    Stmt mid_stmt;
    if (!analyzer_.CanProve(body_begin >= post_doubt_begin)) {
      // [body_begin, post_doubt_begin)
      Stmt simplified_body = ConditionEliminator(cond_set, cond_value)(body);
      Stmt new_body = Substitute(simplified_body, {{Var{var}, var + body_begin}});
      mid_stmt = MakeFor(stmt.get(), post_doubt_begin - body_begin, new_body);
      // Recurse until partitions is empty
      mid_stmt = VisitAndMutate(mid_stmt);
      // Recurse for each non-empty subrange only if there are at least
      // two non-empty subranges
      if (pre_stmt.defined() || post_stmt.defined()) {
        if (pre_stmt.defined() && pre_stmt_recurse) {
          pre_stmt = VisitAndMutate(pre_stmt);
        }
        if (post_stmt.defined() && post_stmt_recurse) {
          post_stmt = VisitAndMutate(post_stmt);
        }
      }
    }
    s = SeqStmt::Flatten(pre_stmt, mid_stmt, post_stmt);
  } else {
    PrimExpr cond = const_true();
    if (!analyzer_.CanProve(body_begin == min)) cond = cond && (var >= body_begin);
    if (!analyzer_.CanProve(post_doubt_begin == (max + 1))) cond = cond && (var < post_doubt_begin);
    s = ThreadPartitionInserter(cond_set, cond)(stmt);
  }
  s = ConvertSSA(s);
  return s;
}