in fairring/device.cc [80:114]
void doCollect(
at::Tensor sendT,
at::Tensor recvT,
int64_t myRank,
NcclComm& comm,
CudaStream& stream) {
MY_CHECK(sendT.layout() == at::kStrided);
MY_CHECK(recvT.layout() == at::kStrided);
MY_CHECK(sendT.sizes() == recvT.sizes());
MY_CHECK(sendT.strides() == recvT.strides());
MY_CHECK(sendT.dim() >= 1);
MY_CHECK((0 <= myRank) && (myRank < sendT.size(0)));
MY_CHECK(sendT[0].is_non_overlapping_and_dense());
if (sendT.numel() == 0) {
return;
}
for (const auto otherRank : c10::irange(sendT.size(0))) {
NCCL_CHECK(ncclSend(
sendT[otherRank].data_ptr(),
sendT[otherRank].numel(),
torchToNcclDtype(sendT.scalar_type()),
otherRank,
comm.get(),
stream));
}
for (const auto otherRank : c10::irange(recvT.size(0))) {
NCCL_CHECK(ncclRecv(
recvT[otherRank].data_ptr(),
recvT[otherRank].numel(),
torchToNcclDtype(recvT.scalar_type()),
otherRank,
comm.get(),
stream));
}
}