c10::intrusive_ptr ProcessGroupFairring::allgather()

in fairring/process_group.cc [187:240]


c10::intrusive_ptr<c10d::ProcessGroup::Work> ProcessGroupFairring::allgather(
    std::vector<std::vector<at::Tensor>>& outputTensors,
    std::vector<at::Tensor>& inputTensors,
    const c10d::AllgatherOptions& opts) {
  // return c10::make_intrusive<WorkFairring>(
  //     c10d::OpType::ALLGATHER,
  //     ncclPG_->allgather(outputTensors, inputTensors, opts)->getFuture());
  MY_CHECK(inputTensors.size() == outputTensors.size());
  int64_t numDevicesPerRank = inputTensors.size();
  std::vector<fairring::MachineFairring::TensorPair> data;
  for (const auto deviceOffset : c10::irange(numDevicesPerRank)) {
    MY_CHECK(
        static_cast<int64_t>(outputTensors[deviceOffset].size()) ==
        size_ * numDevicesPerRank);
    MY_CHECK(inputTensors[deviceOffset].layout() == at::kStrided);
    MY_CHECK(inputTensors[deviceOffset].is_cuda());
    MY_CHECK(inputTensors[deviceOffset].is_non_overlapping_and_dense());
    std::vector<at::Tensor> flattened;
    for (const at::Tensor& t : outputTensors[deviceOffset]) {
      MY_CHECK(t.layout() == at::kStrided);
      MY_CHECK(t.is_cuda());
      MY_CHECK(t.is_non_overlapping_and_dense());
      MY_CHECK(t.device() == inputTensors[deviceOffset].device());
      MY_CHECK(t.scalar_type() == inputTensors[deviceOffset].scalar_type());
      MY_CHECK(t.numel() == inputTensors[deviceOffset].numel());
      flattened.push_back(viewAsFlat(t));
    }
    data.push_back(fairring::MachineFairring::TensorPair{
        .input = viewAsFlat(inputTensors[deviceOffset]),
        .output = unUnbind(std::move(flattened))});
  }
  if (machine_ == nullptr) {
    std::set<c10::DeviceIndex> deviceSet;
    for (const fairring::MachineFairring::TensorPair& pair : data) {
      if (pair.input.is_cuda()) {
        deviceSet.insert(pair.input.device().index());
      }
    }
    std::vector<c10::Device> devices;
    for (const c10::DeviceIndex& idx : deviceSet) {
      devices.emplace_back(c10::kCUDA, idx);
    }
    machine_ = std::make_unique<fairring::MachineFairring>(
        store_,
        rank_,
        size_,
        std::move(devices),
        maxMemoryAllocatedInBytes_,
        maxPaddingAllocatedInBytes_,
        minParallelism_);
  }
  return c10::make_intrusive<WorkFairring>(
      c10d::OpType::ALLGATHER, machine_->allGather(std::move(data)));
}