in src/nccl_ofi_rdma.cpp [4925:5140]
static int accept(nccl_net_ofi_listen_comm_t *listen_comm,
nccl_net_ofi_recv_comm_t **recv_comm)
{
int ret = 0;
nccl_net_ofi_rdma_req_state_t req_state;
nccl_net_ofi_rdma_listen_comm_t *l_comm =
(nccl_net_ofi_rdma_listen_comm_t *)listen_comm;
/* Extract communicator state from listen communicator object */
nccl_net_ofi_rdma_recv_comm_t *r_comm = l_comm->r_comm;
/* Extract request used for connect and connect response message */
nccl_net_ofi_rdma_req_t *req = &l_comm->req;
/* Extract struct used for message exchange */
nccl_ofi_rdma_connection_info_t *conn_msg = &l_comm->conn_msg;
/* Retrieve and validate endpoint */
nccl_net_ofi_rdma_ep_t *l_comm_ep = (nccl_net_ofi_rdma_ep_t *)l_comm->base.base.ep;
assert(l_comm_ep != NULL);
nccl_net_ofi_rdma_ep_t *ep = NULL;
if (l_comm->r_comm) {
ep = (nccl_net_ofi_rdma_ep_t *)r_comm->base.base.ep;
assert(ep != NULL);
}
/* Retrieve and validate device */
nccl_net_ofi_rdma_domain_t *domain = rdma_endpoint_get_domain(l_comm_ep);
assert(domain != NULL);
nccl_net_ofi_rdma_device_t *device = rdma_domain_get_device(domain);
assert(device != NULL);
int dev_id = device->base.dev_id;
if (l_comm->stage == COMM_CONNECTED) {
NCCL_OFI_WARN("listenComm %p object already has an active connection (%d).",
l_comm, l_comm->stage);
ret = -EINVAL;
goto exit;
}
/* Set return receive communicator to NULL until accept finalizes */
*recv_comm = NULL;
/*
* Take appropriate actions based on connection stage of communicator.
*
* Once we have completed the actions for a particular stage, we proceed
* to the next one until failure. This is to ensure we make maximum
* progress in a single function invocation.
*/
switch (l_comm->stage) {
case COMM_CREATE_START:
/* COMM_CREATE_START:Allocate data required for the accept function */
l_comm->stage = COMM_RECV_CONN;
fallthrough;
case COMM_RECV_CONN:
l_comm->stage = COMM_CONN_REQ_PENDING;
fallthrough;
case COMM_CONN_REQ_PENDING:
/* COMM_CONN_REQ_PENDING: Wait until connect message has been
* received. Then, prepare for sending connect accept message,
* i.e., create receive communicator and reset the previously
* used request. */
/* Progress NCCL OFI engine so that connection is accepted */
ret = ofi_process_cq(l_comm_ep);
if (OFI_UNLIKELY(ret != 0)) {
goto exit;
}
/* Check if the connect message is received */
nccl_net_ofi_mutex_lock(&req->req_lock);
req_state = req->state;
nccl_net_ofi_mutex_unlock(&req->req_lock);
/* Wait until connect message is sent */
if (req_state != NCCL_OFI_RDMA_REQ_COMPLETED) {
return 0;
}
/* Number of remote rails and number of local rails match */
if (conn_msg->num_rails != l_comm_ep->num_rails) {
NCCL_OFI_WARN("Unexpected number of remote rails for dev %d. Expected %i but got %i",
dev_id, l_comm_ep->num_rails,
conn_msg->num_rails);
ret = -EINVAL;
goto exit;
}
/* Number of remote control rails and number of local control rails match */
if (conn_msg->num_control_rails != l_comm_ep->num_control_rails) {
NCCL_OFI_WARN("Unexpected number of remote control rails for dev %d. Expected %i but got %i",
dev_id, l_comm_ep->num_control_rails,
conn_msg->num_control_rails);
ret = -EINVAL;
goto exit;
}
/* Prepare receive communicator object for the received peer connection */
r_comm = prepare_recv_comm(domain, l_comm_ep, conn_msg);
if (OFI_UNLIKELY(r_comm == NULL)) {
ret = -EINVAL;
goto exit;
}
l_comm->r_comm = r_comm;
/* prepare_recv_comm establishes the endpoint used for this r_comm,
so set the pointer here. */
ep = (nccl_net_ofi_rdma_ep_t *)r_comm->base.base.ep;
assert(ep != NULL);
/*
* The libfabric resources maintained by the endpoint
* structure is passed from l_comm to r_comm so they can
* then be used by nccl_net_ofi_irecv. We want to make
* sure those resources are not freed up when we call
* nccl_net_ofi_closeListen so we maintain an additional
* refcnt and free it up when nccl_net_ofi_closeRecv is
* called.
*/
nccl_net_ofi_mutex_lock(&(domain->base.domain_lock));
ep->base.ref_cnt++;
nccl_net_ofi_mutex_unlock(&(domain->base.domain_lock));
/* Reset request state for connect response message */
prepare_send_conn_resp_req(l_comm);
/* Initialize connect response message */
ret = prepare_conn_resp(ep, r_comm, dev_id);
if (ret != 0) {
goto exit;
}
l_comm->stage = COMM_SEND_CONN;
fallthrough;
case COMM_SEND_CONN:
/* COMM_SEND_CONN: Send connect response message to remote */
ret = post_send_conn_resp(r_comm, device, ep, req);
if (ret == -FI_EAGAIN) {
return 0;
}
else if (ret != 0) {
goto exit;
}
l_comm->stage = COMM_CONN_RESP_REQ_PENDING;
fallthrough;
case COMM_CONN_RESP_REQ_PENDING:
/* COMM_CONN_RESP_REQ_PENDING: Wait until connect
* response message has been delivered. Afterwards,
* cleanup and return receive communicator. */
/* Progress our engine to get completions */
ret = ofi_process_cq(ep);
if (OFI_UNLIKELY(ret != 0)) {
goto exit;
}
/* Check if the connect response message is sent */
nccl_net_ofi_mutex_lock(&req->req_lock);
req_state = req->state;
nccl_net_ofi_mutex_unlock(&req->req_lock);
/* Wait until connect response message is sent */
if (req_state != NCCL_OFI_RDMA_REQ_COMPLETED) {
return 0;
}
/* The free list item was allocated on the ep
* associated with the r_comm (as opposed to the
* l_comm). ep should point to the recv comm ep at
* this point.
*/
nccl_ofi_freelist_entry_free(ep->conn_msg_fl, r_comm->conn_msg);
r_comm->conn_msg = NULL;
*recv_comm = &r_comm->base;
/* NULL pointer to recv communicator stored in listen
* communicator's state to avoid that `close_listen_recv_comm'
* deallocates the receive communicator */
l_comm->r_comm = NULL;
l_comm->stage = COMM_CONNECTED;
break;
case COMM_CONNECTED:
default:
NCCL_OFI_WARN("Invalid state of receive communicator object: %d",
l_comm->stage);
ret = -EINVAL;
}
nccl_net_ofi_mutex_lock(&comm_cleanup_list_lock);
++num_open_comms;
nccl_net_ofi_mutex_unlock(&comm_cleanup_list_lock);
exit:;
/* Close receive communicator in case listen operation failed */
int close_ret = close_listen_recv_comm(l_comm);
if (close_ret) {
NCCL_OFI_WARN("Failed to close listen communicator");
}
return ret ? ret : close_ret;
}