void CudaAllreduceHalvingDoubling::run()

in gloo/cuda_allreduce_halving_doubling.cc [246:408]


void CudaAllreduceHalvingDoubling<T, W>::run() {
  CudaDeviceGuard guard;
  CudaStream& stream = *scratchStream_;
  size_t bufferOffset = 0;
  size_t numItems = stepsWithinBlock_ > 0 ? chunkSize_ << (steps_ - 1) : count_;

  if (pipelined_ && reduceBeforeFirstSend_) {
    reduceBeforeFirstSend_->run();
  } else if (localReduceOp_) {
    localReduceOp_->run();
  }

  if (this->contextSize_ == 1) {
    GLOO_ENFORCE(localBroadcastOp_,
            "localBroadcastOp must be initialized for single machine");
    localBroadcastOp_->run();
    return;
  }

  // Reduce-scatter
  for (int i = 0; i < stepsWithinBlock_; i++) {
    if (sendOffsets_[i] < count_) {
      sendDataBufs_[i]->send(
          sendOffsets_[i] * sizeof(T), sendCounts_[i] * sizeof(T));
    }
    if (recvOffsets_[i] < count_) {
      if (pipelined_ && i == 0 && reduceBeforeFirstRecv_) {
        reduceBeforeFirstRecv_->runAsync();
      }
      recvDataBufs_[i]->waitRecv();
      if (pipelined_ && i == 0 && reduceBeforeFirstRecv_) {
        reduceBeforeFirstRecv_->wait();
      }
      auto recvBufAtOffset = recvBuf_.range(bufferOffset, recvCounts_[i]);
      auto scratchAtOffset = scratch_.range(recvOffsets_[i], recvCounts_[i]);
      fn_->call(scratchAtOffset, recvBufAtOffset, recvCounts_[i], stream);
      stream.wait();
    }
    sendNotificationBufs_[i]->send();
    bufferOffset += numItems;
    if (i != stepsWithinBlock_ - 1) {
      numItems >>= 1;
    }
  }

  // Communication across binary blocks for non-power-of-two number of
  // processes

  // receive from smaller block
  // data sizes same as in the last step of intrablock reduce-scatter above
  if (nextSmallerBlockSize_ != 0 && smallerBlockRecvDataBuf_ != nullptr) {
    smallerBlockRecvDataBuf_->waitRecv();
    auto recvBufAtOffset =
        recvBuf_.range(bufferOffset, recvCounts_[stepsWithinBlock_ - 1]);
    auto scratchAtOffset = scratch_.range(
        recvOffsets_[stepsWithinBlock_ - 1],
        recvCounts_[stepsWithinBlock_ - 1]);
    fn_->call(
        scratchAtOffset,
        recvBufAtOffset,
        recvCounts_[stepsWithinBlock_ - 1],
        stream);
    stream.wait();
  }

  const auto totalItemsToSend =
      stepsWithinBlock_ > 0 ? recvCounts_[stepsWithinBlock_ - 1] : count_;
  if (nextLargerBlockSize_ != 0 && totalItemsToSend != 0) {
    // scatter to larger block
    const auto offset =
        stepsWithinBlock_ > 0 ? recvOffsets_[stepsWithinBlock_ - 1] : 0;
    const auto numSendsAndReceivesToLargerBlock =
        nextLargerBlockSize_ / myBinaryBlockSize_;
    for (int i = 0; i < numSendsAndReceivesToLargerBlock; i++) {
      if (sendCountToLargerBlock_ * i < totalItemsToSend) {
        largerBlockSendDataBufs_[i]->send(
            (offset + i * sendCountToLargerBlock_) * sizeof(T),
            std::min(
                sendCountToLargerBlock_,
                totalItemsToSend - sendCountToLargerBlock_ * i) *
                sizeof(T));
      }
    }
    // no notification is needed because the forward and backward messages
    // across blocks are serialized in relation to each other

    // receive from larger blocks
    for (int i = 0; i < numSendsAndReceivesToLargerBlock; i++) {
      if (sendCountToLargerBlock_ * i < totalItemsToSend) {
        largerBlockRecvDataBufs_[i]->waitRecv();
      }
    }
    auto recvBufAtOffset = recvBuf_.range(bufferOffset, totalItemsToSend);
    auto scratchAtOffset = scratch_.range(offset, totalItemsToSend);
    // msg from larger block is the final result, no reduce needed
    stream.copyAsync(scratchAtOffset, recvBufAtOffset);
    stream.wait();
  }

  // Send to smaller block (technically the beginning of allgather)
  bool sentToSmallerBlock = false;
  if (nextSmallerBlockSize_ != 0) {
    if (recvOffsets_[stepsWithinBlock_ - 1] < count_) {
      sentToSmallerBlock = true;
      smallerBlockSendDataBuf_->send(
          recvOffsets_[stepsWithinBlock_ - 1] * sizeof(T),
          recvCounts_[stepsWithinBlock_ - 1] * sizeof(T));
    }
  }

  // Allgather
  numItems = chunkSize_ << (steps_ - stepsWithinBlock_);
  for (int i = stepsWithinBlock_ - 1; i >= 0; i--) {
    // verify that destination rank has received and processed this rank's
    // message during the reduce-scatter phase
    recvNotificationBufs_[i]->waitRecv();
    if (recvOffsets_[i] < count_) {
      sendDataBufs_[i]->send(
          recvOffsets_[i] * sizeof(T), recvCounts_[i] * sizeof(T));
    }
    bufferOffset -= numItems;
    if (sendOffsets_[i] < count_) {
      recvDataBufs_[i]->waitRecv();
      auto recvBufAtOffset = recvBuf_.range(bufferOffset, sendCounts_[i]);
      auto scratchAtOffset = scratch_.range(sendOffsets_[i], sendCounts_[i]);
      stream.copyAsync(scratchAtOffset, recvBufAtOffset);
      stream.wait();
    }
    if (pipelined_ && broadcastOps_[i]) {
      broadcastOps_[i]->runAsync();
    }
    numItems <<= 1;

    // Send notification to the pair we just received from that
    // we're done dealing with the receive buffer.
    sendNotificationBufs_[i]->send();
  }

  if (pipelined_ && stepsWithinBlock_ > 0) {
    for (int i = stepsWithinBlock_ - 1; i >= 0; i--) {
      if (broadcastOps_[i]) {
        broadcastOps_[i]->wait();
      }
    }
  } else if (localBroadcastOp_) {
    localBroadcastOp_->runAsync();
    localBroadcastOp_->wait();
  }

  // Wait for notifications from our peers within the block to make
  // sure we can send data immediately without risking overwriting
  // data in its receive buffer before it consumed that data.
  for (int i = stepsWithinBlock_ - 1; i >= 0; i--) {
    recvNotificationBufs_[i]->waitRecv();
  }

  // We have to be sure the send to the smaller block (if any) has
  // completed before returning. If we don't, the buffer contents may
  // be modified by our caller.
  if (sentToSmallerBlock) {
    smallerBlockSendDataBuf_->waitSend();
  }
}