void CommPG::ucx_connect_eps()

in src/torch_ucc.cpp [325:351]


void CommPG::ucx_connect_eps(
    std::vector<ucp_ep_h>& eps,
    std::shared_ptr<torch_ucc_oob_coll_info_t> oob) {
  ucp_address_t* local_addr;
  size_t local_addr_len;
  std::vector<uint8_t> peer_addr;

  TORCH_UCX_CHECK(
      ucp_worker_get_address(ucx_comm.worker, &local_addr, &local_addr_len),
      "failed to get worker address");

  std::vector<uint8_t> val = std::vector<uint8_t>(
      reinterpret_cast<uint8_t*>(local_addr),
      reinterpret_cast<uint8_t*>(local_addr) + local_addr_len);
  oob->store->set(oob->getKey("wa" + std::to_string(oob->rank)), val);
  ucp_worker_release_address(ucx_comm.worker, local_addr);
  eps.resize(oob->size);
  for (int i = 0; i < oob->size; i++) {
    peer_addr = oob->store->get(oob->getKey("wa" + std::to_string(i)));
    ucp_ep_params_t ep_params;
    ep_params.field_mask = UCP_EP_PARAM_FIELD_REMOTE_ADDRESS;
    ep_params.address = reinterpret_cast<ucp_address_t*>(peer_addr.data());
    TORCH_UCX_CHECK(
        ucp_ep_create(ucx_comm.worker, &ep_params, &(eps[i])),
        c10::str("failed to create endpoint with rank ", i));
  }
}