in src/tir/transforms/lower_cross_thread_reduction.cc [294:527]
Stmt TransformReductionBlock(const BlockRealizeNode* realize, //
const Optional<Array<Buffer>>& it_buffers, //
const Array<Buffer>& ct_buffers, //
const Array<Buffer>& wb_buffers, //
const Array<PrimExpr>& old_wb_indices, //
const CommReducer& reducer, //
const Array<PrimExpr>& combiner_rhs, //
const std::vector<const ForNode*>& reduction_loops) {
int n_buffers = wb_buffers.size();
const BlockNode* block = realize->block.get();
auto f_create_buffer_regions = [](Array<Buffer> buffers) {
Array<BufferRegion> regions;
regions.reserve(buffers.size());
for (const Buffer& buffer : buffers) {
regions.push_back(BufferRegion(buffer, {Range::FromMinExtent(0, 1)}));
}
return regions;
};
Array<BufferRegion> ct_buffer_regions = f_create_buffer_regions(ct_buffers);
Optional<Array<BufferRegion>> it_buffer_regions = NullOpt;
if (it_buffers.defined()) {
it_buffer_regions = f_create_buffer_regions(it_buffers.value());
}
// In total, the block is transformed into at most 4 statements
// - Stmt 1: initialize the buffer for in-thread reduction
// - Stmt 2: do in-thread reduction
// - Stmt 3: do cross-thread reduction
// - Stmt 4: write cross-thread reduction result to the original buffer
Array<Stmt> stmts;
stmts.reserve(4);
// Stmt 1: initialize the buffer for in-thread reduction
if (it_buffers.defined()) {
Array<Stmt> inits;
inits.reserve(n_buffers);
for (int i = 0; i < n_buffers; ++i) {
inits.push_back(
BufferStore(it_buffers.value()[i], reducer->identity_element[i], {Integer(0)}));
}
stmts.push_back(BlockRealize(/*iter_values=*/{},
/*predicate=*/const_true(),
/*block=*/
Block(/*iter_vars=*/{},
/*reads=*/{},
/*writes=*/it_buffer_regions.value(),
/*name_hint=*/block->name_hint + "_in_thread_init",
/*body=*/n_buffers > 1 ? SeqStmt(inits) : inits[0])));
}
// Stmt 2: do in-thread reduction
{
Optional<BlockRealize> new_realize = NullOpt;
// If need to generate in-thread reduction,
// then replace `wb_buffers` with `it_buffers` accordingly in given BlockRealize
// otherwise, directly remove given BlockRealize
if (it_buffers.defined()) {
ObjectPtr<BlockNode> new_block = make_object<BlockNode>(*block);
new_block->reads = std::move(new_block->reads);
new_block->writes = it_buffer_regions.value();
new_block->name_hint = new_block->name_hint + "_in_thread";
new_block->body =
BufferReplacer::Run(wb_buffers, it_buffers.value(), std::move(new_block->body));
new_block->init = NullOpt;
ObjectPtr<BlockRealizeNode> n = make_object<BlockRealizeNode>(*realize);
n->block = Block(new_block);
new_realize = BlockRealize(n);
}
For loop = GetRef<For>(reduction_loops[0]);
if (Optional<Stmt> stmt = InThreadReducerMaker::Make(realize, new_realize, std::move(loop))) {
stmts.push_back(stmt.value());
}
}
// Stmt 3: do cross-thread reduction
{
// Step 3.1. Create the parameters to the intrinsic
Array<PrimExpr> parameters;
parameters.reserve(reduction_loops.size() + 4);
// 1-st argument: number of buffers
parameters.push_back(make_const(DataType::UInt(32), n_buffers));
// Next `n_buffers` arguments: sources
if (it_buffers.defined()) {
for (int i = 0; i < n_buffers; ++i) {
parameters.push_back(BufferLoad(it_buffers.value()[i], {Integer(0)}));
}
} else {
parameters.insert(parameters.end(), combiner_rhs.begin(), combiner_rhs.end());
}
// Next argument: predicate
parameters.push_back(const_true());
// Next `n_buffers` arguments: destinations
for (int i = 0; i < n_buffers; ++i) {
parameters.push_back(BufferLoad(ct_buffers[i], {0}));
}
// Next arguments: all the reduction threads
for (const ForNode* reduction_loop : reduction_loops) {
if (reduction_loop->thread_binding.defined()) {
parameters.push_back(reduction_loop->loop_var);
}
}
// Step 3.2. Create the block and the block-realize.
Array<IterVar> iter_vars{nullptr};
Array<PrimExpr> bindings{nullptr};
Array<BufferRegion> reads{nullptr};
if (it_buffers.defined()) {
iter_vars = Array<IterVar>{};
bindings = Array<PrimExpr>{};
reads = it_buffer_regions.value();
} else {
iter_vars = block->iter_vars;
bindings = realize->iter_values;
reads = block->reads;
}
stmts.push_back(BlockRealize(
/*iter_values=*/std::move(bindings),
/*predicate=*/const_true(),
/*block=*/
Block(/*iter_vars=*/std::move(iter_vars),
/*reads=*/std::move(reads),
/*writes=*/ct_buffer_regions,
/*name_hint=*/block->name_hint + "_cross_thread",
/*body=*/
AttrStmt(/*node=*/reducer,
/*attr_key=*/tir::attr::reduce_scope,
/*value=*/make_zero(DataType::Handle()),
/*body=*/
Evaluate(Call(/*dtype=*/DataType::Handle(),
/*op=*/tir::builtin::tvm_thread_allreduce(),
/*args=*/std::move(parameters)))))));
}
// Stmt 4: write cross-thread reduction result to the original buffer
{
ICHECK_EQ(block->iter_vars.size(), realize->iter_values.size());
int n_iter = static_cast<int>(block->iter_vars.size());
Array<IterVar> iter_vars;
Array<PrimExpr> bindings;
Map<Var, Var> var_map;
iter_vars.reserve(n_iter);
bindings.reserve(n_iter);
for (int i = 0; i < n_iter; ++i) {
const IterVar& iter_var = block->iter_vars[i];
const PrimExpr& binding = realize->iter_values[i];
if (iter_var->iter_type != kCommReduce) {
IterVar new_iter_var{nullptr};
{
ObjectPtr<IterVarNode> n = make_object<IterVarNode>(*iter_var.get());
ObjectPtr<VarNode> v = make_object<VarNode>(*iter_var->var.get());
n->var = Var(v);
new_iter_var = IterVar(n);
}
iter_vars.push_back(new_iter_var);
bindings.push_back(binding);
var_map.Set(iter_var->var, new_iter_var->var);
}
}
Array<Stmt> wb_updates;
Array<BufferRegion> wb_regions;
wb_updates.reserve(n_buffers);
wb_regions.reserve(n_buffers);
int n_dim = static_cast<int>(old_wb_indices.size());
Array<Range> region = Substitute(block->writes[0]->region, var_map);
Array<PrimExpr> wb_indices;
wb_indices.reserve(n_dim);
for (int d = 0; d < n_dim; ++d) {
wb_indices.push_back(Substitute(old_wb_indices[d], var_map));
}
for (int i = 0; i < n_buffers; ++i) {
wb_updates.push_back(
BufferStore(wb_buffers[i], BufferLoad(ct_buffers[i], {Integer(0)}), wb_indices));
wb_regions.push_back(BufferRegion(wb_buffers[i], region));
}
// Construct the predicate of the write-back block. It is the conjunction of
// - each predicate clause of the original block which contains spatial loop var, and
// - `t == 0` for each reduction thread dim when the write-back buffer is not local.
PrimExpr wb_predicate = const_true();
std::unordered_set<const VarNode*> reduction_loop_vars;
reduction_loop_vars.reserve(reduction_loops.size());
for (const ForNode* reduction_loop : reduction_loops) {
reduction_loop_vars.insert(reduction_loop->loop_var.get());
}
PostOrderVisit(realize->predicate, [&wb_predicate, &reduction_loop_vars](const ObjectRef& obj) {
if (const auto* and_node = obj.as<AndNode>()) {
Array<PrimExpr> sub_exprs = {and_node->a, and_node->b};
for (PrimExpr sub_expr : sub_exprs) {
if (sub_expr->IsInstance<AndNode>()) {
continue;
}
bool is_reduction = [sub_expr, &reduction_loop_vars]() {
Array<Var> vars = UndefinedVars(sub_expr);
for (Var var : vars) {
if (reduction_loop_vars.find(var.get()) != reduction_loop_vars.end()) {
return true;
}
}
return false;
}();
if (!is_reduction) {
wb_predicate = wb_predicate && sub_expr;
}
}
return true;
}
return false;
});
if (wb_buffers[0].scope() != "local") {
for (const ForNode* loop : reduction_loops) {
if (loop->thread_binding.defined()) {
wb_predicate = wb_predicate && (loop->loop_var == IntImm(loop->loop_var->dtype, 0));
}
}
}
stmts.push_back(BlockRealize(
/*iter_values=*/std::move(bindings),
/*predicate=*/wb_predicate,
/*block=*/
Block(/*iter_vars=*/std::move(iter_vars),
/*reads=*/std::move(ct_buffer_regions),
/*writes=*/std::move(wb_regions),
/*name_hint=*/block->name_hint + "_write_back",
/*body=*/n_buffers > 1 ? SeqStmt(wb_updates) : wb_updates[0])));
}
// Final step: Wrap all the above four statements with the reduction loops bound to threadIdx
Stmt new_stmt = SeqStmt::Flatten(std::move(stmts));
for (auto rit = reduction_loops.rbegin(); rit != reduction_loops.rend(); ++rit) {
const ForNode* loop = *rit;
if (loop->thread_binding.defined()) {
ObjectPtr<ForNode> n = make_object<ForNode>(*loop);
n->body = std::move(new_stmt);
new_stmt = For(n);
}
}
return new_stmt;
}