in gloo/allreduce_ring_chunked.h [83:212]
void run() {
if (count_ == 0) {
return;
}
// Reduce specified pointers into ptrs_[0]
for (int i = 1; i < ptrs_.size(); i++) {
fn_->call(ptrs_[0], ptrs_[i], count_);
}
if (this->contextSize_ == 1) {
// Broadcast ptrs_[0]
for (int i = 1; i < ptrs_.size(); i++) {
memcpy(ptrs_[i], ptrs_[0], bytes_);
}
return;
}
// Kick off copying initial chunks
copyChunkAtOffset(2 * this->contextRank_);
copyChunkAtOffset(2 * this->contextRank_ + 1);
// Start with reduction of previously copied chunk
for (int round = 2; round < chunks_; round++) {
// We loop over all chunks starting at 2, since we just sent two
// chunks to fill both buffers. Imagine a square grid with
// chunks of memory laid out vertically and nodes horizontally.
// The diagonal of this grid marks which nodes sends which
// chunks of memory in the prelude. Processing happens by moving
// this diagonal forward and have it wrap around the edge. This
// means that node with rank 0 at round 2 will process the last
// chunk. This explains why we subtract the round in the offset
// equation below.
//
// Because we're dealing with double buffering in this
// implementation, we have twice the number of chunks and
// process them in pairs. This explains why we ignore the LSB on
// the round number when subtracting it. The LSB is later added
// to flip back and forth between the two buffers for this pair
// of chunks. The number of chunks is finally added to make sure
// we can wrap correctly (no modulo against negative number).
//
auto chunkOffset = ((2 * this->contextRank_) - (round & ~0x1) +
(round & 0x1) + chunks_) %
chunks_;
auto offset = chunkOffset * chunkSize_;
auto length = chunkSize_;
if (offset + length <= count_) {
// Chunk completely in range, copy full chunk.
} else if (offset < count_) {
// Chunk partially in range, copy partial chunk.
length = count_ - offset;
} else {
// Chunk out of range, copy nothing.
length = 0;
}
// Wait for inbox write to complete
recvDataBuf_[chunkOffset & 1]->waitRecv();
// Reduce
if (length > 0) {
fn_->call(&ptrs_[0][offset], inbox_[chunkOffset & 1], length);
}
// Send notification to node on the left that
// this node is ready for an inbox write.
sendNotificationBuf_->send();
// Wait for notification from node on the right
// to be sure this node can start an inbox write.
recvNotificationBuf_->waitRecv();
// Copy accumulated chunk
copyChunkAtOffset(chunkOffset);
}
// Second pass around the ring to broadcast result.
// End at chunks_-2 since that's where the accumulation
// stopped in the previous set of rounds.
for (int round = 0; round < (chunks_ - 2); round++) {
auto chunkOffset = ((2 * this->contextRank_) - (round & ~0x1) +
(round & 0x1) + chunks_) %
chunks_;
auto offset = chunkOffset * chunkSize_;
auto length = chunkSize_;
if (offset + length <= count_) {
// Chunk completely in range, copy full chunk.
} else if (offset < count_) {
// Chunk partially in range, copy partial chunk.
length = count_ - offset;
} else {
// Chunk out of range, copy nothing.
length = 0;
}
// Wait for inbox write to complete
recvDataBuf_[chunkOffset & 1]->waitRecv();
// Copy
if (length > 0) {
memcpy(&ptrs_[0][offset], inbox_[chunkOffset & 1], length * sizeof(T));
}
// Skip copying in the last two rounds
if (round < (chunks_ - 4)) {
// Send notification to node on the left that
// this node is ready for an inbox write.
sendNotificationBuf_->send();
// Wait for notification from node on the right
// to be sure this node can start an inbox write.
recvNotificationBuf_->waitRecv();
// Copy accumulated chunks
copyChunkAtOffset(chunkOffset);
}
}
// Final barrier to make sure every node has finished
// Otherwise, a second all reduce call might interfere
// with one that it still in progress on some nodes.
sendNotificationBuf_->send();
recvNotificationBuf_->waitRecv();
// Broadcast ptrs_[0]
for (int i = 1; i < ptrs_.size(); i++) {
memcpy(ptrs_[i], ptrs_[0], bytes_);
}
}