in fairring/device.cc [590:657]
void DeviceFairring::reduceScatterOneSlice(
at::Tensor input,
at::Tensor output,
at::cuda::CUDAEvent initialEvent) {
c10::cuda::CUDAGuard g(myDeviceIdxOnProcess_);
at::cuda::CUDAEvent reduceScatterToCollectEvent;
at::cuda::CUDAEvent collectToAddEvent;
MY_CHECK(input.numel() % (numDevicesPerMachine_ * numMachines_) == 0);
MY_CHECK(
input.numel() == output.numel() * numDevicesPerMachine_ * numMachines_);
at::Tensor input3d;
if (deviceGlobalRankIsFavorable_) {
input3d = input.view({numDevicesPerMachine_, numMachines_, -1});
} else {
input3d =
input.view({numMachines_, numDevicesPerMachine_, -1}).transpose(0, 1);
}
MY_CHECK(numDevicesPerMachine_ >= 2);
at::Tensor input3dStaging =
input3d[(myDeviceIdxOnMachine_ + 1) % numDevicesPerMachine_];
initialEvent.block(reduceScatterStream_);
if (numMachines_ == 1) {
MY_CHECK(input3d.is_contiguous());
NCCL_CHECK(ncclReduceScatter(
input3d.data_ptr(),
output.data_ptr(),
output.numel(),
torchToNcclDtype(output.scalar_type()),
ncclSum,
reduceScatterComm_.get(),
reduceScatterStream_));
} else if (numDevicesPerMachine_ > 1) {
NCCL_CHECK(ncclGroupStart());
doReduceScatter(
input3d,
myDeviceIdxOnMachine_,
reduceScatterComm_,
reduceScatterStream_);
NCCL_CHECK(ncclGroupEnd());
}
reduceScatterToCollectEvent.record(reduceScatterStream_);
reduceScatterToCollectEvent.block(collectStream_);
if (numMachines_ > 1) {
NCCL_CHECK(ncclGroupStart());
doCollect(
input3d[myDeviceIdxOnMachine_],
input3dStaging,
myMachineIdx_,
collectComm_,
collectStream_);
NCCL_CHECK(ncclGroupEnd());
}
collectToAddEvent.record(collectStream_);
collectToAddEvent.block(addStream_);
if (numMachines_ > 1) {
c10::cuda::CUDAStreamGuard g(addStream_);
if (input3d.numel() > 0) {
at::sum_out(output, input3dStaging, {0});
}
}
}