c10::intrusive_ptr ProcessGroupUCC::collective_post()

in src/torch_ucc.cpp [703:762]


c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupUCC::collective_post(
    OpType opType,
    PreProcess preproc,
    PostProcess postproc,
    ucc_coll_args_t& coll,
    std::unique_ptr<ProcessGroupUCC::WorkData> data,
    c10::Device dev,
    std::vector<at::Tensor> &outputTensors,
    const char* prof_title) {
  set_timeout(coll);
  auto work = c10::make_intrusive<ProcessGroupUCC::WorkUCC>(
      opType, torch_ucc_config.enable_profiling ? prof_title : nullptr);

  // Store references to outputs to be used by result
  work->outputs_ = std::make_shared<std::vector<at::Tensor>>(outputTensors);
  switch (dev.type()) {
    case c10::DeviceType::CPU: {
      if (torch_ucc_config.use_future) {
        work->future_ = c10::make_intrusive<at::ivalue::Future>(
            c10::ListType::create(c10::TensorType::get()));
      }
      comm->enqueue_collective(std::move(data), work, coll, team);
      return work;
    }
#ifdef USE_CUDA
    case c10::DeviceType::CUDA: {
      auto cuda_ev = getPooledEvent();
      cuda_ev->record(at::cuda::getCurrentCUDAStream(dev.index()));
      cuda_ev->block(*stream);
      at::cuda::CUDAStreamGuard guard(*stream);
      preproc();
      comm->enqueue_cuda_collective(std::move(data), work, coll, team, cuda_ee);
      postproc();
      cuda_ev->record(*stream);
      work->fence = std::move(cuda_ev);
      work->ep = &ep;
      if (torch_ucc_config.use_future) {
        c10::cuda::CUDAMultiStreamGuard streamGuard(*stream);
        std::vector<c10::Device> devList{dev};
        work->future_ = c10::make_intrusive<at::ivalue::Future>(
            c10::ListType::create(c10::TensorType::get()), devList);
        // Add a callback that runs profiling end callbacks
        if (work->recordFunctionEndCallback_) {
          work->future_->addCallback([work](at::ivalue::Future& /* unused */) {
            work->recordFunctionEndCallback_();
          });
        }

        work->future_->markCompleted(c10::IValue(outputTensors));
      }
      return work;
    }
#endif // #ifdef USE_CUDA
    default: {
      TORCH_UCC_LOG_ERROR(
          TORCH_UCC_COLL_POST, c10::str("unsupported device type ", dev.str()));
      throw std::runtime_error(ucc_status_string(UCC_ERR_NOT_SUPPORTED));
    }
  }
}