in fairring/device.cc [116:156]
void doDiffuse(
at::Tensor sendT,
at::Tensor recvT,
int64_t myRank,
NcclComm& comm,
CudaStream& stream) {
MY_CHECK(sendT.layout() == at::kStrided);
MY_CHECK(recvT.layout() == at::kStrided);
MY_CHECK(recvT.dim() >= 1);
MY_CHECK((0 <= myRank) && (myRank < recvT.size(0)));
MY_CHECK(sendT.sizes() == recvT[0].sizes());
MY_CHECK(sendT.strides() == recvT[0].strides());
MY_CHECK(recvT[0].is_non_overlapping_and_dense());
if (recvT.numel() == 0) {
return;
}
for (const auto otherRank : c10::irange(recvT.size(0))) {
if (otherRank == myRank && recvT[myRank].data_ptr() == sendT.data_ptr()) {
continue;
}
NCCL_CHECK(ncclSend(
sendT.data_ptr(),
sendT.numel(),
torchToNcclDtype(recvT.scalar_type()),
otherRank,
comm.get(),
stream));
}
for (const auto otherRank : c10::irange(recvT.size(0))) {
if (otherRank == myRank && recvT[myRank].data_ptr() == sendT.data_ptr()) {
continue;
}
NCCL_CHECK(ncclRecv(
recvT[otherRank].data_ptr(),
recvT[otherRank].numel(),
torchToNcclDtype(recvT.scalar_type()),
otherRank,
comm.get(),
stream));
}
}