in fairring/device.cc [405:588]
void DeviceFairring::allReduceOneSlice(
at::Tensor slice,
c10::optional<at::cuda::CUDAEvent> initialEvent) {
c10::cuda::CUDAGuard g(myDeviceIdxOnProcess_);
c10::ScalarType dtype = slice.scalar_type();
int64_t elementSizeInBytes = slice.element_size();
at::cuda::CUDAEvent reduceScatterToCollectEvent;
at::cuda::CUDAEvent collectToAddEvent;
at::cuda::CUDAEvent addToDiffuseEvent;
at::cuda::CUDAEvent diffuseToAllGatherEvent;
at::Tensor slice3d;
c10::optional<at::Tensor> padding;
at::cuda::CUDAEvent* paddingEvent = nullptr;
if (slice.numel() % (numDevicesPerMachine_ * numMachines_) == 0) {
slice3d = slice.view({numDevicesPerMachine_, numMachines_, -1});
} else {
int64_t sliceSizeInElems = roundDownToNearestMultiple(
slice.numel(), numDevicesPerMachine_ * numMachines_);
slice3d = slice.index({torch::indexing::Slice(0, sliceSizeInElems)})
.view({numDevicesPerMachine_, numMachines_, -1});
int64_t paddingSlotIdx = (nextPaddingSlot_++) % layout_.numPaddingSlots;
padding = paddingBuffer_[paddingSlotIdx]
.view(dtype)
.flatten()
.index({torch::indexing::Slice(
0, numDevicesPerMachine_ * numMachines_)})
.view({numDevicesPerMachine_, numMachines_});
paddingEvent = &paddingEvents_[paddingSlotIdx];
}
at::Tensor slice3dStaging;
c10::optional<at::Tensor> paddingStaging;
at::cuda::CUDAEvent* stagingEvent = nullptr;
if (numDevicesPerMachine_ == 1) {
int64_t stagingSlotIdx = (nextStagingSlot_++) % layout_.numStagingSlots;
slice3dStaging = stagingBuffer_[stagingSlotIdx]
.view(dtype)
.flatten()
.index({torch::indexing::Slice(
0, slice3d[myDeviceIdxOnMachine_].numel())})
.view({numMachines_, -1});
if (padding) {
paddingStaging = paddingStagingBuffer_[stagingSlotIdx]
.view(dtype)
.flatten()
.index({torch::indexing::Slice(
0, (*padding)[myDeviceIdxOnMachine_].numel())})
.view({numMachines_});
}
stagingEvent = &stagingEvents_[stagingSlotIdx];
} else {
slice3dStaging =
slice3d[(myDeviceIdxOnMachine_ + 1) % numDevicesPerMachine_];
if (padding) {
paddingStaging =
(*padding)[(myDeviceIdxOnMachine_ + 1) % numDevicesPerMachine_];
}
}
if (initialEvent.has_value()) {
initialEvent.value().block(reduceScatterStream_);
}
if (padding) {
(*paddingEvent).block(reduceScatterStream_);
// No need to zero out the padding: we don't care what value it has/gets.
CUDA_CHECK(cudaMemcpyAsync(
(*padding).data_ptr(),
reinterpret_cast<uint8_t*>(slice.data_ptr()) +
slice3d.numel() * elementSizeInBytes,
(slice.numel() - slice3d.numel()) * elementSizeInBytes,
cudaMemcpyDeviceToDevice,
reduceScatterStream_));
}
if (numDevicesPerMachine_ > 1) {
NCCL_CHECK(ncclGroupStart());
doReduceScatter(
slice3d,
myDeviceIdxOnMachine_,
reduceScatterComm_,
reduceScatterStream_);
if (padding) {
doReduceScatter(
*padding,
myDeviceIdxOnMachine_,
reduceScatterComm_,
reduceScatterStream_);
}
NCCL_CHECK(ncclGroupEnd());
}
reduceScatterToCollectEvent.record(reduceScatterStream_);
if (stagingEvent) {
(*stagingEvent).block(collectStream_);
}
reduceScatterToCollectEvent.block(collectStream_);
if (numMachines_ > 1) {
NCCL_CHECK(ncclGroupStart());
doCollect(
slice3d[myDeviceIdxOnMachine_],
slice3dStaging,
myMachineIdx_,
collectComm_,
collectStream_);
if (padding) {
doCollect(
(*padding)[myDeviceIdxOnMachine_],
*paddingStaging,
myMachineIdx_,
collectComm_,
collectStream_);
}
NCCL_CHECK(ncclGroupEnd());
}
collectToAddEvent.record(collectStream_);
collectToAddEvent.block(addStream_);
if (numMachines_ > 1) {
c10::cuda::CUDAStreamGuard g(addStream_);
// sum_out wants its first argument to be an lvalue (for no good reason)
if (slice3d.numel() > 0) {
auto out = slice3d[myDeviceIdxOnMachine_][myMachineIdx_];
at::sum_out(out, slice3dStaging, {0});
}
if (padding) {
auto paddingOut = (*padding)[myDeviceIdxOnMachine_][myMachineIdx_];
at::sum_out(paddingOut, (*paddingStaging), {0});
}
}
addToDiffuseEvent.record(addStream_);
if (stagingEvent) {
(*stagingEvent).record(addStream_);
}
addToDiffuseEvent.block(diffuseStream_);
if (numMachines_ > 1) {
NCCL_CHECK(ncclGroupStart());
doDiffuse(
slice3d[myDeviceIdxOnMachine_][myMachineIdx_],
slice3d[myDeviceIdxOnMachine_],
myMachineIdx_,
diffuseComm_,
diffuseStream_);
if (padding) {
doDiffuse(
(*padding)[myDeviceIdxOnMachine_][myMachineIdx_],
(*padding)[myDeviceIdxOnMachine_],
myMachineIdx_,
diffuseComm_,
diffuseStream_);
}
NCCL_CHECK(ncclGroupEnd());
}
diffuseToAllGatherEvent.record(diffuseStream_);
diffuseToAllGatherEvent.block(allGatherStream_);
if (numDevicesPerMachine_ > 1) {
NCCL_CHECK(ncclGroupStart());
doAllGather(
slice3d, myDeviceIdxOnMachine_, allGatherComm_, allGatherStream_);
if (padding) {
doAllGather(
*padding, myDeviceIdxOnMachine_, allGatherComm_, allGatherStream_);
}
NCCL_CHECK(ncclGroupEnd());
}
if (padding) {
CUDA_CHECK(cudaMemcpyAsync(
reinterpret_cast<uint8_t*>(slice.data_ptr()) +
slice3d.numel() * elementSizeInBytes,
(*padding).data_ptr(),
(slice.numel() - slice3d.numel()) * elementSizeInBytes,
cudaMemcpyDeviceToDevice,
allGatherStream_));
(*paddingEvent).record(allGatherStream_);
}
}