void CudaAllreduceHalvingDoubling::devicePointerInit()

in gloo/cuda_allreduce_halving_doubling.cc [411:453]


void CudaAllreduceHalvingDoubling<T, W>::devicePointerInit() {
  size_t offset, numElements;

  for (int i = 0; i < stepsWithinBlock_; i++) {
    // in the first broadcast (with step 'steps_ - 1'), include both the local
    // chunk result from reduce-scatter and the first received chunk
    offset = i == stepsWithinBlock_ - 1
        ? std::min(recvOffsets_[i], sendOffsets_[i])
        : sendOffsets_[i];
    numElements = i == stepsWithinBlock_ - 1 ? recvCounts_[i] + sendCounts_[i]
                                             : sendCounts_[i];
    if (offset > count_) {
      scratchPtrForBroadcast_.push_back(typename W::Pointer());
      continue;
    }
    if (offset + numElements > count_) {
      numElements = count_ - offset;
    }

    scratchPtrForBroadcast_.push_back(scratch_.range(offset, numElements));
    for (int j = 0; j < devicePtrs_.size(); j++) {
      devicePtrsForBroadcast_[i].push_back(
          devicePtrs_[j].range(offset, numElements));
    }
  }
  if (sendOffsets_[0] < count_) {
    scratchPtrForFirstSend_ = scratch_.range(sendOffsets_[0], sendCounts_[0]);
  }
  if (recvOffsets_[0] < count_) {
    scratchPtrForFirstRecv_ = scratch_.range(recvOffsets_[0], recvCounts_[0]);
  }

  for (int i = 0; i < devicePtrs_.size(); i++) {
    if (sendOffsets_[0] < count_) {
      devicePtrsForFirstSend_.push_back(
          devicePtrs_[i].range(sendOffsets_[0], sendCounts_[0]));
    }
    if (recvOffsets_[0] < count_) {
      devicePtrsForFirstRecv_.push_back(
          devicePtrs_[i].range(recvOffsets_[0], recvCounts_[0]));
    }
  }
}