in src/tir/transforms/lower_thread_allreduce.cc [156:430]
Stmt MakeAllreduce(const CallNode* call) {
ICHECK(!reduce_combiner_.empty());
const CommReducerNode* combiner = reduce_combiner_.back();
size_t size = combiner->result.size();
const IntImmNode* size_of_args = call->args[0].as<IntImmNode>();
ICHECK(size_of_args) << call->args[0]->GetTypeKey();
ICHECK_EQ(size, size_of_args->value);
Array<PrimExpr> inits = combiner->identity_element;
std::vector<PrimExpr> values(size);
std::vector<DataType> types(size);
PrimExpr cond = call->args[size + 1];
for (size_t idx = 0; idx < size; ++idx) {
values[idx] = call->args[1 + idx];
if (!is_one(cond)) {
values[idx] = Select(cond, values[idx], inits[idx]);
}
types[idx] = values[idx].dtype();
}
std::vector<Buffer> buffers(size);
for (size_t idx = 0; idx < size; ++idx) {
PrimExpr arg = call->args[2 + size + idx];
// Loads from boolean buffers may have cast nodes inserted by
// earlier passes.
if (auto cast = arg.as<CastNode>()) {
arg = cast->value;
}
buffers[idx] = Downcast<BufferLoad>(arg)->buffer;
}
std::unordered_set<const VarNode*> reduce_set;
for (size_t i = 2 + 2 * size; i < call->args.size(); ++i) {
const VarNode* v = call->args[i].as<VarNode>();
// The simply optimization replace a iteration variable with a constant
// when extent of the iteration is 1. As threaded IterVar always started from 0,
// we can just ignore this variable in this case.
if (v) {
reduce_set.insert(v);
} else {
ICHECK(call->args[i].as<IntImmNode>() && call->args[i].as<IntImmNode>()->value == 0)
<< "arg" << i << "should be a VarNode or IntImmNode";
}
}
size_t nmatch = 0;
std::vector<ThreadEntry> vred, vpar;
for (const AttrStmtNode* attr : thread_extents_) {
ThreadEntry e;
IterVar iv = Downcast<IterVar>(attr->node);
e.scope = runtime::ThreadScope::Create(iv->thread_tag);
e.iv = iv;
ICHECK_LE(e.scope.rank, 1);
ICHECK_GE(e.scope.dim_index, 0) << "vthread do not work with cross thread reduction";
if (e.scope.rank == 1) {
const auto* ptr = attr->value.as<IntImmNode>();
ICHECK(ptr) << "Need constant extent for reduce set " << iv;
e.extent = static_cast<int>(ptr->value);
// ignore variables equal to 0
if (e.extent == 1) {
continue;
}
if (reduce_set.count(iv->var.get())) {
vred.push_back(e);
++nmatch;
} else {
vpar.push_back(e);
}
}
}
ICHECK_EQ(nmatch, reduce_set.size()) << "Not all reduce index are presented in the context";
std::sort(vred.begin(), vred.end());
std::sort(vpar.begin(), vpar.end());
// the size of each index.
int reduce_extent, group_extent;
PrimExpr reduce_index = FlattenThread(vred, &reduce_extent);
PrimExpr group_index = FlattenThread(vpar, &group_extent);
// the longest contiguous reduce extent after flattening
int contiguous_reduce_extent = 1;
std::vector<std::tuple<int, int, bool>> block_threads; // tuple(dim_index, extent, is_reduce)
for (const ThreadEntry& thr : vred) {
if (thr.scope.rank == 1) { // threadIdx
block_threads.emplace_back(thr.scope.dim_index, thr.extent, true);
}
}
for (const ThreadEntry& thr : vpar) {
if (thr.scope.rank == 1) { // threadIdx
block_threads.emplace_back(thr.scope.dim_index, thr.extent, false);
}
}
// sort according to dim_index
std::sort(block_threads.begin(), block_threads.end());
for (auto&& thr_attr : block_threads) {
auto [dim_index, extent, is_reduce] = thr_attr;
(void)dim_index; // https://gcc.gnu.org/bugzilla/show_bug.cgi?id=81767
if (is_reduce) {
contiguous_reduce_extent *= extent;
} else {
break;
}
}
std::vector<Stmt> seq;
std::vector<Buffer> new_alloc_bufs;
//
// This is an optimization. For small reduction sizes, it may be beneficial
// for a single warp to performance the entire reduction. No trips to shared
// memory and no cross warp synchronizations are required.
// The following code emits the reduction as follows:
//
// Allocate reduction vars v[i], i = 0..size-1
//
// for offset from WARP_SIZE to 1 by 2
//
// a <- load(v[i])
// b <- shuffle_down(load(v[i], offset))
// v[i] <- reduction(a, b)
//
// broadcast results from lane 0 to all other lanes and store
// the final reduction result to the proper location.
//
// When the thread extent is multiple of warp size, we can use a two-stage
// warp-level reduction to optimize. This is implemented by applying the
// algorithm above twice.
//
// For example, suppose we want to use 512 threads to reduce 512 elements
// and the warp size is 32. In this case there are (512 / 32) = 16 warps.
// In the first stage, each of the 16 warps reduces 32 elements. So after
// the stage, we have 16 remaining elements to be reduced, one for each warp.
// We store the 16 elements in shared memory, and start the second stage.
// In the second stage we use the first 16 lanes of the first warp to reduce
// the remaining elements, and this reduction can also be optimized by
// shuffle_down warp-level primitives.
PrimExpr zero_index = make_const(reduce_index->dtype, 0);
if (IsWarpReduction(types, group_extent, reduce_extent, contiguous_reduce_extent)) {
std::vector<PrimExpr> reduce_results;
DataType mask_dtype = DataType::UInt(32);
PrimExpr mask = Call(mask_dtype, builtin::tvm_warp_activemask(), {});
if (reduce_extent <= warp_size_) {
std::tie(reduce_results, new_alloc_bufs) = MakeWarpAllreduce(
values, types, combiner, reduce_index, reduce_extent, group_index, mask, NullOpt, &seq);
// Broadcast the reduction result from lane 0 to all other lanes.
// This avoids to emit predicated stores, as all threads are
// uniformly writing the same result.
for (size_t i = 0; i < size; ++i) {
Buffer buf = Downcast<BufferLoad>(reduce_results[i])->buffer;
PrimExpr val = BufferLoad(buf, {zero_index});
ICHECK_EQ(val->dtype, types[i]);
PrimExpr splat = WarpShuffle(builtin::tvm_warp_shuffle(), new_alloc_bufs.back(), val,
reduce_extent * group_index);
seq.push_back(BufferStore(buf, splat, {zero_index}));
}
} else {
int n_warps = reduce_extent / warp_size_;
std::vector<Buffer> local_bufs;
// 1. Create the staging buffer in shared memory.
std::vector<Buffer> staging_shared_bufs;
staging_shared_bufs.reserve(size);
for (size_t i = 0; i < size; ++i) {
Buffer staging_shared_buf = decl_buffer(
/*shape=*/{make_const(reduce_index->dtype, n_warps * group_extent)},
/*dtype=*/buffers[i]->dtype, /*name=*/"red_buf_staging", /*storage_scope=*/"shared");
staging_shared_bufs.push_back(staging_shared_buf);
new_alloc_bufs.push_back(staging_shared_buf);
}
// 2. First round of allreduce.
std::tie(reduce_results, local_bufs) = MakeWarpAllreduce(
values, types, combiner, reduce_index, warp_size_, group_index, mask, NullOpt, &seq);
new_alloc_bufs.insert(new_alloc_bufs.end(), local_bufs.begin(), local_bufs.end());
// 3. Write allreduce results to staging buffer.
std::vector<Stmt> write_staging_buf;
write_staging_buf.reserve(size);
for (size_t i = 0; i < size; ++i) {
new_alloc_bufs.push_back(Downcast<BufferLoad>(reduce_results[i])->buffer);
write_staging_buf.push_back(BufferStore(
/*buffer=*/staging_shared_bufs[i],
/*value=*/reduce_results[i],
/*indices=*/{group_index * n_warps + floordiv(reduce_index, warp_size_)}));
}
PrimExpr cond = floormod(reduce_index, warp_size_) == zero_index;
seq.push_back(IfThenElse(cond, SeqStmt::Flatten(write_staging_buf)));
seq.push_back(SyncThread("shared"));
// 4. Load staging buffer.
// Second round of allreduce.
for (size_t i = 0; i < size; ++i) {
values[i] = BufferLoad(/*buffer=*/staging_shared_bufs[i],
/*indices=*/{group_index * n_warps + reduce_index});
}
std::tie(reduce_results, local_bufs) = MakeWarpAllreduce(
values, types, combiner, reduce_index, n_warps, group_index, mask,
/*predicate=*/reduce_index < make_const(reduce_index->dtype, n_warps), &seq);
new_alloc_bufs.insert(new_alloc_bufs.end(), local_bufs.begin(), local_bufs.end());
// 5. Create shared memory buffer(s) of `group_extent` elements, storing
// the allreduce results so each thread can access.
std::vector<Stmt> write_result;
write_result.reserve(size);
for (size_t i = 0; i < size; ++i) {
new_alloc_bufs.push_back(Downcast<BufferLoad>(reduce_results[i])->buffer);
Buffer broadcast_shared_buf = decl_buffer(
/*shape=*/{make_const(reduce_index->dtype, group_extent)},
/*dtype=*/buffers[i]->dtype, /*name=*/"red_result", /*storage_scope=*/"shared");
write_result.push_back(
BufferStore(broadcast_shared_buf, reduce_results[i], {group_index}));
// Update `reduce_results`, pointing to the value loaded from the shared memory buffer.
reduce_results[i] = BufferLoad(broadcast_shared_buf, {group_index});
}
seq.push_back(IfThenElse(reduce_index == zero_index, SeqStmt::Flatten(write_result)));
seq.push_back(SyncThread("shared"));
}
// Write back allreduce results and update existing allocations.
for (size_t i = 0; i < size; ++i) {
ICHECK(!load_remap_.count(buffers[i]->data.get()));
PrimExpr pred = const_true(types[i].lanes());
Buffer buf = Downcast<BufferLoad>(reduce_results[i])->buffer;
ICHECK_EQ(reduce_results[i]->dtype, types[i]);
load_remap_[buffers[i]->data.get()] = reduce_results[i];
auto node = Allocate(buf->data, types[i], buf->shape, pred, Evaluate(0));
alloc_remap_[buffers[i]->data.get()] = buf;
var_remap_[buffers[i]->data.get()] = buf->data;
buf_remap_[buffers[i].get()] = buf;
}
} else {
std::vector<Buffer> shared_bufs(size);
if (reduce_extent == 1) {
// special case, no reduction is needed.
std::vector<Stmt> stores;
for (size_t i = 0; i < size; ++i) {
stores.push_back(BufferStore(buffers[i], values[i], {0}));
}
return SeqStmt::Flatten(stores);
}
// This sync is necessary because there might be incomplete read of
// previous iteration on the same buffer.
seq.emplace_back(SyncThread("shared"));
for (size_t idx = 0; idx < size; ++idx) {
shared_bufs[idx] = decl_buffer({IntImm(group_index->dtype, group_extent * reduce_extent)},
types[idx], "red_buf" + std::to_string(idx), "shared");
seq.emplace_back(BufferStore(shared_bufs[idx], values[idx],
{BufIndex(reduce_index, group_index, reduce_extent)}));
}
seq.emplace_back(SyncThread("shared"));
seq.emplace_back(MakeBufAllreduce(combiner, types, shared_bufs, reduce_index, group_index,
reduce_extent, group_extent, contiguous_reduce_extent));
for (size_t idx = 0; idx < size; ++idx) {
ICHECK(!load_remap_.count(buffers[idx]->data.get()));
PrimExpr pred = const_true(types[idx].lanes());
BufferLoad load(shared_bufs[idx],
{BufIndex(make_zero(reduce_index.dtype()), group_index, reduce_extent)});
ICHECK_EQ(load->dtype, types[idx]);
load_remap_[buffers[idx]->data.get()] = load;
alloc_remap_[buffers[idx]->data.get()] = shared_bufs[idx];
var_remap_[buffers[idx]->data.get()] = shared_bufs[idx]->data;
buf_remap_[buffers[idx].get()] = shared_bufs[idx];
}
}
// Fix all local allocations as all statements are built.
Stmt body = SeqStmt::Flatten(seq);
for (Buffer buf : new_alloc_bufs) {
body = DeclBuffer(buf, body);
body = Allocate(buf->data, buf->dtype, buf->shape, const_true(buf->dtype.lanes()), body);
}
return body;
}