in src/nccl_ofi_rdma.cpp [2645:2716]
static int finish_connect(nccl_net_ofi_rdma_send_comm_t *s_comm)
{
int ret = 0;
nccl_ofi_rdma_connection_info_t *conn_resp = (nccl_ofi_rdma_connection_info_t *)s_comm->conn_msg->ptr;
int dev_id = -1;
nccl_net_ofi_rdma_ep_t *ep = NULL;
nccl_net_ofi_rdma_device_t *device = NULL;
assert(s_comm->conn_resp_req);
if (s_comm->conn_resp_req->state != NCCL_OFI_RDMA_REQ_COMPLETED) {
NCCL_OFI_WARN("Invalid connect response request state. Got %i but expected %i",
s_comm->conn_resp_req->state, NCCL_OFI_RDMA_REQ_COMPLETED);
return -EINVAL;
}
/* Validate endpoint */
ep = (nccl_net_ofi_rdma_ep_t *)s_comm->base.base.ep;
if (OFI_UNLIKELY(ep == NULL)) {
NCCL_OFI_WARN("Invalid endpoint provided");
return -EINVAL;
}
/* Retrieve and validate device */
device = rdma_endpoint_get_device(ep);
if (OFI_UNLIKELY(device == NULL)) {
NCCL_OFI_WARN("Invalid device provided");
return -EINVAL;
}
dev_id = device->base.dev_id;
if (conn_resp->num_rails != ep->num_rails) {
NCCL_OFI_WARN("Unexpected number of remote rails for dev %d. Expected %i but got %i",
dev_id, ep->num_rails,
conn_resp->num_rails);
return -EINVAL;
}
if (conn_resp->num_control_rails != ep->num_control_rails) {
NCCL_OFI_WARN("Unexpected number of remote control rails for dev %d. Expected %i but got %i",
dev_id, ep->num_control_rails,
conn_resp->num_control_rails);
return -EINVAL;
}
/* Validate received comm ID */
if (OFI_UNLIKELY(conn_resp->local_comm_id >= device->num_comm_ids)) {
NCCL_OFI_WARN("Received an invalid communicator ID %u for device %d", conn_resp->local_comm_id,
dev_id);
return -EINVAL;
}
/* Set remote comm ID to remote recv comm ID */
s_comm->remote_comm_id = conn_resp->local_comm_id;
/* Initialize rails `1...num_rails-1' */
ret = init_send_comm_rails(s_comm, ep, dev_id,
conn_resp->ep_names,
conn_resp->num_rails,
conn_resp->control_ep_names,
conn_resp->num_control_rails);
if (ret != 0) {
return ret;
}
s_comm->conn_resp_req->free(s_comm->conn_resp_req, false);
s_comm->conn_resp_req = NULL;
nccl_ofi_freelist_entry_free(ep->conn_msg_fl, s_comm->conn_msg);
s_comm->conn_msg = NULL;
return ret;
}