void doReduceScatter()

in fairring/device.cc [42:78]


void doReduceScatter(
    at::Tensor t,
    int64_t myRank,
    NcclComm& comm,
    CudaStream& stream) {
  MY_CHECK(t.layout() == at::kStrided);
  MY_CHECK(t.is_non_overlapping_and_dense());
  MY_CHECK(t.dim() >= 1);
  MY_CHECK((0 <= myRank) && (myRank < t.size(0)));
  if (t.numel() == 0) {
    return;
  }
  if (t.is_contiguous()) {
    NCCL_CHECK(ncclReduceScatter(
        t.data_ptr(),
        t[myRank].data_ptr(),
        t[myRank].numel(),
        torchToNcclDtype(t.scalar_type()),
        ncclSum,
        comm.get(),
        stream));
  } else {
    MY_CHECK(t.dim() >= 2);
    t = t.transpose(0, 1);
    MY_CHECK(t.is_contiguous());
    for (const auto idx : c10::irange(t.size(0))) {
      NCCL_CHECK(ncclReduceScatter(
          t[idx].data_ptr(),
          t[idx][myRank].data_ptr(),
          t[idx][myRank].numel(),
          torchToNcclDtype(t.scalar_type()),
          ncclSum,
          comm.get(),
          stream));
    }
  }
}