in fairring/device.cc [659:708]
void DeviceFairring::allGatherOneSlice(
at::Tensor input,
at::Tensor output,
at::cuda::CUDAEvent initialEvent) {
c10::cuda::CUDAGuard g(myDeviceIdxOnProcess_);
at::cuda::CUDAEvent diffuseToAllGatherEvent;
MY_CHECK(output.numel() % (numDevicesPerMachine_ * numMachines_) == 0);
MY_CHECK(
output.numel() == input.numel() * numDevicesPerMachine_ * numMachines_);
at::Tensor output3d;
if (deviceGlobalRankIsFavorable_) {
output3d = output.view({numDevicesPerMachine_, numMachines_, -1});
} else {
output3d =
output.view({numMachines_, numDevicesPerMachine_, -1}).transpose(0, 1);
}
initialEvent.block(diffuseStream_);
if (numMachines_ > 1) {
NCCL_CHECK(ncclGroupStart());
doDiffuse(
input,
output3d[myDeviceIdxOnMachine_],
myMachineIdx_,
diffuseComm_,
diffuseStream_);
NCCL_CHECK(ncclGroupEnd());
}
diffuseToAllGatherEvent.record(diffuseStream_);
diffuseToAllGatherEvent.block(allGatherStream_);
if (numMachines_ == 1) {
MY_CHECK(output3d.is_contiguous());
NCCL_CHECK(ncclAllGather(
input.data_ptr(),
output3d.data_ptr(),
input.numel(),
torchToNcclDtype(input.scalar_type()),
allGatherComm_.get(),
allGatherStream_));
} else if (numDevicesPerMachine_ > 1) {
NCCL_CHECK(ncclGroupStart());
doAllGather(
output3d, myDeviceIdxOnMachine_, allGatherComm_, allGatherStream_);
NCCL_CHECK(ncclGroupEnd());
}
}