in fairring/device.cc [158:192]
void doAllGather(
at::Tensor t,
int64_t myRank,
NcclComm& comm,
CudaStream& stream) {
MY_CHECK(t.layout() == at::kStrided);
MY_CHECK(t.is_non_overlapping_and_dense());
MY_CHECK(t.dim() >= 1);
MY_CHECK((0 <= myRank) && (myRank < t.size(0)));
if (t.numel() == 0) {
return;
}
if (t.is_contiguous()) {
NCCL_CHECK(ncclAllGather(
t[myRank].data_ptr(),
t.data_ptr(),
t[myRank].numel(),
torchToNcclDtype(t.scalar_type()),
comm.get(),
stream));
} else {
MY_CHECK(t.dim() >= 2);
t = t.transpose(0, 1);
MY_CHECK(t.is_contiguous());
for (const auto idx : c10::irange(t.size(0))) {
NCCL_CHECK(ncclAllGather(
t[idx][myRank].data_ptr(),
t[idx].data_ptr(),
t[idx][myRank].numel(),
torchToNcclDtype(t.scalar_type()),
comm.get(),
stream));
}
}
}