void DeviceFairring::allGatherOneSlice()

in fairring/device.cc [659:708]


void DeviceFairring::allGatherOneSlice(
    at::Tensor input,
    at::Tensor output,
    at::cuda::CUDAEvent initialEvent) {
  c10::cuda::CUDAGuard g(myDeviceIdxOnProcess_);

  at::cuda::CUDAEvent diffuseToAllGatherEvent;

  MY_CHECK(output.numel() % (numDevicesPerMachine_ * numMachines_) == 0);
  MY_CHECK(
      output.numel() == input.numel() * numDevicesPerMachine_ * numMachines_);
  at::Tensor output3d;
  if (deviceGlobalRankIsFavorable_) {
    output3d = output.view({numDevicesPerMachine_, numMachines_, -1});
  } else {
    output3d =
        output.view({numMachines_, numDevicesPerMachine_, -1}).transpose(0, 1);
  }

  initialEvent.block(diffuseStream_);

  if (numMachines_ > 1) {
    NCCL_CHECK(ncclGroupStart());
    doDiffuse(
        input,
        output3d[myDeviceIdxOnMachine_],
        myMachineIdx_,
        diffuseComm_,
        diffuseStream_);
    NCCL_CHECK(ncclGroupEnd());
  }
  diffuseToAllGatherEvent.record(diffuseStream_);

  diffuseToAllGatherEvent.block(allGatherStream_);
  if (numMachines_ == 1) {
    MY_CHECK(output3d.is_contiguous());
    NCCL_CHECK(ncclAllGather(
        input.data_ptr(),
        output3d.data_ptr(),
        input.numel(),
        torchToNcclDtype(input.scalar_type()),
        allGatherComm_.get(),
        allGatherStream_));
  } else if (numDevicesPerMachine_ > 1) {
    NCCL_CHECK(ncclGroupStart());
    doAllGather(
        output3d, myDeviceIdxOnMachine_, allGatherComm_, allGatherStream_);
    NCCL_CHECK(ncclGroupEnd());
  }
}