void doCollect()

in fairring/device.cc [80:114]


void doCollect(
    at::Tensor sendT,
    at::Tensor recvT,
    int64_t myRank,
    NcclComm& comm,
    CudaStream& stream) {
  MY_CHECK(sendT.layout() == at::kStrided);
  MY_CHECK(recvT.layout() == at::kStrided);
  MY_CHECK(sendT.sizes() == recvT.sizes());
  MY_CHECK(sendT.strides() == recvT.strides());
  MY_CHECK(sendT.dim() >= 1);
  MY_CHECK((0 <= myRank) && (myRank < sendT.size(0)));
  MY_CHECK(sendT[0].is_non_overlapping_and_dense());
  if (sendT.numel() == 0) {
    return;
  }
  for (const auto otherRank : c10::irange(sendT.size(0))) {
    NCCL_CHECK(ncclSend(
        sendT[otherRank].data_ptr(),
        sendT[otherRank].numel(),
        torchToNcclDtype(sendT.scalar_type()),
        otherRank,
        comm.get(),
        stream));
  }
  for (const auto otherRank : c10::irange(recvT.size(0))) {
    NCCL_CHECK(ncclRecv(
        recvT[otherRank].data_ptr(),
        recvT[otherRank].numel(),
        torchToNcclDtype(recvT.scalar_type()),
        otherRank,
        comm.get(),
        stream));
  }
}