in src/torch_ucc.cpp [1247:1286]
void ProcessGroupUCC::initComm(c10::Device dev) {
if (!comm) {
#ifdef USE_CUDA
if (dev.is_cuda()) {
c10::cuda::set_device(dev.index());
}
#endif
comm = CommPG::get_comm(comm_id, dev, oob, logger);
comm->ucx_connect_eps(eps, oob);
TORCH_UCC_LOG_INFO(TORCH_UCC_INIT, "Successfully initialized UCX library");
comm->ucc_create_team(team, oob);
TORCH_UCC_LOG_INFO(TORCH_UCC_INIT, "Successfully initialized UCC library");
logger->setPhase(TORCH_UCC_READY);
} else {
if (dev.is_cuda()) {
if ((comm->cuda_device_index != TORCH_UCC_DEVICE_NOT_SET) &&
(comm->cuda_device_index != dev.index())) {
TORCH_UCC_LOG_ERROR(
TORCH_UCC_INIT,
"ucc communicator was initialized with different cuda device,"
"multi device is not supported");
throw std::runtime_error(ucc_status_string(UCC_ERR_NOT_SUPPORTED));
}
comm->cuda_device_index = dev.index();
}
}
#ifdef USE_CUDA
if (!cuda_ee && dev.is_cuda()) {
stream = std::make_unique<at::cuda::CUDAStream>(
at::cuda::getStreamFromPool(true, dev.index()));
ucc_ee_params_t params;
params.ee_type = UCC_EE_CUDA_STREAM;
params.ee_context = (void*)stream->stream();
params.ee_context_size = sizeof(cudaStream_t);
TORCH_UCC_CHECK(
ucc_ee_create(team, ¶ms, &cuda_ee),
"failed to create UCC execution engine");
}
#endif
}