void reduce()

in gloo/reduce.cc [21:247]


void reduce(ReduceOptions& opts) {
  if (opts.elements == 0) {
    return;
  }
  const auto& context = opts.context;
  transport::UnboundBuffer* in = opts.in.get();
  transport::UnboundBuffer* out = opts.out.get();
  const auto slot = Slot::build(kReduceSlotPrefix, opts.tag);

  // Sanity checks
  GLOO_ENFORCE(opts.elementSize > 0);
  GLOO_ENFORCE(opts.root >= 0 && opts.root < context->size);
  GLOO_ENFORCE(opts.reduce != nullptr);
  const auto recvRank = (context->size + context->rank + 1) % context->size;
  GLOO_ENFORCE(
      recvRank == context->rank || context->getPair(recvRank),
      "missing connection between rank " + std::to_string(context->rank) +
          " (this process) and rank " + std::to_string(recvRank));
  const auto sendRank = (context->size + context->rank - 1) % context->size;
  GLOO_ENFORCE(
      sendRank == context->rank || context->getPair(sendRank),
      "missing connection between rank " + std::to_string(context->rank) +
          " (this process) and rank " + std::to_string(sendRank));

  // If input buffer is not specified, the output is also the input
  if (in == nullptr) {
    in = out;
  }

  GLOO_ENFORCE_EQ(in->size, opts.elements * opts.elementSize);
  GLOO_ENFORCE_EQ(out->size, opts.elements * opts.elementSize);

  // Short circuit if there is only a single process.
  if (context->size == 1) {
    if (in != out) {
      memcpy(out->ptr, in->ptr, opts.elements * opts.elementSize);
    }
    return;
  }

  // The ring algorithm works as follows.
  //
  // The given input is split into a number of chunks equal to the
  // number of processes. Once the algorithm has finished, every
  // process hosts one chunk of reduced output, in sequential order
  // (rank 0 has chunk 0, rank 1 has chunk 1, etc.). As the input may
  // not be divisible by the number of processes, the chunk on the
  // final ranks may have partial output or may be empty.
  //
  // As a chunk is passed along the ring and contains the reduction of
  // successively more ranks, we have to alternate between performing
  // I/O for that chunk and computing the reduction between the
  // received chunk and the local chunk. To avoid this alternating
  // pattern, we split up a chunk into multiple segments (>= 2), and
  // ensure we have one segment in flight while computing a reduction
  // on the other. The segment size has an upper bound to minimize
  // memory usage and avoid poor cache behavior. This means we may
  // have many segments per chunk when dealing with very large inputs.
  //
  // The nomenclature here is reflected in the variable naming below
  // (one chunk per rank and many segments per chunk).
  //
  const size_t totalBytes = opts.elements * opts.elementSize;

  // Ensure that maximum segment size is a multiple of the element size.
  // Otherwise, the segment size can exceed the maximum segment size after
  // rounding it up to the nearest multiple of the element size.
  // For example, if maxSegmentSize = 10, and elementSize = 4,
  // then after rounding up: segmentSize = 12;
  const size_t maxSegmentSize =
      opts.elementSize * (opts.maxSegmentSize / opts.elementSize);

  // The number of bytes per segment must be a multiple of the bytes
  // per element for the reduction to work; round up if necessary.
  const size_t segmentBytes = roundUp(
      std::min(
          // Rounded division to have >= 2 segments per chunk.
          (totalBytes + (context->size * 2 - 1)) / (context->size * 2),
          // Configurable segment size limit
          maxSegmentSize),
      opts.elementSize);

  // Compute how many segments make up the input buffer.
  //
  // Round up to the nearest multiple of the context size such that
  // there is an equal number of segments per process and execution is
  // symmetric across processes.
  //
  // The minimum is twice the context size, because the algorithm
  // below overlaps sending/receiving a segment with computing the
  // reduction of the another segment.
  //
  const size_t numSegments = roundUp(
      std::max(
          (totalBytes + (segmentBytes - 1)) / segmentBytes,
          (size_t)context->size * 2),
      (size_t)context->size);
  GLOO_ENFORCE_EQ(numSegments % context->size, 0);
  GLOO_ENFORCE_GE(numSegments, context->size * 2);
  const size_t numSegmentsPerRank = numSegments / context->size;
  const size_t chunkBytes = numSegmentsPerRank * segmentBytes;

  // Allocate scratch space to hold two chunks
  std::unique_ptr<uint8_t[]> tmpAllocation(new uint8_t[segmentBytes * 2]);
  std::unique_ptr<transport::UnboundBuffer> tmpBuffer =
      context->createUnboundBuffer(tmpAllocation.get(), segmentBytes * 2);
  transport::UnboundBuffer* tmp = tmpBuffer.get();

  // Use dynamic lookup for chunk offset in the temporary buffer.
  // With two operations in flight we need two offsets.
  // They can be indexed using the loop counter.
  std::array<size_t, 2> segmentOffset;
  segmentOffset[0] = 0;
  segmentOffset[1] = segmentBytes;

  // Function computes the offsets and lengths of the chunks to be
  // sent and received for a given chunk iteration.
  auto computeReduceScatterOffsets = [&](size_t i) {
    struct {
      size_t sendOffset;
      size_t recvOffset;
      ssize_t sendLength;
      ssize_t recvLength;
    } result;

    // Compute segment index to send from (to rank - 1) and segment
    // index to receive into (from rank + 1). Multiply by the number
    // of bytes in a chunk to get to an offset. The offset is allowed
    // to be out of range (>= totalBytes) and this is taken into
    // account when computing the associated length.
    result.sendOffset =
        ((((context->rank + 1) * numSegmentsPerRank) + i) * segmentBytes) %
        (numSegments * segmentBytes);
    result.recvOffset =
        ((((context->rank + 2) * numSegmentsPerRank) + i) * segmentBytes) %
        (numSegments * segmentBytes);

    // If the segment is entirely in range, the following statement is
    // equal to segmentBytes. If it isn't, it will be less, or even
    // negative. This is why the ssize_t typecasts are needed.
    result.sendLength = std::min(
        (ssize_t)segmentBytes,
        (ssize_t)totalBytes - (ssize_t)result.sendOffset);
    result.recvLength = std::min(
        (ssize_t)segmentBytes,
        (ssize_t)totalBytes - (ssize_t)result.recvOffset);

    return result;
  };

  for (auto i = 0; i < numSegments; i++) {
    if (i >= 2) {
      // Compute send and receive offsets and lengths two iterations
      // ago. Needed so we know when to wait for an operation and when
      // to ignore (when the offset was out of bounds), and know where
      // to reduce the contents of the temporary buffer.
      auto prev = computeReduceScatterOffsets(i - 2);
      if (prev.recvLength > 0) {
        tmp->waitRecv(opts.timeout);
        opts.reduce(
            static_cast<uint8_t*>(out->ptr) + prev.recvOffset,
            static_cast<const uint8_t*>(in->ptr) + prev.recvOffset,
            static_cast<const uint8_t*>(tmp->ptr) + segmentOffset[i & 0x1],
            prev.recvLength / opts.elementSize);
      }
      if (prev.sendLength > 0) {
        if ((i - 2) < numSegmentsPerRank) {
          in->waitSend(opts.timeout);
        } else {
          out->waitSend(opts.timeout);
        }
      }
    }

    // Issue new send and receive operation in all but the final two
    // iterations. At that point we have already sent all data we
    // needed to and only have to wait for the final segments to be
    // reduced into the output.
    if (i < (numSegments - 2)) {
      // Compute send and receive offsets and lengths for this iteration.
      auto cur = computeReduceScatterOffsets(i);
      if (cur.recvLength > 0) {
        tmp->recv(recvRank, slot, segmentOffset[i & 0x1], cur.recvLength);
      }
      if (cur.sendLength > 0) {
        if (i < numSegmentsPerRank) {
          in->send(sendRank, slot, cur.sendOffset, cur.sendLength);
        } else {
          out->send(sendRank, slot, cur.sendOffset, cur.sendLength);
        }
      }
    }
  }

  // Gather to root rank.
  //
  // Beware: totalBytes <= (numSegments * segmentBytes), which is
  // incompatible with the generic gather algorithm where the
  // contribution is identical across processes.
  //
  if (context->rank == opts.root) {
    size_t numRecv = 0;
    for (size_t rank = 0; rank < context->size; rank++) {
      if (rank == context->rank) {
        continue;
      }
      size_t recvOffset = rank * numSegmentsPerRank * segmentBytes;
      ssize_t recvLength = std::min(
          (ssize_t)chunkBytes, (ssize_t)totalBytes - (ssize_t)recvOffset);
      if (recvLength > 0) {
        out->recv(rank, slot, recvOffset, recvLength);
        numRecv++;
      }
    }
    for (size_t i = 0; i < numRecv; i++) {
      out->waitRecv(opts.timeout);
    }
  } else {
    size_t sendOffset = context->rank * numSegmentsPerRank * segmentBytes;
    ssize_t sendLength = std::min(
        (ssize_t)chunkBytes, (ssize_t)totalBytes - (ssize_t)sendOffset);
    if (sendLength > 0) {
      out->send(opts.root, slot, sendOffset, sendLength);
      out->waitSend(opts.timeout);
    }
  }
}