in src/nccl_ofi_rdma.cpp [6771:6943]
static int connect(nccl_net_ofi_ep_t *base_ep,
nccl_net_ofi_conn_handle_t *handle,
nccl_net_ofi_send_comm_t **send_comm)
{
int ret = 0;
nccl_net_ofi_rdma_req_state_t conn_resp_req_state;
nccl_net_ofi_rdma_req_state_t conn_msg_state;
*send_comm = NULL;
nccl_net_ofi_rdma_ep_t *ep =
(nccl_net_ofi_rdma_ep_t *)base_ep;
/* Extract connection state of the communicator */
save_comm_state_t *comm_state = &(handle->state);
nccl_net_ofi_rdma_req_t *req = (nccl_net_ofi_rdma_req_t *)comm_state->req;
nccl_net_ofi_rdma_send_comm_t *s_comm =
(nccl_net_ofi_rdma_send_comm_t *)comm_state->comm;
/* Retrieve and validate devices */
nccl_net_ofi_rdma_device_t *device = (nccl_net_ofi_rdma_device_t *)base_ep->domain->device;
assert(device != NULL);
/* Connection establishment is not done yet */
nccl_ofi_comm_stage_t stage = comm_state->stage;
if (stage == COMM_CONNECTED) {
NCCL_OFI_WARN("Handle %p object already has an active send communicator (%p).",
handle, s_comm);
return -EINVAL;
}
ret = post_rx_buffs(ep);
if (ret != 0) {
NCCL_OFI_WARN("Error posting rx buffers: %d", ret);
return ret;
}
/*
* Take appropriate actions based on connection stage of communicator.
*
* Once we have completed the actions for a particular stage, we proceed
* to the next one until failure. This is to ensure we make maximum
* progress in a single function invocation.
*/
switch (stage) {
case COMM_CREATE_START:
/* COMM_CREATE_START: Allocate data required for the
* connect function */
/* When we are building the s_comm for the first time, */
/* it should *NOT* come initialized from handle. */
assert(s_comm == NULL);
/* Build send communicator with one comm rail */
ret = create_send_comm(handle, ep, &s_comm);
if (OFI_UNLIKELY(ret != 0)) {
return ret;
}
if (OFI_UNLIKELY(s_comm == NULL)) {
return -ENOMEM;
}
comm_state->comm = &s_comm->base.base;
/* Prepare connect request to be sent to peer */
req = prepare_send_conn_req(s_comm);
if (OFI_UNLIKELY(req == NULL)) {
send_comm_destroy(s_comm);
return -ENOMEM;
}
comm_state->req = &req->base;
/* Prepare request to receive connect response message */
s_comm->conn_resp_req = prepare_recv_conn_resp_req(s_comm);
if (OFI_UNLIKELY(s_comm->conn_resp_req == NULL)) {
send_comm_destroy(s_comm);
return -EINVAL;
}
comm_state->stage = COMM_SEND_CONN;
fallthrough;
case COMM_SEND_CONN:
/* COMM_SEND_CONN: Post a connect message to send peer connections */
ret = post_send_conn(s_comm, device, ep, req);
if (ret == -FI_EAGAIN) {
return 0;
}
else if (ret != 0) {
req->free(req, false);
send_comm_destroy(s_comm);
return ret;
}
comm_state->stage = COMM_CONN_REQ_PENDING;
fallthrough;
case COMM_CONN_REQ_PENDING:
/* COMM_CONN_REQ_PENDING: Wait until connect message
* has been sent. Afterwards, reset previously used
* request. */
/* Progress our engine to get completions */
ret = ofi_process_cq(ep);
if (OFI_UNLIKELY(ret != 0)) {
/* Send communicator cannot be closed since
* send request of send connect message is
* still pending */
return ret;
}
/* Check if the connect message is sent */
nccl_net_ofi_mutex_lock(&req->req_lock);
conn_msg_state = req->state;
nccl_net_ofi_mutex_unlock(&req->req_lock);
/* Wait until connect message is sent */
if (conn_msg_state != NCCL_OFI_RDMA_REQ_COMPLETED) {
return 0;
}
/* Release connect message request */
req->free(req, false);
comm_state->req = NULL;
req = NULL;
comm_state->stage = COMM_RECV_CONN;
fallthrough;
case COMM_RECV_CONN:
/* COMM_RECV_CONN: Receive connect response message from remote */
assert(s_comm && s_comm->num_rails > 0);
comm_state->stage = COMM_CONN_RESP_REQ_PENDING;
fallthrough;
case COMM_CONN_RESP_REQ_PENDING:
/* Progress our engine to get completions. If the
* connect response message has arrived, the
* connection establishment will be finalized. */
ret = ofi_process_cq(ep);
if (OFI_UNLIKELY(ret != 0)) {
return ret;
}
nccl_net_ofi_mutex_lock(&s_comm->conn_resp_req->req_lock);
conn_resp_req_state = s_comm->conn_resp_req->state;
nccl_net_ofi_mutex_unlock(&s_comm->conn_resp_req->req_lock);
/* Wait until conn resp message is received */
if (conn_resp_req_state != NCCL_OFI_RDMA_REQ_COMPLETED) {
return 0;
}
ret = finish_connect(s_comm);
if (OFI_UNLIKELY(ret != 0)) {
return ret;
}
comm_state->stage = COMM_CONNECTED;
break;
case COMM_CONNECTED:
default:
NCCL_OFI_WARN("Invalid state of send communicator object: %d", stage);
return -EINVAL;
};
nccl_net_ofi_mutex_lock(&comm_cleanup_list_lock);
++num_open_comms;
nccl_net_ofi_mutex_unlock(&comm_cleanup_list_lock);
*send_comm = &s_comm->base;
return ret;
}