in src/torch_ucc.cpp [268:323]
std::shared_ptr<CommPG> CommPG::get_comm(
uint32_t& id,
c10::Device dev,
std::shared_ptr<torch_ucc_oob_coll_info_t> oob,
const c10::intrusive_ptr<ProcessGroupUCCLogger>& logger) {
static std::mutex m;
static std::weak_ptr<CommPG> comm;
static uint32_t comm_id;
std::lock_guard<std::mutex> lock(m);
id = (comm_id % TORCH_UCX_MAX_COMM);
std::vector<uint8_t> remote_comm_id;
if (oob->rank != 0) {
std::vector<uint8_t> val = std::vector<uint8_t>(
reinterpret_cast<uint8_t*>(&id),
reinterpret_cast<uint8_t*>(&id) + sizeof(id));
oob->store->set("group_id" + std::to_string(oob->rank), val);
} else {
for (int i = 1; i < oob->size; i++) {
remote_comm_id = oob->store->get("group_id" + std::to_string(i));
id = std::max(id, *(reinterpret_cast<uint32_t*>(remote_comm_id.data())));
}
std::vector<uint8_t> val = std::vector<uint8_t>(
reinterpret_cast<uint8_t*>(&id),
reinterpret_cast<uint8_t*>(&id) + sizeof(id));
oob->store->set("group_id" + std::to_string(oob->rank), val);
}
remote_comm_id = oob->store->get("group_id" + std::to_string(0));
oob->comm_id = *(reinterpret_cast<uint32_t*>(remote_comm_id.data()));
comm_id = oob->comm_id + 1;
if (torch_ucc_config.shared_comm) {
std::shared_ptr<CommPG> shared_comm = comm.lock();
if (!shared_comm) {
shared_comm = std::make_shared<CommPG>(
logger, oob, dev);
comm = shared_comm;
} else {
if (dev.is_cuda()) {
if ((shared_comm->cuda_device_index != TORCH_UCC_DEVICE_NOT_SET) &&
(shared_comm->cuda_device_index != dev.index())) {
TORCH_UCC_LOG_ERROR(
TORCH_UCC_INIT,
"ucc communicator was initialized with different cuda device,"
"multi device is not supported");
throw std::runtime_error(ucc_status_string(UCC_ERR_NOT_SUPPORTED));
}
shared_comm->cuda_device_index = dev.index();
}
}
return shared_comm;
} else {
return std::make_shared<CommPG>(logger, oob, dev);
}
}