in gloo/allreduce_bcube.h [338:434]
void run() {
if (totalNumElems_ == 0) {
return;
}
// Local reduce operation
for (int i = 1; i < ptrs_.size(); i++) {
fn_->call(ptrs_[0], ptrs_[i], totalNumElems_);
}
if (nodes_ == 1) {
// Broadcast ptrs_[0]
for (int i = 1; i < ptrs_.size(); i++) {
memcpy(ptrs_[i], ptrs_[0], bytes_);
}
return;
}
// Reduce-scatter
DEBUG_PRINT_STAGE("start");
for (int step = 0; step < steps_; ++step) {
for (int destRank : getPeersPerStep(myRank_, step)) {
int sendCount = getNumElemsPerStep(destRank, step);
int ptrOffset = getPtrOffsetPerStep(destRank, step);
DEBUG_PRINT_SEND("reduce-scatter");
sendDataBufs_[destRank]->send(
ptrOffset * sizeof(T), sendCount * sizeof(T));
} // sends within group
for (int srcRank : getPeersPerStep(myRank_, step)) {
int recvCount = getNumElemsPerStep(myRank_, step);
int ptrOffset = getPtrOffsetPerStep(myRank_, step);
recvDataBufs_[srcRank]->waitRecv();
DEBUG_PRINT_RECV("reduce-scatter");
fn_->call(
&ptrs_[0][ptrOffset],
&recvBufs_[recvBufIdx_[srcRank]][0],
recvCount);
/*
* Send notification to the pair we just received from that
* we're done dealing with the receive buffer.
*/
sendNotificationBufs_[srcRank]->send();
} // recvs within group and reduces
} // reduce-scatter steps
DEBUG_PRINT_STAGE("reduce-scattered");
// All-gather
for (int step = steps_ - 1; step >= 0; --step) {
for (int destRank : getPeersPerStep(myRank_, step)) {
int sendCount = getNumElemsPerStep(myRank_, step);
int ptrOffset = getPtrOffsetPerStep(myRank_, step);
/*
* Wait for notification from the peer to make sure we can send data
* without risking any overwrites in its receive buffer.
*/
recvNotificationBufs_[destRank]->waitRecv();
DEBUG_PRINT_SEND("all-gather");
sendDataBufs_[destRank]->send(
ptrOffset * sizeof(T), sendCount * sizeof(T));
}
for (int srcRank : getPeersPerStep(myRank_, step)) {
int recvCount = getNumElemsPerStep(srcRank, step);
int ptrOffset = getPtrOffsetPerStep(srcRank, step);
recvDataBufs_[srcRank]->waitRecv();
DEBUG_PRINT_RECV("all-gather");
std::memcpy(
&ptrs_[0][ptrOffset],
&recvBufs_[recvBufIdx_[srcRank]][0],
recvCount * sizeof(T));
if (step == 0) {
/*
* Send notification to the pair we just received from that
* we're done dealing with the receive buffer.
*/
sendNotificationBufs_[srcRank]->send();
}
} // recvs within group and reduces
} // all-gather steps
DEBUG_PRINT_STAGE("all-reduced");
// Broadcast ptrs_[0]
for (int i = 1; i < ptrs_.size(); i++) {
memcpy(ptrs_[i], ptrs_[0], bytes_);
}
/*
* Wait for notifications from our peers within the block to make
* sure we can send data immediately without risking overwriting
* data in its receive buffer before it consumed that data.
*/
for (int peerRank : getPeersPerStep(myRank_, 0)) {
recvNotificationBufs_[peerRank]->waitRecv();
}
}