void DeviceFairring::reduceScatterOneSlice()

in fairring/device.cc [590:657]


void DeviceFairring::reduceScatterOneSlice(
    at::Tensor input,
    at::Tensor output,
    at::cuda::CUDAEvent initialEvent) {
  c10::cuda::CUDAGuard g(myDeviceIdxOnProcess_);

  at::cuda::CUDAEvent reduceScatterToCollectEvent;
  at::cuda::CUDAEvent collectToAddEvent;

  MY_CHECK(input.numel() % (numDevicesPerMachine_ * numMachines_) == 0);
  MY_CHECK(
      input.numel() == output.numel() * numDevicesPerMachine_ * numMachines_);
  at::Tensor input3d;
  if (deviceGlobalRankIsFavorable_) {
    input3d = input.view({numDevicesPerMachine_, numMachines_, -1});
  } else {
    input3d =
        input.view({numMachines_, numDevicesPerMachine_, -1}).transpose(0, 1);
  }

  MY_CHECK(numDevicesPerMachine_ >= 2);
  at::Tensor input3dStaging =
      input3d[(myDeviceIdxOnMachine_ + 1) % numDevicesPerMachine_];

  initialEvent.block(reduceScatterStream_);

  if (numMachines_ == 1) {
    MY_CHECK(input3d.is_contiguous());
    NCCL_CHECK(ncclReduceScatter(
        input3d.data_ptr(),
        output.data_ptr(),
        output.numel(),
        torchToNcclDtype(output.scalar_type()),
        ncclSum,
        reduceScatterComm_.get(),
        reduceScatterStream_));
  } else if (numDevicesPerMachine_ > 1) {
    NCCL_CHECK(ncclGroupStart());
    doReduceScatter(
        input3d,
        myDeviceIdxOnMachine_,
        reduceScatterComm_,
        reduceScatterStream_);
    NCCL_CHECK(ncclGroupEnd());
  }
  reduceScatterToCollectEvent.record(reduceScatterStream_);

  reduceScatterToCollectEvent.block(collectStream_);
  if (numMachines_ > 1) {
    NCCL_CHECK(ncclGroupStart());
    doCollect(
        input3d[myDeviceIdxOnMachine_],
        input3dStaging,
        myMachineIdx_,
        collectComm_,
        collectStream_);
    NCCL_CHECK(ncclGroupEnd());
  }
  collectToAddEvent.record(collectStream_);

  collectToAddEvent.block(addStream_);
  if (numMachines_ > 1) {
    c10::cuda::CUDAStreamGuard g(addStream_);
    if (input3d.numel() > 0) {
      at::sum_out(output, input3dStaging, {0});
    }
  }
}