void CudaAllreduceRingChunked::run()

in gloo/cuda_allreduce_ring_chunked.cc [135:279]


void CudaAllreduceRingChunked<T, W>::run() {
  CudaDeviceGuard guard;
  CudaStream& stream = *scratchStream_;

  // Kick off local reduction for each chunk.
  // The result is stored in scratch_ at the corresponding chunk offset.
  // Make sure to iterate over the chunks in the order they will be sent.
  for (auto i = 0; i < chunks_; i++) {
    const auto chunkOffset = getChunkOffset(i);
    if (chunkOffset < chunkContext_.size()) {
      auto& context = chunkContext_[chunkOffset];
      context.reduceOp->runAsync();
    }
  }

  if (this->contextSize_ == 1) {

    // Wait for the local reduction to complete then broadcast chunk to devices
    for (auto i = 0; i < chunks_; i++) {
      const auto chunkOffset = getChunkOffset(i);
      if (chunkOffset < chunkContext_.size()) {
        auto& context = chunkContext_[chunkOffset];
        context.reduceOp->wait();
        context.broadcastOp->runAsync();
      }
    }

    // Wait for broadcast to complete
    for (auto i = 0; i < chunks_; i++) {
      const auto chunkOffset = getChunkOffset(i);
      if (chunkOffset < chunkContext_.size()) {
        auto& context = chunkContext_[chunkOffset];
        context.broadcastOp->wait();
      }
    }
    return;
  }

  // First pass reduces a chunk in each round
  for (auto round = 0; round < chunks_; round++) {
    const auto chunkOffset = getChunkOffset(round);

    if (chunkOffset < chunkContext_.size()) {
      auto& context = chunkContext_[chunkOffset];

      // Wait for the local reduction to complete
      // When using the host workspace this also makes sure the reduction
      // result is copied into the host side scratch buffer.
      context.reduceOp->wait();

      // Reduce chunk from previous round. Nothing to do for initial rounds.
      if (round >= 2) {
        // Wait for inbox write to complete
        recvDataBuf_[chunkOffset & 1]->waitRecv();

        // Reduce
        fn_->call(
            context.scratch,
            inbox_[chunkOffset & 1],
            context.scratch.getCount(),
            stream);
        stream.wait();
      }
    } else {
      // Empty chunk but still need to wait on the inbox write to ensure the
      // algorithm progresses. Nothing to do for initial rounds.
      if (round >= 2) {
        recvDataBuf_[chunkOffset & 1]->waitRecv();
      }
    }

    // Skip buffer passing notifications in initial rounds
    if (round >= 2) {
      // Send notification to node on the left that
      // this node is ready for an inbox write.
      sendNotificationBuf_->send();

      // Wait for notification from node on the right
      // to be sure this node can start an inbox write.
      recvNotificationBuf_->waitRecv();
    }

    // Copy accumulated chunk
    copyChunkAtOffset(chunkOffset);
  }

  // Second pass around the ring to broadcast result.
  for (int round = 0; round < chunks_; round++) {
    const auto chunkOffset = getChunkOffset(round);

    if (chunkOffset < chunkContext_.size()) {
      auto& context = chunkContext_[chunkOffset];

      // End at chunks_-2 since that's where the accumulation
      // stopped in the previous set of rounds.
      if (round < (chunks_ - 2)) {
        // Wait for inbox write to complete
        recvDataBuf_[chunkOffset & 1]->waitRecv();

        // Copy chunk from inbox to scratch space
        stream.copyAsync(context.scratch, inbox_[chunkOffset & 1]);
        stream.wait();
      }

      // Broadcast chunk to devices. Do this in all rounds with non-empty chunk.
      context.broadcastOp->runAsync();
    } else {
      // Empty chunk but still need to wait on the inbox write to ensure the
      // algorithm progresses.
      if (round < (chunks_ - 2)) {
        recvDataBuf_[chunkOffset & 1]->waitRecv();
      }
    }

    // Skip copying in the last two rounds
    if (round < (chunks_ - 4)) {
      // Send notification to node on the left that
      // this node is ready for an inbox write.
      sendNotificationBuf_->send();

      // Wait for notification from node on the right
      // to be sure this node can start an inbox write.
      recvNotificationBuf_->waitRecv();

      // Copy accumulated chunks
      copyChunkAtOffset(chunkOffset);
    }
  }

  // Final barrier to make sure every node has finished
  // Otherwise, a second all reduce call might interfere
  // with one that it still in progress on some nodes.
  sendNotificationBuf_->send();
  recvNotificationBuf_->waitRecv();

  // If running synchronously, wait for all chunk broadcasts to complete
  if (synchronizeDeviceOutputs_) {
    for (auto i = 0; i < chunks_; i++) {
      const auto chunkOffset = getChunkOffset(i);
      if (chunkOffset < chunkContext_.size()) {
        chunkContext_[chunkOffset].broadcastOp->wait();
      }
    }
  }
}