in fairring/process_group.cc [136:169]
c10::intrusive_ptr<c10d::ProcessGroup::Work> ProcessGroupFairring::allreduce(
std::vector<at::Tensor>& data,
const c10d::AllreduceOptions& opts) {
// return c10::make_intrusive<WorkFairring>(
// c10d::OpType::ALLREDUCE, ncclPG_->allreduce(data, opts)->getFuture());
for (at::Tensor& t : data) {
MY_CHECK(t.layout() == at::kStrided);
MY_CHECK(t.is_cuda());
MY_CHECK(t.is_non_overlapping_and_dense());
t = viewAsFlat(t);
}
if (machine_ == nullptr) {
std::set<c10::DeviceIndex> deviceSet;
for (const at::Tensor& t : data) {
if (t.is_cuda()) {
deviceSet.insert(t.device().index());
}
}
std::vector<c10::Device> devices;
for (const c10::DeviceIndex& idx : deviceSet) {
devices.emplace_back(c10::kCUDA, idx);
}
machine_ = std::make_unique<fairring::MachineFairring>(
store_,
rank_,
size_,
std::move(devices),
maxMemoryAllocatedInBytes_,
maxPaddingAllocatedInBytes_,
minParallelism_);
}
return c10::make_intrusive<WorkFairring>(
c10d::OpType::ALLREDUCE, machine_->allReduce(opts.reduceOp, data));
}