void doDiffuse()

in fairring/device.cc [116:156]


void doDiffuse(
    at::Tensor sendT,
    at::Tensor recvT,
    int64_t myRank,
    NcclComm& comm,
    CudaStream& stream) {
  MY_CHECK(sendT.layout() == at::kStrided);
  MY_CHECK(recvT.layout() == at::kStrided);
  MY_CHECK(recvT.dim() >= 1);
  MY_CHECK((0 <= myRank) && (myRank < recvT.size(0)));
  MY_CHECK(sendT.sizes() == recvT[0].sizes());
  MY_CHECK(sendT.strides() == recvT[0].strides());
  MY_CHECK(recvT[0].is_non_overlapping_and_dense());
  if (recvT.numel() == 0) {
    return;
  }
  for (const auto otherRank : c10::irange(recvT.size(0))) {
    if (otherRank == myRank && recvT[myRank].data_ptr() == sendT.data_ptr()) {
      continue;
    }
    NCCL_CHECK(ncclSend(
        sendT.data_ptr(),
        sendT.numel(),
        torchToNcclDtype(recvT.scalar_type()),
        otherRank,
        comm.get(),
        stream));
  }
  for (const auto otherRank : c10::irange(recvT.size(0))) {
    if (otherRank == myRank && recvT[myRank].data_ptr() == sendT.data_ptr()) {
      continue;
    }
    NCCL_CHECK(ncclRecv(
        recvT[otherRank].data_ptr(),
        recvT[otherRank].numel(),
        torchToNcclDtype(recvT.scalar_type()),
        otherRank,
        comm.get(),
        stream));
  }
}