c10::intrusive_ptr DeviceFairring::allReduce()

in fairring/device.cc [273:327]


c10::intrusive_ptr<c10::ivalue::Future> DeviceFairring::allReduce(
    c10d::ReduceOp opType,
    at::Tensor tensor) {
  at::cuda::CUDAEvent initialEvent;
  initialEvent.record(c10::cuda::getCurrentCUDAStream(myDeviceIdxOnProcess_));

  c10::intrusive_ptr<c10::ivalue::Future> future =
      c10::make_intrusive<c10::ivalue::Future>(
          c10::ListType::ofTensors(),
          std::vector<c10::Device>(
              {c10::Device(c10::kCUDA, myDeviceIdxOnProcess_)}));

  cmdQueue_.enqueue([this,
                     tensor = std::move(tensor),
                     initialEvent = std::make_shared<at::cuda::CUDAEvent>(
                         std::move(initialEvent)),
                     future]() mutable {
    int64_t numElements = tensor.numel();
    MY_CHECK(kAlignment % tensor.element_size() == 0);
    int64_t maxSliceSizeInElems =
        layout_.sliceSizeInBytes / tensor.element_size();
    int64_t numSlices = ceilOfRatio(numElements, maxSliceSizeInElems);
    for (const auto sliceIdx : c10::irange(numSlices)) {
      int64_t seqNum = nextSlot_++;
      int64_t offsetInElems = sliceIdx * maxSliceSizeInElems;
      int64_t sliceSizeInElems =
          std::min(maxSliceSizeInElems, numElements - offsetInElems);
      at::Tensor slice = tensor.slice(
          /*dim=*/0, offsetInElems, offsetInElems + sliceSizeInElems);
      auto myFuture = sliceIdx == numSlices - 1
          ? std::move(future)
          : c10::intrusive_ptr<c10::ivalue::Future>();
      try {
        allReduceOneSlice(
            std::move(slice),
            sliceIdx == 0
                ? c10::optional<at::cuda::CUDAEvent>(std::move(*initialEvent))
                : c10::nullopt);
        if (myFuture) {
          c10::cuda::CUDAStreamGuard g(allGatherStream_);
          myFuture->markCompleted(std::vector<at::Tensor>{std::move(tensor)});
        }
      } catch (const std::exception& e) {
        LOG(ERROR) << "Function for chunk #" << seqNum
                   << " threw exception: " << e.what();
        if (myFuture) {
          c10::cuda::CUDAStreamGuard g(allGatherStream_);
          myFuture->setError(std::current_exception());
        }
      }
    }
  });

  return future;
}