in fairring/device.cc [273:327]
c10::intrusive_ptr<c10::ivalue::Future> DeviceFairring::allReduce(
c10d::ReduceOp opType,
at::Tensor tensor) {
at::cuda::CUDAEvent initialEvent;
initialEvent.record(c10::cuda::getCurrentCUDAStream(myDeviceIdxOnProcess_));
c10::intrusive_ptr<c10::ivalue::Future> future =
c10::make_intrusive<c10::ivalue::Future>(
c10::ListType::ofTensors(),
std::vector<c10::Device>(
{c10::Device(c10::kCUDA, myDeviceIdxOnProcess_)}));
cmdQueue_.enqueue([this,
tensor = std::move(tensor),
initialEvent = std::make_shared<at::cuda::CUDAEvent>(
std::move(initialEvent)),
future]() mutable {
int64_t numElements = tensor.numel();
MY_CHECK(kAlignment % tensor.element_size() == 0);
int64_t maxSliceSizeInElems =
layout_.sliceSizeInBytes / tensor.element_size();
int64_t numSlices = ceilOfRatio(numElements, maxSliceSizeInElems);
for (const auto sliceIdx : c10::irange(numSlices)) {
int64_t seqNum = nextSlot_++;
int64_t offsetInElems = sliceIdx * maxSliceSizeInElems;
int64_t sliceSizeInElems =
std::min(maxSliceSizeInElems, numElements - offsetInElems);
at::Tensor slice = tensor.slice(
/*dim=*/0, offsetInElems, offsetInElems + sliceSizeInElems);
auto myFuture = sliceIdx == numSlices - 1
? std::move(future)
: c10::intrusive_ptr<c10::ivalue::Future>();
try {
allReduceOneSlice(
std::move(slice),
sliceIdx == 0
? c10::optional<at::cuda::CUDAEvent>(std::move(*initialEvent))
: c10::nullopt);
if (myFuture) {
c10::cuda::CUDAStreamGuard g(allGatherStream_);
myFuture->markCompleted(std::vector<at::Tensor>{std::move(tensor)});
}
} catch (const std::exception& e) {
LOG(ERROR) << "Function for chunk #" << seqNum
<< " threw exception: " << e.what();
if (myFuture) {
c10::cuda::CUDAStreamGuard g(allGatherStream_);
myFuture->setError(std::current_exception());
}
}
}
});
return future;
}