void CudaAllreduceHalvingDoubling::initReductionsAndBroadcasts()

in gloo/cuda_allreduce_halving_doubling.cc [531:601]


void CudaAllreduceHalvingDoubling<T, W>::initReductionsAndBroadcasts(
    typename std::enable_if<
        std::is_same<U, CudaHostWorkspace<T>>::value,
        typename U::Pointer>::type*) {
  if (stepsWithinBlock_ == 0) {
    return;
  }
  if (sendCounts_[0] * sizeof(T) < kOnDeviceThreshold) {
    if (!devicePtrsForFirstSend_.empty()) {
      reduceBeforeFirstSend_ = cudaHostReduce(
          streams_,
          devicePtrsForFirstSend_,
          scratchPtrForFirstSend_,
          fn_,
          0,
          sendCounts_[0]);
    }
    if (!devicePtrsForFirstRecv_.empty()) {
      reduceBeforeFirstRecv_ = cudaHostReduce(
          streams_,
          devicePtrsForFirstRecv_,
          scratchPtrForFirstRecv_,
          fn_,
          0,
          recvCounts_[0]);
    }
  } else {
    if (!devicePtrsForFirstSend_.empty()) {
      reduceBeforeFirstSend_ = cudaDeviceReduce(
          streams_,
          devicePtrsForFirstSend_,
          scratchPtrForFirstSend_,
          fn_,
          0,
          sendCounts_[0]);
    }
    if (!devicePtrsForFirstRecv_.empty()) {
      reduceBeforeFirstRecv_ = cudaDeviceReduce(
          streams_,
          devicePtrsForFirstRecv_,
          scratchPtrForFirstRecv_,
          fn_,
          0,
          recvCounts_[0]);
    }
  }
  for (int i = 0; i < stepsWithinBlock_; i++) {
    if (devicePtrsForBroadcast_[i].empty()) {
      broadcastOps_.push_back(nullptr);
      continue;
    }
    const size_t numElementsInBcast = i == stepsWithinBlock_ - 1
        ? sendCounts_[i] + recvCounts_[i]
        : sendCounts_[i];
    if (numElementsInBcast * sizeof(T) < kOnDeviceThreshold) {
      broadcastOps_.push_back(cudaHostBroadcast(
          streams_,
          devicePtrsForBroadcast_[i],
          scratchPtrForBroadcast_[i],
          0,
          numElementsInBcast));
    } else {
      broadcastOps_.push_back(cudaDeviceBroadcast(
          streams_,
          devicePtrsForBroadcast_[i],
          scratchPtrForBroadcast_[i],
          0,
          numElementsInBcast));
    }
  }
}