in src/nccl_ofi_sendrecv.cpp [2190:2313]
static int sendrecv_endpoint_connect(nccl_net_ofi_ep_t *base_ep,
nccl_net_ofi_conn_handle_t *handle,
nccl_net_ofi_send_comm_t **send_comm)
{
int ret = 0;
ssize_t rc = 0;
*send_comm = NULL;
nccl_net_ofi_sendrecv_ep_t *ep =
(nccl_net_ofi_sendrecv_ep_t *)base_ep;
nccl_ofi_connection_info_t *conn_info = NULL;
/* Retrieve and validate devices */
nccl_net_ofi_sendrecv_device_t *device = sendrecv_endpoint_get_device(ep);
if (OFI_UNLIKELY(device == NULL)) {
NCCL_OFI_WARN("Error accessing devices array. Devices array has not been initialized.");
return -EINVAL;
}
int dev_id = device->base.dev_id;
/* Extract connection state of the communicator */
save_comm_state_t *comm_state = &(handle->state);
nccl_net_ofi_sendrecv_req_t *req = (nccl_net_ofi_sendrecv_req_t *)comm_state->req;
nccl_net_ofi_sendrecv_send_comm_t *s_comm =
(nccl_net_ofi_sendrecv_send_comm_t *)comm_state->comm;
/*
* Take appropriate actions based on connection stage of communicator.
*
* Once we have completed the actions for a particular stage, we proceed
* to the next one until failure. This is to ensure we make maximum
* progress in a single function invocation.
*/
nccl_ofi_comm_stage_t stage = comm_state->stage;
switch (stage) {
case COMM_CREATE_START:
/*
* When we are building the s_comm for the first time,
* it should *NOT* come initialized from handle.
*/
assert(s_comm == NULL);
/* Build send_comm */
ret = sendrecv_send_comm_create(handle, ep, &s_comm);
if (OFI_UNLIKELY(ret != 0 || s_comm == NULL)) {
return ret;
}
/* Prepare connect request to be sent to peer */
req = sendrecv_send_comm_prepare_send_req(s_comm);
if (OFI_UNLIKELY(req == NULL)) {
free(s_comm);
return -ENOMEM;
}
comm_state->stage = COMM_SEND_CONN;
fallthrough;
case COMM_SEND_CONN:
/* Send "connect" message to remote EP */
rc = sendrecv_send_comm_send_connect_message(s_comm, device, ep, req);
if (rc == -FI_EAGAIN) {
/* Save connection state */
comm_state->comm = &s_comm->base.base;
comm_state->req = &req->base;
return 0;
}
else if (rc != 0) {
sendrecv_send_comm_free_req(s_comm, dev_id, req, false);
free(s_comm);
return rc;
}
comm_state->stage = COMM_CONN_REQ_PENDING;
fallthrough;
case COMM_CONN_REQ_PENDING:
conn_info = (nccl_ofi_connection_info_t *)s_comm->conn_info->ptr;
if (conn_info->connect_to_self == 1) {
NCCL_OFI_TRACE(NCCL_NET, "Connect to self; short circuit cleanup");
/* short cut to avoid rendezvous
deadlock in GDR detection */
comm_state->stage = COMM_CONNECTED;
break;
}
/* Progress our engine to get completions */
ret = sendrecv_cq_process(ep->cq);
if (OFI_UNLIKELY(ret != 0)) {
assert((nccl_net_ofi_comm_t *)s_comm == req->comm);
sendrecv_send_comm_free_req(s_comm, dev_id, req, false);
free(s_comm);
return ret;
}
/* Check if the connect message is sent */
if (req->state != NCCL_OFI_SENDRECV_REQ_COMPLETED) {
/* Save connection state */
comm_state->comm = &s_comm->base.base;
comm_state->req = &req->base;
return 0;
}
comm_state->stage = COMM_CONNECTED;
break;
case COMM_RECV_CONN:
case COMM_CONN_RESP_REQ_PENDING:
case COMM_CONNECTED:
default:
NCCL_OFI_WARN("Invalid state of send communicator object: %d", stage);
return -EINVAL;
};
*send_comm = &s_comm->base;
assert((nccl_net_ofi_comm_t *)s_comm == req->comm);
conn_info = (nccl_ofi_connection_info_t *)s_comm->conn_info->ptr;
if (conn_info->connect_to_self != 1) {
sendrecv_send_comm_free_req(s_comm, dev_id, req, false);
nccl_ofi_freelist_entry_free(ep->conn_msg_fl, s_comm->conn_info);
s_comm->conn_info = NULL;
}
return ret;
}