void ProcessGroupUCC::initComm()

in src/torch_ucc.cpp [1247:1286]


void ProcessGroupUCC::initComm(c10::Device dev) {
  if (!comm) {
#ifdef USE_CUDA
    if (dev.is_cuda()) {
      c10::cuda::set_device(dev.index());
    }
#endif
    comm = CommPG::get_comm(comm_id, dev, oob, logger);
    comm->ucx_connect_eps(eps, oob);
    TORCH_UCC_LOG_INFO(TORCH_UCC_INIT, "Successfully initialized UCX library");
    comm->ucc_create_team(team, oob);
    TORCH_UCC_LOG_INFO(TORCH_UCC_INIT, "Successfully initialized UCC library");
    logger->setPhase(TORCH_UCC_READY);
  } else {
    if (dev.is_cuda()) {
      if ((comm->cuda_device_index != TORCH_UCC_DEVICE_NOT_SET) &&
          (comm->cuda_device_index != dev.index())) {
        TORCH_UCC_LOG_ERROR(
            TORCH_UCC_INIT,
            "ucc communicator was initialized with different cuda device,"
            "multi device is not supported");
        throw std::runtime_error(ucc_status_string(UCC_ERR_NOT_SUPPORTED));
      }
      comm->cuda_device_index = dev.index();
    }
  }
#ifdef USE_CUDA
  if (!cuda_ee && dev.is_cuda()) {
    stream = std::make_unique<at::cuda::CUDAStream>(
        at::cuda::getStreamFromPool(true, dev.index()));
    ucc_ee_params_t params;
    params.ee_type = UCC_EE_CUDA_STREAM;
    params.ee_context = (void*)stream->stream();
    params.ee_context_size = sizeof(cudaStream_t);
    TORCH_UCC_CHECK(
        ucc_ee_create(team, &params, &cuda_ee),
        "failed to create UCC execution engine");
  }
#endif
}