in gloo/reduce_scatter.h [322:435]
void run() {
size_t bufferOffset = 0;
size_t numItems =
stepsWithinBlock_ > 0 ? chunkSize_ << (steps_ - 1) : count_;
for (int i = 1; i < ptrs_.size(); i++) {
fn_->call(ptrs_[0], ptrs_[i], count_);
}
if (this->contextSize_ == 1) {
// Broadcast ptrs_[0]
for (int i = 1; i < ptrs_.size(); i++) {
memcpy(ptrs_[i], ptrs_[0], bytes_);
}
return;
}
// Reduce-scatter (within binary block).
for (int i = 0; i < stepsWithinBlock_; i++) {
if (sendOffsets_[i] < count_) {
sendDataBufs_[i]->send(
sendOffsets_[i] * sizeof(T), sendCounts_[i] * sizeof(T));
}
if (recvOffsets_[i] < count_) {
recvDataBufs_[i]->waitRecv();
fn_->call(
&ptrs_[0][recvOffsets_[i]],
&recvBuf_[bufferOffset],
recvCounts_[i]);
}
bufferOffset += numItems;
sendNotificationBufs_[i]->send();
numItems >>= 1;
}
// Communication across binary blocks for non-power-of-two number of
// processes
// receive from smaller block
// data sizes same as in the last step of intrablock reduce-scatter above
int sendNotifyOffset = stepsWithinBlock_;
if (nextSmallerBlockSize_ != 0 && smallerBlockRecvDataBuf_ != nullptr) {
smallerBlockRecvDataBuf_->waitRecv();
fn_->call(
&ptrs_[0][recvOffsets_[stepsWithinBlock_ - 1]],
&recvBuf_[bufferOffset],
recvCounts_[stepsWithinBlock_ - 1]);
}
const auto totalItemsToSend =
stepsWithinBlock_ > 0 ? recvCounts_[stepsWithinBlock_ - 1] : count_;
if (nextLargerBlockSize_ != 0 && totalItemsToSend != 0) {
// scatter to larger block
const auto offset =
stepsWithinBlock_ > 0 ? recvOffsets_[stepsWithinBlock_ - 1] : 0;
const auto numSendsAndReceivesToLargerBlock =
nextLargerBlockSize_ / myBinaryBlockSize_;
for (int i = 0; i < numSendsAndReceivesToLargerBlock; i++) {
if (sendCountToLargerBlock_ * i < totalItemsToSend) {
largerBlockSendDataBufs_[i]->send(
(offset + i * sendCountToLargerBlock_) * sizeof(T),
std::min(
sendCountToLargerBlock_,
totalItemsToSend - sendCountToLargerBlock_ * i) *
sizeof(T));
}
}
}
// Distribution phase: Scatter/distribute based on user specified
// distribution.
int index = 0;
for (const auto& distMap : distMapForSend_) {
const auto myRank = this->context_->rank;
const int destRank = distMap.rank;
if (myRank != destRank) {
distSendDataBufs_[index++]->send(
distMap.offset * sizeof(T), distMap.itemCount * sizeof(T));
}
}
index = 0;
bufferOffset = 0;
for (const auto& distMap : distMapForRecv_) {
const auto myRank = this->context_->rank;
const int srcRank = distMap.rank;
if (myRank != srcRank) {
distRecvDataBufs_[index++]->waitRecv();
memcpy(
&ptrs_[0][bufferOffset],
&recvBufDist_[distMap.offset],
distMap.itemCount * sizeof(T));
sendNotificationBufs_[sendNotifyOffset++]->send();
} else {
if (myRank != 0) { // Data already in-place for rank 0.
memcpy(
&ptrs_[0][bufferOffset],
&ptrs_[0][distMap.offset],
distMap.itemCount * sizeof(T));
}
}
bufferOffset += distMap.itemCount;
}
// Broadcast ptrs_[0]
for (int i = 1; i < ptrs_.size(); i++) {
memcpy(ptrs_[i], ptrs_[0], bytes_);
}
// Wait for all notifications to make sure we can send data immediately
// without risking overwriting data in its receive buffer before it
// consumed that data.
for (auto& recvNotificationBuf : recvNotificationBufs_) {
recvNotificationBuf->waitRecv();
}
}