in gloo/allreduce.cc [148:393]
void ring(
const detail::AllreduceOptionsImpl& opts,
ReduceRangeFunction reduceInputs,
BroadcastRangeFunction broadcastOutputs) {
const auto& context = opts.context;
const std::vector<std::unique_ptr<transport::UnboundBuffer>>& out = opts.out;
const auto slot = Slot::build(kAllreduceSlotPrefix, opts.tag);
const size_t totalBytes = opts.elements * opts.elementSize;
// Note: context->size > 1
const auto recvRank = (context->size + context->rank + 1) % context->size;
const auto sendRank = (context->size + context->rank - 1) % context->size;
GLOO_ENFORCE(
context->getPair(recvRank),
"missing connection between rank " + std::to_string(context->rank) +
" (this process) and rank " + std::to_string(recvRank));
GLOO_ENFORCE(
context->getPair(sendRank),
"missing connection between rank " + std::to_string(context->rank) +
" (this process) and rank " + std::to_string(sendRank));
// The ring algorithm works as follows.
//
// The given input is split into a number of chunks equal to the
// number of processes. Once the algorithm has finished, every
// process hosts one chunk of reduced output, in sequential order
// (rank 0 has chunk 0, rank 1 has chunk 1, etc.). As the input may
// not be divisible by the number of processes, the chunk on the
// final ranks may have partial output or may be empty.
//
// As a chunk is passed along the ring and contains the reduction of
// successively more ranks, we have to alternate between performing
// I/O for that chunk and computing the reduction between the
// received chunk and the local chunk. To avoid this alternating
// pattern, we split up a chunk into multiple segments (>= 2), and
// ensure we have one segment in flight while computing a reduction
// on the other. The segment size has an upper bound to minimize
// memory usage and avoid poor cache behavior. This means we may
// have many segments per chunk when dealing with very large inputs.
//
// The nomenclature here is reflected in the variable naming below
// (one chunk per rank and many segments per chunk).
//
// Ensure that maximum segment size is a multiple of the element size.
// Otherwise, the segment size can exceed the maximum segment size after
// rounding it up to the nearest multiple of the element size.
// For example, if maxSegmentSize = 10, and elementSize = 4,
// then after rounding up: segmentSize = 12;
const size_t maxSegmentBytes = opts.elementSize *
std::max((size_t)1, opts.maxSegmentSize / opts.elementSize);
// Compute how many segments make up the input buffer.
//
// Round up to the nearest multiple of the context size such that
// there is an equal number of segments per process and execution is
// symmetric across processes.
//
// The minimum is twice the context size, because the algorithm
// below overlaps sending/receiving a segment with computing the
// reduction of the another segment.
//
const size_t numSegments = roundUp(
std::max(
(totalBytes + (maxSegmentBytes - 1)) / maxSegmentBytes,
(size_t)context->size * 2),
(size_t)context->size);
GLOO_ENFORCE_EQ(numSegments % context->size, 0);
GLOO_ENFORCE_GE(numSegments, context->size * 2);
const size_t numSegmentsPerRank = numSegments / context->size;
const size_t segmentBytes =
roundUp((totalBytes + numSegments - 1) / numSegments, opts.elementSize);
// Allocate scratch space to hold two chunks
std::unique_ptr<uint8_t[]> tmpAllocation(new uint8_t[segmentBytes * 2]);
std::unique_ptr<transport::UnboundBuffer> tmpBuffer =
context->createUnboundBuffer(tmpAllocation.get(), segmentBytes * 2);
transport::UnboundBuffer* tmp = tmpBuffer.get();
// Use dynamic lookup for chunk offset in the temporary buffer.
// With two operations in flight we need two offsets.
// They can be indexed using the loop counter.
std::array<size_t, 2> segmentOffset;
segmentOffset[0] = 0;
segmentOffset[1] = segmentBytes;
// Function computes the offsets and lengths of the segments to be
// sent and received for a given iteration during reduce/scatter.
auto computeReduceScatterOffsets = [&](size_t i) {
struct {
size_t sendOffset;
size_t recvOffset;
ssize_t sendLength;
ssize_t recvLength;
} result;
// Compute segment index to send from (to rank - 1) and segment
// index to receive into (from rank + 1). Multiply by the number
// of bytes in a chunk to get to an offset. The offset is allowed
// to be out of range (>= totalBytes) and this is taken into
// account when computing the associated length.
result.sendOffset =
((((context->rank + 1) * numSegmentsPerRank) + i) * segmentBytes) %
(numSegments * segmentBytes);
result.recvOffset =
((((context->rank + 2) * numSegmentsPerRank) + i) * segmentBytes) %
(numSegments * segmentBytes);
// If the segment is entirely in range, the following statement is
// equal to segmentBytes. If it isn't, it will be less, or even
// negative. This is why the ssize_t typecasts are needed.
result.sendLength = std::min(
(ssize_t)segmentBytes,
(ssize_t)totalBytes - (ssize_t)result.sendOffset);
result.recvLength = std::min(
(ssize_t)segmentBytes,
(ssize_t)totalBytes - (ssize_t)result.recvOffset);
return result;
};
// Ring reduce/scatter.
//
// Number of iterations is computed as follows:
// - Take `numSegments` for the total number of segments,
// - Subtract `numSegmentsPerRank` because the final segments hold
// the partial result and must not be forwarded in this phase.
// - Add 2 because we pipeline send and receive operations (we issue
// send/recv operations on iterations 0 and 1 and wait for them to
// complete on iterations 2 and 3).
//
for (auto i = 0; i < (numSegments - numSegmentsPerRank + 2); i++) {
if (i >= 2) {
// Compute send and receive offsets and lengths two iterations
// ago. Needed so we know when to wait for an operation and when
// to ignore (when the offset was out of bounds), and know where
// to reduce the contents of the temporary buffer.
auto prev = computeReduceScatterOffsets(i - 2);
if (prev.recvLength > 0) {
// Prepare out[0]->ptr to hold the local reduction
reduceInputs(prev.recvOffset, prev.recvLength);
// Wait for segment from neighbor.
tmp->waitRecv(opts.timeout);
// Reduce segment from neighbor into out->ptr.
opts.reduce(
static_cast<uint8_t*>(out[0]->ptr) + prev.recvOffset,
static_cast<const uint8_t*>(out[0]->ptr) + prev.recvOffset,
static_cast<const uint8_t*>(tmp->ptr) + segmentOffset[i & 0x1],
prev.recvLength / opts.elementSize);
}
if (prev.sendLength > 0) {
out[0]->waitSend(opts.timeout);
}
}
// Issue new send and receive operation in all but the final two
// iterations. At that point we have already sent all data we
// needed to and only have to wait for the final segments to be
// reduced into the output.
if (i < (numSegments - numSegmentsPerRank)) {
// Compute send and receive offsets and lengths for this iteration.
auto cur = computeReduceScatterOffsets(i);
if (cur.recvLength > 0) {
tmp->recv(recvRank, slot, segmentOffset[i & 0x1], cur.recvLength);
}
if (cur.sendLength > 0) {
// Prepare out[0]->ptr to hold the local reduction for this segment
if (i < numSegmentsPerRank) {
reduceInputs(cur.sendOffset, cur.sendLength);
}
out[0]->send(sendRank, slot, cur.sendOffset, cur.sendLength);
}
}
}
// Function computes the offsets and lengths of the segments to be
// sent and received for a given iteration during allgather.
auto computeAllgatherOffsets = [&](size_t i) {
struct {
size_t sendOffset;
size_t recvOffset;
ssize_t sendLength;
ssize_t recvLength;
} result;
result.sendOffset =
((((context->rank) * numSegmentsPerRank) + i) * segmentBytes) %
(numSegments * segmentBytes);
result.recvOffset =
((((context->rank + 1) * numSegmentsPerRank) + i) * segmentBytes) %
(numSegments * segmentBytes);
// If the segment is entirely in range, the following statement is
// equal to segmentBytes. If it isn't, it will be less, or even
// negative. This is why the ssize_t typecasts are needed.
result.sendLength = std::min(
(ssize_t)segmentBytes,
(ssize_t)totalBytes - (ssize_t)result.sendOffset);
result.recvLength = std::min(
(ssize_t)segmentBytes,
(ssize_t)totalBytes - (ssize_t)result.recvOffset);
return result;
};
// Ring allgather.
//
// Beware: totalBytes <= (numSegments * segmentBytes), which is
// incompatible with the generic allgather algorithm where the
// contribution is identical across processes.
//
// See comment prior to reduce/scatter loop on how the number of
// iterations for this loop is computed.
//
for (auto i = 0; i < (numSegments - numSegmentsPerRank + 2); i++) {
if (i >= 2) {
auto prev = computeAllgatherOffsets(i - 2);
if (prev.recvLength > 0) {
out[0]->waitRecv(opts.timeout);
// Broadcast received segments to output buffers.
broadcastOutputs(prev.recvOffset, prev.recvLength);
}
if (prev.sendLength > 0) {
out[0]->waitSend(opts.timeout);
}
}
// Issue new send and receive operation in all but the final two
// iterations. At that point we have already sent all data we
// needed to and only have to wait for the final segments to be
// sent to the output.
if (i < (numSegments - numSegmentsPerRank)) {
auto cur = computeAllgatherOffsets(i);
if (cur.recvLength > 0) {
out[0]->recv(recvRank, slot, cur.recvOffset, cur.recvLength);
}
if (cur.sendLength > 0) {
out[0]->send(sendRank, slot, cur.sendOffset, cur.sendLength);
// Broadcast first segments to outputs buffers.
if (i < numSegmentsPerRank) {
broadcastOutputs(cur.sendOffset, cur.sendLength);
}
}
}
}
}