in fairring/process_group.cc [187:240]
c10::intrusive_ptr<c10d::ProcessGroup::Work> ProcessGroupFairring::allgather(
std::vector<std::vector<at::Tensor>>& outputTensors,
std::vector<at::Tensor>& inputTensors,
const c10d::AllgatherOptions& opts) {
// return c10::make_intrusive<WorkFairring>(
// c10d::OpType::ALLGATHER,
// ncclPG_->allgather(outputTensors, inputTensors, opts)->getFuture());
MY_CHECK(inputTensors.size() == outputTensors.size());
int64_t numDevicesPerRank = inputTensors.size();
std::vector<fairring::MachineFairring::TensorPair> data;
for (const auto deviceOffset : c10::irange(numDevicesPerRank)) {
MY_CHECK(
static_cast<int64_t>(outputTensors[deviceOffset].size()) ==
size_ * numDevicesPerRank);
MY_CHECK(inputTensors[deviceOffset].layout() == at::kStrided);
MY_CHECK(inputTensors[deviceOffset].is_cuda());
MY_CHECK(inputTensors[deviceOffset].is_non_overlapping_and_dense());
std::vector<at::Tensor> flattened;
for (const at::Tensor& t : outputTensors[deviceOffset]) {
MY_CHECK(t.layout() == at::kStrided);
MY_CHECK(t.is_cuda());
MY_CHECK(t.is_non_overlapping_and_dense());
MY_CHECK(t.device() == inputTensors[deviceOffset].device());
MY_CHECK(t.scalar_type() == inputTensors[deviceOffset].scalar_type());
MY_CHECK(t.numel() == inputTensors[deviceOffset].numel());
flattened.push_back(viewAsFlat(t));
}
data.push_back(fairring::MachineFairring::TensorPair{
.input = viewAsFlat(inputTensors[deviceOffset]),
.output = unUnbind(std::move(flattened))});
}
if (machine_ == nullptr) {
std::set<c10::DeviceIndex> deviceSet;
for (const fairring::MachineFairring::TensorPair& pair : data) {
if (pair.input.is_cuda()) {
deviceSet.insert(pair.input.device().index());
}
}
std::vector<c10::Device> devices;
for (const c10::DeviceIndex& idx : deviceSet) {
devices.emplace_back(c10::kCUDA, idx);
}
machine_ = std::make_unique<fairring::MachineFairring>(
store_,
rank_,
size_,
std::move(devices),
maxMemoryAllocatedInBytes_,
maxPaddingAllocatedInBytes_,
minParallelism_);
}
return c10::make_intrusive<WorkFairring>(
c10d::OpType::ALLGATHER, machine_->allGather(std::move(data)));
}