in src/tir/transforms/memhammer_intermediate_stage.cc [228:427]
std::pair<Stmt, SeqStmt> InsertCacheStage(Stmt stmt, bool is_write_cache, String storage_scope,
Optional<For> compute_location,
const Array<For>& outer_loops, Buffer* alloc_buffer) {
Stmt body = stmt;
std::vector<const ForNode*> loops;
std::vector<const ForNode*> loops_under_compute_location;
std::vector<const ForNode*> relaxed_thread_loops;
bool need_relax = !compute_location.defined();
Map<Var, Range> var_range;
PrimExpr vector_bytes = -1;
// Step 1. Perform rank promotion on the buffer access, turning a strided-changing dimension into
// several contiguous-changing dimensions
// Step 1.1 collect loop var range for rank promotion
while (const ForNode* loop = body.as<ForNode>()) {
if (need_relax) {
var_range.Set(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent));
loops_under_compute_location.push_back(loop);
} else {
loops.push_back(loop);
}
if (loop == compute_location.value_or(For()).get()) {
need_relax = true;
}
if (loop->kind == ForKind::kVectorized) {
vector_bytes = loop->extent;
}
body = loop->body;
}
Optional<PrimExpr> predicate;
if (const auto* op = body.as<IfThenElseNode>()) {
// the predicate is generated by coalescing
predicate = op->condition;
body = op->then_case;
}
for (const For& loop : outer_loops) {
if (loop->kind == ForKind::kThreadBinding) {
const String& thread_tag = loop->thread_binding.value()->thread_tag;
if (CanRelaxStorageUnderThread(runtime::StorageScope::Create(storage_scope),
runtime::ThreadScope::Create(thread_tag))) {
var_range.Set(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent));
relaxed_thread_loops.push_back(loop.get());
}
}
}
arith::Analyzer analyzer;
const BufferLoadNode* target_buffer_load = nullptr;
if (is_write_cache) {
tir::PreOrderVisit(stmt, [&](const ObjectRef& obj) {
if (const auto* buffer_load = obj.as<BufferLoadNode>()) {
if (buffer_load->buffer.scope() == "wmma.accumulator" ||
buffer_load->buffer.scope() == "m16n8k8.matrixC") {
if (target_buffer_load == nullptr) {
target_buffer_load = buffer_load;
} else {
CHECK(target_buffer_load->buffer.same_as(buffer_load->buffer))
<< "More than one target buffer found";
ICHECK(target_buffer_load->indices.size() == buffer_load->indices.size());
for (size_t i = 0; i < target_buffer_load->indices.size(); i++) {
CHECK(
analyzer.CanProveEqual(target_buffer_load->indices[i], buffer_load->indices[i]));
}
}
}
}
return true;
});
CHECK(target_buffer_load);
}
const BufferStoreNode* buf_store = TVM_TYPE_AS(body, BufferStoreNode);
Array<PrimExpr> cache_indices;
Array<PrimExpr> new_shape;
bool use_rank_promotion = false;
if (!is_write_cache && buf_store->value.as<BufferLoadNode>()) {
Array<PrimExpr> indices =
is_write_cache ? buf_store->indices : buf_store->value.as<BufferLoadNode>()->indices;
new_shape = IndexPatternFinder::getRankPromotedShape(indices, var_range, &cache_indices);
// write cache disabled for now
// rank promotion for write cache cannot guarantee the shape fits wmma.accumulator
if (!new_shape.empty()) {
use_rank_promotion = true;
}
}
Array<Var> new_loop_vars;
Map<Var, PrimExpr> subst_map;
if (!use_rank_promotion) {
cache_indices.clear();
for (const ForNode* loop : relaxed_thread_loops) {
new_shape.push_back(loop->extent);
}
for (const ForNode* loop : loops_under_compute_location) {
new_shape.push_back(loop->extent);
}
}
for (int i = 0; i < static_cast<int>(relaxed_thread_loops.size()); i++) {
const ForNode* loop = relaxed_thread_loops[i];
Var new_loop_var = loop->loop_var.copy_with_suffix("_cache");
new_loop_vars.push_back(new_loop_var);
subst_map.Set(loop->loop_var, new_loop_var);
if (!use_rank_promotion) {
cache_indices.push_back(loop->loop_var);
}
}
for (int i = 0; i < static_cast<int>(loops_under_compute_location.size()); i++) {
const ForNode* loop = loops_under_compute_location[i];
Var new_loop_var = loop->loop_var.copy_with_suffix("_cache");
new_loop_vars.push_back(new_loop_var);
subst_map.Set(loop->loop_var, new_loop_var);
if (!use_rank_promotion) {
cache_indices.push_back(loop->loop_var);
}
}
Array<PrimExpr> subst_indices;
Array<PrimExpr> subst_cache_indices;
if (is_write_cache) {
for (PrimExpr e : buf_store->indices) {
subst_indices.push_back(Substitute(e, subst_map));
}
}
for (PrimExpr e : cache_indices) {
subst_cache_indices.push_back(Substitute(e, subst_map));
}
Buffer new_buffer;
if (is_write_cache) {
// this is needed for global <- cast(load(wmma))
// shared stage should have the same dtype as wmma
new_buffer = WithScope(target_buffer_load->buffer, storage_scope);
} else {
new_buffer = WithScope(buf_store->buffer, storage_scope);
}
BufferNode* buffer_ptr = new_buffer.CopyOnWrite();
buffer_ptr->shape = new_shape;
*alloc_buffer = new_buffer;
Stmt generate_body;
if (is_write_cache) {
// copy from wmma to new cache buffer
BufferLoad new_buffer_load{new_buffer, cache_indices};
generate_body =
BufferLoadReplacer(target_buffer_load->buffer, new_buffer_load)(GetRef<Stmt>(buf_store));
generate_body = Substitute(generate_body, subst_map);
} else {
generate_body =
BufferStore(new_buffer, Substitute(buf_store->value, subst_map), subst_cache_indices);
}
if (predicate.defined()) {
// generated by coalescing
CHECK_EQ(loops_under_compute_location.size(), 2);
PrimExpr subst_value = 0;
PrimExpr subst_predicate = Substitute(predicate.value(), subst_map);
generate_body = IfThenElse(subst_predicate, generate_body);
}
for (int i = static_cast<int>(loops_under_compute_location.size()) - 1; i >= 0; i--) {
const ForNode* orig_loop = loops_under_compute_location[i];
ObjectPtr<ForNode> new_loop = make_object<ForNode>(*orig_loop);
new_loop->loop_var = new_loop_vars[i + relaxed_thread_loops.size()];
new_loop->body = generate_body;
generate_body = For(new_loop);
}
for (int i = static_cast<int>(relaxed_thread_loops.size()) - 1; i >= 0; i--) {
const ForNode* orig_loop = relaxed_thread_loops[i];
ObjectPtr<ForNode> new_loop = make_object<ForNode>(*orig_loop);
new_loop->loop_var = new_loop_vars[i];
new_loop->body = generate_body;
new_loop->kind = ForKind::kSerial;
new_loop->thread_binding = NullOpt;
new_loop->annotations = {};
generate_body = For(new_loop);
}
Stmt rewrite_body;
if (is_write_cache) {
BufferLoad new_buffer_load{new_buffer, cache_indices};
rewrite_body = BufferStore(new_buffer, GetRef<BufferLoad>(target_buffer_load), cache_indices);
} else {
rewrite_body =
BufferStore(buf_store->buffer, BufferLoad(new_buffer, cache_indices), buf_store->indices);
}
if (predicate.defined()) {
rewrite_body = IfThenElse(predicate.value(), rewrite_body);
}
for (int i = static_cast<int>(loops_under_compute_location.size()) - 1; i >= 0; i--) {
const ForNode* orig_loop = loops_under_compute_location[i];
ObjectPtr<ForNode> new_loop = make_object<ForNode>(*orig_loop);
new_loop->body = rewrite_body;
rewrite_body = For(new_loop);
}
SeqStmt insert_location;
if (is_write_cache) {
generate_body = insert_location = SeqStmt({rewrite_body, generate_body});
} else {
generate_body = insert_location = SeqStmt({generate_body, rewrite_body});
}
generate_body = CopyLoopChain(loops, generate_body);
return std::make_pair(generate_body, insert_location);
}