in gloo/cuda_allreduce_ring_chunked.cc [135:279]
void CudaAllreduceRingChunked<T, W>::run() {
CudaDeviceGuard guard;
CudaStream& stream = *scratchStream_;
// Kick off local reduction for each chunk.
// The result is stored in scratch_ at the corresponding chunk offset.
// Make sure to iterate over the chunks in the order they will be sent.
for (auto i = 0; i < chunks_; i++) {
const auto chunkOffset = getChunkOffset(i);
if (chunkOffset < chunkContext_.size()) {
auto& context = chunkContext_[chunkOffset];
context.reduceOp->runAsync();
}
}
if (this->contextSize_ == 1) {
// Wait for the local reduction to complete then broadcast chunk to devices
for (auto i = 0; i < chunks_; i++) {
const auto chunkOffset = getChunkOffset(i);
if (chunkOffset < chunkContext_.size()) {
auto& context = chunkContext_[chunkOffset];
context.reduceOp->wait();
context.broadcastOp->runAsync();
}
}
// Wait for broadcast to complete
for (auto i = 0; i < chunks_; i++) {
const auto chunkOffset = getChunkOffset(i);
if (chunkOffset < chunkContext_.size()) {
auto& context = chunkContext_[chunkOffset];
context.broadcastOp->wait();
}
}
return;
}
// First pass reduces a chunk in each round
for (auto round = 0; round < chunks_; round++) {
const auto chunkOffset = getChunkOffset(round);
if (chunkOffset < chunkContext_.size()) {
auto& context = chunkContext_[chunkOffset];
// Wait for the local reduction to complete
// When using the host workspace this also makes sure the reduction
// result is copied into the host side scratch buffer.
context.reduceOp->wait();
// Reduce chunk from previous round. Nothing to do for initial rounds.
if (round >= 2) {
// Wait for inbox write to complete
recvDataBuf_[chunkOffset & 1]->waitRecv();
// Reduce
fn_->call(
context.scratch,
inbox_[chunkOffset & 1],
context.scratch.getCount(),
stream);
stream.wait();
}
} else {
// Empty chunk but still need to wait on the inbox write to ensure the
// algorithm progresses. Nothing to do for initial rounds.
if (round >= 2) {
recvDataBuf_[chunkOffset & 1]->waitRecv();
}
}
// Skip buffer passing notifications in initial rounds
if (round >= 2) {
// Send notification to node on the left that
// this node is ready for an inbox write.
sendNotificationBuf_->send();
// Wait for notification from node on the right
// to be sure this node can start an inbox write.
recvNotificationBuf_->waitRecv();
}
// Copy accumulated chunk
copyChunkAtOffset(chunkOffset);
}
// Second pass around the ring to broadcast result.
for (int round = 0; round < chunks_; round++) {
const auto chunkOffset = getChunkOffset(round);
if (chunkOffset < chunkContext_.size()) {
auto& context = chunkContext_[chunkOffset];
// End at chunks_-2 since that's where the accumulation
// stopped in the previous set of rounds.
if (round < (chunks_ - 2)) {
// Wait for inbox write to complete
recvDataBuf_[chunkOffset & 1]->waitRecv();
// Copy chunk from inbox to scratch space
stream.copyAsync(context.scratch, inbox_[chunkOffset & 1]);
stream.wait();
}
// Broadcast chunk to devices. Do this in all rounds with non-empty chunk.
context.broadcastOp->runAsync();
} else {
// Empty chunk but still need to wait on the inbox write to ensure the
// algorithm progresses.
if (round < (chunks_ - 2)) {
recvDataBuf_[chunkOffset & 1]->waitRecv();
}
}
// Skip copying in the last two rounds
if (round < (chunks_ - 4)) {
// Send notification to node on the left that
// this node is ready for an inbox write.
sendNotificationBuf_->send();
// Wait for notification from node on the right
// to be sure this node can start an inbox write.
recvNotificationBuf_->waitRecv();
// Copy accumulated chunks
copyChunkAtOffset(chunkOffset);
}
}
// Final barrier to make sure every node has finished
// Otherwise, a second all reduce call might interfere
// with one that it still in progress on some nodes.
sendNotificationBuf_->send();
recvNotificationBuf_->waitRecv();
// If running synchronously, wait for all chunk broadcasts to complete
if (synchronizeDeviceOutputs_) {
for (auto i = 0; i < chunks_; i++) {
const auto chunkOffset = getChunkOffset(i);
if (chunkOffset < chunkContext_.size()) {
chunkContext_[chunkOffset].broadcastOp->wait();
}
}
}
}