in fairring/device.cc [329:365]
c10::intrusive_ptr<c10::ivalue::Future> DeviceFairring::reduceScatter(
at::Tensor input,
at::Tensor output) {
// Because reduce-scatter cannot be sliced
MY_CHECK(numDevicesPerMachine_ >= 2);
at::cuda::CUDAEvent initialEvent;
initialEvent.record(c10::cuda::getCurrentCUDAStream(myDeviceIdxOnProcess_));
c10::intrusive_ptr<c10::ivalue::Future> future =
c10::make_intrusive<c10::ivalue::Future>(
c10::ListType::ofTensors(),
std::vector<c10::Device>(
{c10::Device(c10::kCUDA, myDeviceIdxOnProcess_)}));
cmdQueue_.enqueue([this,
input = std::move(input),
output = std::move(output),
initialEvent = std::make_shared<at::cuda::CUDAEvent>(
std::move(initialEvent)),
future]() mutable {
MY_CHECK(kAlignment % input.element_size() == 0);
int64_t seqNum = nextSlot_++;
try {
reduceScatterOneSlice(input, output, std::move(*initialEvent));
c10::cuda::CUDAStreamGuard g(addStream_);
future->markCompleted(std::vector<at::Tensor>{output});
} catch (const std::exception& e) {
LOG(ERROR) << "Function for chunk #" << seqNum
<< " threw exception: " << e.what();
c10::cuda::CUDAStreamGuard g(addStream_);
future->setError(std::current_exception());
}
});
return future;
}