std::shared_ptr CommPG::get_comm()

in src/torch_ucc.cpp [268:323]


std::shared_ptr<CommPG> CommPG::get_comm(
    uint32_t& id,
    c10::Device dev,
    std::shared_ptr<torch_ucc_oob_coll_info_t> oob,
    const c10::intrusive_ptr<ProcessGroupUCCLogger>& logger) {
  static std::mutex m;
  static std::weak_ptr<CommPG> comm;
  static uint32_t comm_id;

  std::lock_guard<std::mutex> lock(m);
  id = (comm_id % TORCH_UCX_MAX_COMM);

  std::vector<uint8_t> remote_comm_id;
  if (oob->rank != 0) {
    std::vector<uint8_t> val = std::vector<uint8_t>(
        reinterpret_cast<uint8_t*>(&id),
        reinterpret_cast<uint8_t*>(&id) + sizeof(id));
    oob->store->set("group_id" + std::to_string(oob->rank), val);
  } else {
    for (int i = 1; i < oob->size; i++) {
      remote_comm_id = oob->store->get("group_id" + std::to_string(i));
      id = std::max(id, *(reinterpret_cast<uint32_t*>(remote_comm_id.data())));
    }
    std::vector<uint8_t> val = std::vector<uint8_t>(
        reinterpret_cast<uint8_t*>(&id),
        reinterpret_cast<uint8_t*>(&id) + sizeof(id));
    oob->store->set("group_id" + std::to_string(oob->rank), val);
  }
  remote_comm_id = oob->store->get("group_id" + std::to_string(0));
  oob->comm_id = *(reinterpret_cast<uint32_t*>(remote_comm_id.data()));
  comm_id = oob->comm_id + 1;

  if (torch_ucc_config.shared_comm) {
    std::shared_ptr<CommPG> shared_comm = comm.lock();
    if (!shared_comm) {
      shared_comm = std::make_shared<CommPG>(
          logger, oob, dev);
      comm = shared_comm;
    } else {
      if (dev.is_cuda()) {
        if ((shared_comm->cuda_device_index != TORCH_UCC_DEVICE_NOT_SET) &&
            (shared_comm->cuda_device_index != dev.index())) {
          TORCH_UCC_LOG_ERROR(
              TORCH_UCC_INIT,
              "ucc communicator was initialized with different cuda device,"
              "multi device is not supported");
          throw std::runtime_error(ucc_status_string(UCC_ERR_NOT_SUPPORTED));
        }
        shared_comm->cuda_device_index = dev.index();
      }
    }
    return shared_comm;
  } else {
    return std::make_shared<CommPG>(logger, oob, dev);
  }
}