std::pair InsertCacheStage()

in src/tir/transforms/memhammer_intermediate_stage.cc [228:427]


std::pair<Stmt, SeqStmt> InsertCacheStage(Stmt stmt, bool is_write_cache, String storage_scope,
                                          Optional<For> compute_location,
                                          const Array<For>& outer_loops, Buffer* alloc_buffer) {
  Stmt body = stmt;
  std::vector<const ForNode*> loops;
  std::vector<const ForNode*> loops_under_compute_location;
  std::vector<const ForNode*> relaxed_thread_loops;
  bool need_relax = !compute_location.defined();
  Map<Var, Range> var_range;
  PrimExpr vector_bytes = -1;
  // Step 1. Perform rank promotion on the buffer access, turning a strided-changing dimension into
  // several contiguous-changing dimensions
  // Step 1.1 collect loop var range for rank promotion
  while (const ForNode* loop = body.as<ForNode>()) {
    if (need_relax) {
      var_range.Set(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent));
      loops_under_compute_location.push_back(loop);
    } else {
      loops.push_back(loop);
    }
    if (loop == compute_location.value_or(For()).get()) {
      need_relax = true;
    }
    if (loop->kind == ForKind::kVectorized) {
      vector_bytes = loop->extent;
    }
    body = loop->body;
  }
  Optional<PrimExpr> predicate;
  if (const auto* op = body.as<IfThenElseNode>()) {
    // the predicate is generated by coalescing
    predicate = op->condition;
    body = op->then_case;
  }
  for (const For& loop : outer_loops) {
    if (loop->kind == ForKind::kThreadBinding) {
      const String& thread_tag = loop->thread_binding.value()->thread_tag;
      if (CanRelaxStorageUnderThread(runtime::StorageScope::Create(storage_scope),
                                     runtime::ThreadScope::Create(thread_tag))) {
        var_range.Set(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent));
        relaxed_thread_loops.push_back(loop.get());
      }
    }
  }

  arith::Analyzer analyzer;
  const BufferLoadNode* target_buffer_load = nullptr;
  if (is_write_cache) {
    tir::PreOrderVisit(stmt, [&](const ObjectRef& obj) {
      if (const auto* buffer_load = obj.as<BufferLoadNode>()) {
        if (buffer_load->buffer.scope() == "wmma.accumulator" ||
            buffer_load->buffer.scope() == "m16n8k8.matrixC") {
          if (target_buffer_load == nullptr) {
            target_buffer_load = buffer_load;
          } else {
            CHECK(target_buffer_load->buffer.same_as(buffer_load->buffer))
                << "More than one target buffer found";
            ICHECK(target_buffer_load->indices.size() == buffer_load->indices.size());
            for (size_t i = 0; i < target_buffer_load->indices.size(); i++) {
              CHECK(
                  analyzer.CanProveEqual(target_buffer_load->indices[i], buffer_load->indices[i]));
            }
          }
        }
      }
      return true;
    });
    CHECK(target_buffer_load);
  }

  const BufferStoreNode* buf_store = TVM_TYPE_AS(body, BufferStoreNode);
  Array<PrimExpr> cache_indices;
  Array<PrimExpr> new_shape;
  bool use_rank_promotion = false;
  if (!is_write_cache && buf_store->value.as<BufferLoadNode>()) {
    Array<PrimExpr> indices =
        is_write_cache ? buf_store->indices : buf_store->value.as<BufferLoadNode>()->indices;
    new_shape = IndexPatternFinder::getRankPromotedShape(indices, var_range, &cache_indices);
    // write cache disabled for now
    // rank promotion for write cache cannot guarantee the shape fits wmma.accumulator
    if (!new_shape.empty()) {
      use_rank_promotion = true;
    }
  }
  Array<Var> new_loop_vars;
  Map<Var, PrimExpr> subst_map;
  if (!use_rank_promotion) {
    cache_indices.clear();
    for (const ForNode* loop : relaxed_thread_loops) {
      new_shape.push_back(loop->extent);
    }
    for (const ForNode* loop : loops_under_compute_location) {
      new_shape.push_back(loop->extent);
    }
  }

  for (int i = 0; i < static_cast<int>(relaxed_thread_loops.size()); i++) {
    const ForNode* loop = relaxed_thread_loops[i];
    Var new_loop_var = loop->loop_var.copy_with_suffix("_cache");
    new_loop_vars.push_back(new_loop_var);
    subst_map.Set(loop->loop_var, new_loop_var);
    if (!use_rank_promotion) {
      cache_indices.push_back(loop->loop_var);
    }
  }
  for (int i = 0; i < static_cast<int>(loops_under_compute_location.size()); i++) {
    const ForNode* loop = loops_under_compute_location[i];
    Var new_loop_var = loop->loop_var.copy_with_suffix("_cache");
    new_loop_vars.push_back(new_loop_var);
    subst_map.Set(loop->loop_var, new_loop_var);
    if (!use_rank_promotion) {
      cache_indices.push_back(loop->loop_var);
    }
  }
  Array<PrimExpr> subst_indices;
  Array<PrimExpr> subst_cache_indices;
  if (is_write_cache) {
    for (PrimExpr e : buf_store->indices) {
      subst_indices.push_back(Substitute(e, subst_map));
    }
  }
  for (PrimExpr e : cache_indices) {
    subst_cache_indices.push_back(Substitute(e, subst_map));
  }

  Buffer new_buffer;
  if (is_write_cache) {
    // this is needed for global <- cast(load(wmma))
    // shared stage should have the same dtype as wmma
    new_buffer = WithScope(target_buffer_load->buffer, storage_scope);
  } else {
    new_buffer = WithScope(buf_store->buffer, storage_scope);
  }
  BufferNode* buffer_ptr = new_buffer.CopyOnWrite();
  buffer_ptr->shape = new_shape;
  *alloc_buffer = new_buffer;

  Stmt generate_body;
  if (is_write_cache) {
    // copy from wmma to new cache buffer
    BufferLoad new_buffer_load{new_buffer, cache_indices};
    generate_body =
        BufferLoadReplacer(target_buffer_load->buffer, new_buffer_load)(GetRef<Stmt>(buf_store));
    generate_body = Substitute(generate_body, subst_map);
  } else {
    generate_body =
        BufferStore(new_buffer, Substitute(buf_store->value, subst_map), subst_cache_indices);
  }

  if (predicate.defined()) {
    // generated by coalescing
    CHECK_EQ(loops_under_compute_location.size(), 2);
    PrimExpr subst_value = 0;
    PrimExpr subst_predicate = Substitute(predicate.value(), subst_map);
    generate_body = IfThenElse(subst_predicate, generate_body);
  }

  for (int i = static_cast<int>(loops_under_compute_location.size()) - 1; i >= 0; i--) {
    const ForNode* orig_loop = loops_under_compute_location[i];
    ObjectPtr<ForNode> new_loop = make_object<ForNode>(*orig_loop);
    new_loop->loop_var = new_loop_vars[i + relaxed_thread_loops.size()];
    new_loop->body = generate_body;
    generate_body = For(new_loop);
  }
  for (int i = static_cast<int>(relaxed_thread_loops.size()) - 1; i >= 0; i--) {
    const ForNode* orig_loop = relaxed_thread_loops[i];
    ObjectPtr<ForNode> new_loop = make_object<ForNode>(*orig_loop);
    new_loop->loop_var = new_loop_vars[i];
    new_loop->body = generate_body;
    new_loop->kind = ForKind::kSerial;
    new_loop->thread_binding = NullOpt;
    new_loop->annotations = {};
    generate_body = For(new_loop);
  }
  Stmt rewrite_body;
  if (is_write_cache) {
    BufferLoad new_buffer_load{new_buffer, cache_indices};
    rewrite_body = BufferStore(new_buffer, GetRef<BufferLoad>(target_buffer_load), cache_indices);
  } else {
    rewrite_body =
        BufferStore(buf_store->buffer, BufferLoad(new_buffer, cache_indices), buf_store->indices);
  }
  if (predicate.defined()) {
    rewrite_body = IfThenElse(predicate.value(), rewrite_body);
  }
  for (int i = static_cast<int>(loops_under_compute_location.size()) - 1; i >= 0; i--) {
    const ForNode* orig_loop = loops_under_compute_location[i];
    ObjectPtr<ForNode> new_loop = make_object<ForNode>(*orig_loop);
    new_loop->body = rewrite_body;
    rewrite_body = For(new_loop);
  }
  SeqStmt insert_location;
  if (is_write_cache) {
    generate_body = insert_location = SeqStmt({rewrite_body, generate_body});
  } else {
    generate_body = insert_location = SeqStmt({generate_body, rewrite_body});
  }
  generate_body = CopyLoopChain(loops, generate_body);
  return std::make_pair(generate_body, insert_location);
}