c10::intrusive_ptr ProcessGroupFairring::allreduce()

in fairring/process_group.cc [136:169]


c10::intrusive_ptr<c10d::ProcessGroup::Work> ProcessGroupFairring::allreduce(
    std::vector<at::Tensor>& data,
    const c10d::AllreduceOptions& opts) {
  // return c10::make_intrusive<WorkFairring>(
  //     c10d::OpType::ALLREDUCE, ncclPG_->allreduce(data, opts)->getFuture());
  for (at::Tensor& t : data) {
    MY_CHECK(t.layout() == at::kStrided);
    MY_CHECK(t.is_cuda());
    MY_CHECK(t.is_non_overlapping_and_dense());
    t = viewAsFlat(t);
  }
  if (machine_ == nullptr) {
    std::set<c10::DeviceIndex> deviceSet;
    for (const at::Tensor& t : data) {
      if (t.is_cuda()) {
        deviceSet.insert(t.device().index());
      }
    }
    std::vector<c10::Device> devices;
    for (const c10::DeviceIndex& idx : deviceSet) {
      devices.emplace_back(c10::kCUDA, idx);
    }
    machine_ = std::make_unique<fairring::MachineFairring>(
        store_,
        rank_,
        size_,
        std::move(devices),
        maxMemoryAllocatedInBytes_,
        maxPaddingAllocatedInBytes_,
        minParallelism_);
  }
  return c10::make_intrusive<WorkFairring>(
      c10d::OpType::ALLREDUCE, machine_->allReduce(opts.reduceOp, data));
}