in src/torch_ucc.cpp [325:351]
void CommPG::ucx_connect_eps(
std::vector<ucp_ep_h>& eps,
std::shared_ptr<torch_ucc_oob_coll_info_t> oob) {
ucp_address_t* local_addr;
size_t local_addr_len;
std::vector<uint8_t> peer_addr;
TORCH_UCX_CHECK(
ucp_worker_get_address(ucx_comm.worker, &local_addr, &local_addr_len),
"failed to get worker address");
std::vector<uint8_t> val = std::vector<uint8_t>(
reinterpret_cast<uint8_t*>(local_addr),
reinterpret_cast<uint8_t*>(local_addr) + local_addr_len);
oob->store->set(oob->getKey("wa" + std::to_string(oob->rank)), val);
ucp_worker_release_address(ucx_comm.worker, local_addr);
eps.resize(oob->size);
for (int i = 0; i < oob->size; i++) {
peer_addr = oob->store->get(oob->getKey("wa" + std::to_string(i)));
ucp_ep_params_t ep_params;
ep_params.field_mask = UCP_EP_PARAM_FIELD_REMOTE_ADDRESS;
ep_params.address = reinterpret_cast<ucp_address_t*>(peer_addr.data());
TORCH_UCX_CHECK(
ucp_ep_create(ucx_comm.worker, &ep_params, &(eps[i])),
c10::str("failed to create endpoint with rank ", i));
}
}