in fairring/device.cc [42:78]
void doReduceScatter(
at::Tensor t,
int64_t myRank,
NcclComm& comm,
CudaStream& stream) {
MY_CHECK(t.layout() == at::kStrided);
MY_CHECK(t.is_non_overlapping_and_dense());
MY_CHECK(t.dim() >= 1);
MY_CHECK((0 <= myRank) && (myRank < t.size(0)));
if (t.numel() == 0) {
return;
}
if (t.is_contiguous()) {
NCCL_CHECK(ncclReduceScatter(
t.data_ptr(),
t[myRank].data_ptr(),
t[myRank].numel(),
torchToNcclDtype(t.scalar_type()),
ncclSum,
comm.get(),
stream));
} else {
MY_CHECK(t.dim() >= 2);
t = t.transpose(0, 1);
MY_CHECK(t.is_contiguous());
for (const auto idx : c10::irange(t.size(0))) {
NCCL_CHECK(ncclReduceScatter(
t[idx].data_ptr(),
t[idx][myRank].data_ptr(),
t[idx][myRank].numel(),
torchToNcclDtype(t.scalar_type()),
ncclSum,
comm.get(),
stream));
}
}
}