void CudaAllreduceBcube::run()

in gloo/cuda_allreduce_bcube.cc [119:214]


void CudaAllreduceBcube<T, W>::run() {
  CudaDeviceGuard guard;
  CudaStream& stream = *scratchStream_;

  localReduceOp_->run();

  if (nodes_ == 1) {
    GLOO_ENFORCE(
        localBroadcastOp_,
        "localBroadcastOp must be initialized for single machine");
    localBroadcastOp_->run();
    return;
  }

  // Reduce-scatter
  DEBUG_PRINT_STAGE("start");
  for (int step = 0; step < steps_; ++step) {
    const auto& peerRanks = getPeersPerStep(myRank_, step);
    for (int destRank : peerRanks) {
      int sendCount = getNumElemsPerStep(destRank, step);
      int ptrOffset = getPtrOffsetPerStep(destRank, step);
      DEBUG_PRINT_SEND("reduce-scatter");
      sendDataBufs_[destRank]->send(
          ptrOffset * sizeof(T), sendCount * sizeof(T));
    } // sends within group

    for (int srcRank : peerRanks) {
      int recvCount = getNumElemsPerStep(myRank_, step);
      int ptrOffset = getPtrOffsetPerStep(myRank_, step);
      recvDataBufs_[srcRank]->waitRecv();
      DEBUG_PRINT_RECV("reduce-scatter");
      auto recvBufAtOffset =
          recvBufs_[recvBufIdx_[srcRank]].range(0, recvCount);
      auto scratchAtOffset = scratch_.range(ptrOffset, recvCount);
      fn_->call(scratchAtOffset, recvBufAtOffset, recvCount, stream);
      stream.wait();
      /*
       * Send notification to the pair we just received from that
       * we're done dealing with the receive buffer.
       */
      sendNotificationBufs_[srcRank]->send();
    } // recvs within group and reduces
  } // reduce-scatter steps

  DEBUG_PRINT_STAGE("reduce-scattered");

  // All-gather
  for (int step = steps_ - 1; step >= 0; --step) {
    const auto& peerRanks = getPeersPerStep(myRank_, step);
    for (int destRank : peerRanks) {
      int sendCount = getNumElemsPerStep(myRank_, step);
      int ptrOffset = getPtrOffsetPerStep(myRank_, step);
      /*
       * Wait for notification from the peer to make sure we can send data
       * without risking any overwrites in its receive buffer.
       */
      recvNotificationBufs_[destRank]->waitRecv();
      DEBUG_PRINT_SEND("all-gather");
      sendDataBufs_[destRank]->send(
          ptrOffset * sizeof(T), sendCount * sizeof(T));
    }

    for (int srcRank : peerRanks) {
      int recvCount = getNumElemsPerStep(srcRank, step);
      int ptrOffset = getPtrOffsetPerStep(srcRank, step);
      recvDataBufs_[srcRank]->waitRecv();
      DEBUG_PRINT_RECV("all-gather");
      auto recvBufAtOffset =
          recvBufs_[recvBufIdx_[srcRank]].range(0, recvCount);
      auto scratchAtOffset = scratch_.range(ptrOffset, recvCount);
      stream.copyAsync(scratchAtOffset, recvBufAtOffset);
      stream.wait();
      if (step == 0) {
        /*
         * Send notification to the pair we just received from that
         * we're done dealing with the receive buffer.``
         */
        sendNotificationBufs_[srcRank]->send();
      }
    } // recvs within group and reduces
  } // all-gather steps

  DEBUG_PRINT_STAGE("all-reduced");

  localBroadcastOp_->runAsync();
  localBroadcastOp_->wait();

  /*
   * Wait for notifications from our peers within the block to make
   * sure we can send data immediately without risking overwriting
   * data in its receive buffer before it consumed that data.
   */
  for (int peerRank : getPeersPerStep(myRank_, 0)) {
    recvNotificationBufs_[peerRank]->waitRecv();
  }
}