in gloo/cuda_allreduce_bcube.cc [119:214]
void CudaAllreduceBcube<T, W>::run() {
CudaDeviceGuard guard;
CudaStream& stream = *scratchStream_;
localReduceOp_->run();
if (nodes_ == 1) {
GLOO_ENFORCE(
localBroadcastOp_,
"localBroadcastOp must be initialized for single machine");
localBroadcastOp_->run();
return;
}
// Reduce-scatter
DEBUG_PRINT_STAGE("start");
for (int step = 0; step < steps_; ++step) {
const auto& peerRanks = getPeersPerStep(myRank_, step);
for (int destRank : peerRanks) {
int sendCount = getNumElemsPerStep(destRank, step);
int ptrOffset = getPtrOffsetPerStep(destRank, step);
DEBUG_PRINT_SEND("reduce-scatter");
sendDataBufs_[destRank]->send(
ptrOffset * sizeof(T), sendCount * sizeof(T));
} // sends within group
for (int srcRank : peerRanks) {
int recvCount = getNumElemsPerStep(myRank_, step);
int ptrOffset = getPtrOffsetPerStep(myRank_, step);
recvDataBufs_[srcRank]->waitRecv();
DEBUG_PRINT_RECV("reduce-scatter");
auto recvBufAtOffset =
recvBufs_[recvBufIdx_[srcRank]].range(0, recvCount);
auto scratchAtOffset = scratch_.range(ptrOffset, recvCount);
fn_->call(scratchAtOffset, recvBufAtOffset, recvCount, stream);
stream.wait();
/*
* Send notification to the pair we just received from that
* we're done dealing with the receive buffer.
*/
sendNotificationBufs_[srcRank]->send();
} // recvs within group and reduces
} // reduce-scatter steps
DEBUG_PRINT_STAGE("reduce-scattered");
// All-gather
for (int step = steps_ - 1; step >= 0; --step) {
const auto& peerRanks = getPeersPerStep(myRank_, step);
for (int destRank : peerRanks) {
int sendCount = getNumElemsPerStep(myRank_, step);
int ptrOffset = getPtrOffsetPerStep(myRank_, step);
/*
* Wait for notification from the peer to make sure we can send data
* without risking any overwrites in its receive buffer.
*/
recvNotificationBufs_[destRank]->waitRecv();
DEBUG_PRINT_SEND("all-gather");
sendDataBufs_[destRank]->send(
ptrOffset * sizeof(T), sendCount * sizeof(T));
}
for (int srcRank : peerRanks) {
int recvCount = getNumElemsPerStep(srcRank, step);
int ptrOffset = getPtrOffsetPerStep(srcRank, step);
recvDataBufs_[srcRank]->waitRecv();
DEBUG_PRINT_RECV("all-gather");
auto recvBufAtOffset =
recvBufs_[recvBufIdx_[srcRank]].range(0, recvCount);
auto scratchAtOffset = scratch_.range(ptrOffset, recvCount);
stream.copyAsync(scratchAtOffset, recvBufAtOffset);
stream.wait();
if (step == 0) {
/*
* Send notification to the pair we just received from that
* we're done dealing with the receive buffer.``
*/
sendNotificationBufs_[srcRank]->send();
}
} // recvs within group and reduces
} // all-gather steps
DEBUG_PRINT_STAGE("all-reduced");
localBroadcastOp_->runAsync();
localBroadcastOp_->wait();
/*
* Wait for notifications from our peers within the block to make
* sure we can send data immediately without risking overwriting
* data in its receive buffer before it consumed that data.
*/
for (int peerRank : getPeersPerStep(myRank_, 0)) {
recvNotificationBufs_[peerRank]->waitRecv();
}
}