void DeviceFairring::allReduceOneSlice()

in fairring/device.cc [405:588]


void DeviceFairring::allReduceOneSlice(
    at::Tensor slice,
    c10::optional<at::cuda::CUDAEvent> initialEvent) {
  c10::cuda::CUDAGuard g(myDeviceIdxOnProcess_);

  c10::ScalarType dtype = slice.scalar_type();
  int64_t elementSizeInBytes = slice.element_size();

  at::cuda::CUDAEvent reduceScatterToCollectEvent;
  at::cuda::CUDAEvent collectToAddEvent;
  at::cuda::CUDAEvent addToDiffuseEvent;
  at::cuda::CUDAEvent diffuseToAllGatherEvent;

  at::Tensor slice3d;
  c10::optional<at::Tensor> padding;
  at::cuda::CUDAEvent* paddingEvent = nullptr;
  if (slice.numel() % (numDevicesPerMachine_ * numMachines_) == 0) {
    slice3d = slice.view({numDevicesPerMachine_, numMachines_, -1});
  } else {
    int64_t sliceSizeInElems = roundDownToNearestMultiple(
        slice.numel(), numDevicesPerMachine_ * numMachines_);
    slice3d = slice.index({torch::indexing::Slice(0, sliceSizeInElems)})
                  .view({numDevicesPerMachine_, numMachines_, -1});
    int64_t paddingSlotIdx = (nextPaddingSlot_++) % layout_.numPaddingSlots;
    padding = paddingBuffer_[paddingSlotIdx]
                  .view(dtype)
                  .flatten()
                  .index({torch::indexing::Slice(
                      0, numDevicesPerMachine_ * numMachines_)})
                  .view({numDevicesPerMachine_, numMachines_});
    paddingEvent = &paddingEvents_[paddingSlotIdx];
  }

  at::Tensor slice3dStaging;
  c10::optional<at::Tensor> paddingStaging;
  at::cuda::CUDAEvent* stagingEvent = nullptr;
  if (numDevicesPerMachine_ == 1) {
    int64_t stagingSlotIdx = (nextStagingSlot_++) % layout_.numStagingSlots;
    slice3dStaging = stagingBuffer_[stagingSlotIdx]
                         .view(dtype)
                         .flatten()
                         .index({torch::indexing::Slice(
                             0, slice3d[myDeviceIdxOnMachine_].numel())})
                         .view({numMachines_, -1});
    if (padding) {
      paddingStaging = paddingStagingBuffer_[stagingSlotIdx]
                           .view(dtype)
                           .flatten()
                           .index({torch::indexing::Slice(
                               0, (*padding)[myDeviceIdxOnMachine_].numel())})
                           .view({numMachines_});
    }
    stagingEvent = &stagingEvents_[stagingSlotIdx];
  } else {
    slice3dStaging =
        slice3d[(myDeviceIdxOnMachine_ + 1) % numDevicesPerMachine_];
    if (padding) {
      paddingStaging =
          (*padding)[(myDeviceIdxOnMachine_ + 1) % numDevicesPerMachine_];
    }
  }

  if (initialEvent.has_value()) {
    initialEvent.value().block(reduceScatterStream_);
  }

  if (padding) {
    (*paddingEvent).block(reduceScatterStream_);
    // No need to zero out the padding: we don't care what value it has/gets.
    CUDA_CHECK(cudaMemcpyAsync(
        (*padding).data_ptr(),
        reinterpret_cast<uint8_t*>(slice.data_ptr()) +
            slice3d.numel() * elementSizeInBytes,
        (slice.numel() - slice3d.numel()) * elementSizeInBytes,
        cudaMemcpyDeviceToDevice,
        reduceScatterStream_));
  }

  if (numDevicesPerMachine_ > 1) {
    NCCL_CHECK(ncclGroupStart());
    doReduceScatter(
        slice3d,
        myDeviceIdxOnMachine_,
        reduceScatterComm_,
        reduceScatterStream_);
    if (padding) {
      doReduceScatter(
          *padding,
          myDeviceIdxOnMachine_,
          reduceScatterComm_,
          reduceScatterStream_);
    }
    NCCL_CHECK(ncclGroupEnd());
  }
  reduceScatterToCollectEvent.record(reduceScatterStream_);

  if (stagingEvent) {
    (*stagingEvent).block(collectStream_);
  }

  reduceScatterToCollectEvent.block(collectStream_);
  if (numMachines_ > 1) {
    NCCL_CHECK(ncclGroupStart());
    doCollect(
        slice3d[myDeviceIdxOnMachine_],
        slice3dStaging,
        myMachineIdx_,
        collectComm_,
        collectStream_);
    if (padding) {
      doCollect(
          (*padding)[myDeviceIdxOnMachine_],
          *paddingStaging,
          myMachineIdx_,
          collectComm_,
          collectStream_);
    }
    NCCL_CHECK(ncclGroupEnd());
  }
  collectToAddEvent.record(collectStream_);

  collectToAddEvent.block(addStream_);
  if (numMachines_ > 1) {
    c10::cuda::CUDAStreamGuard g(addStream_);
    // sum_out wants its first argument to be an lvalue (for no good reason)
    if (slice3d.numel() > 0) {
      auto out = slice3d[myDeviceIdxOnMachine_][myMachineIdx_];
      at::sum_out(out, slice3dStaging, {0});
    }
    if (padding) {
      auto paddingOut = (*padding)[myDeviceIdxOnMachine_][myMachineIdx_];
      at::sum_out(paddingOut, (*paddingStaging), {0});
    }
  }
  addToDiffuseEvent.record(addStream_);

  if (stagingEvent) {
    (*stagingEvent).record(addStream_);
  }

  addToDiffuseEvent.block(diffuseStream_);
  if (numMachines_ > 1) {
    NCCL_CHECK(ncclGroupStart());
    doDiffuse(
        slice3d[myDeviceIdxOnMachine_][myMachineIdx_],
        slice3d[myDeviceIdxOnMachine_],
        myMachineIdx_,
        diffuseComm_,
        diffuseStream_);
    if (padding) {
      doDiffuse(
          (*padding)[myDeviceIdxOnMachine_][myMachineIdx_],
          (*padding)[myDeviceIdxOnMachine_],
          myMachineIdx_,
          diffuseComm_,
          diffuseStream_);
    }
    NCCL_CHECK(ncclGroupEnd());
  }
  diffuseToAllGatherEvent.record(diffuseStream_);

  diffuseToAllGatherEvent.block(allGatherStream_);
  if (numDevicesPerMachine_ > 1) {
    NCCL_CHECK(ncclGroupStart());
    doAllGather(
        slice3d, myDeviceIdxOnMachine_, allGatherComm_, allGatherStream_);
    if (padding) {
      doAllGather(
          *padding, myDeviceIdxOnMachine_, allGatherComm_, allGatherStream_);
    }
    NCCL_CHECK(ncclGroupEnd());
  }

  if (padding) {
    CUDA_CHECK(cudaMemcpyAsync(
        reinterpret_cast<uint8_t*>(slice.data_ptr()) +
            slice3d.numel() * elementSizeInBytes,
        (*padding).data_ptr(),
        (slice.numel() - slice3d.numel()) * elementSizeInBytes,
        cudaMemcpyDeviceToDevice,
        allGatherStream_));
    (*paddingEvent).record(allGatherStream_);
  }
}