in src/nccl_ofi_sendrecv.cpp [1523:1690]
static int sendrecv_listen_comm_accept(nccl_net_ofi_listen_comm_t *listen_comm,
nccl_net_ofi_recv_comm_t **recv_comm)
{
int ret = 0;
nccl_net_ofi_sendrecv_listen_comm_t *l_comm =
(nccl_net_ofi_sendrecv_listen_comm_t *)listen_comm;
if (l_comm->state.stage != COMM_CONN_REQ_PENDING && l_comm->accepted) {
NCCL_OFI_WARN("listen_comm %p object already has an active connection (%d).",
listen_comm, l_comm->accepted);
return -EINVAL;
}
*recv_comm = NULL;
/* Extract communicator state from listen communicator object */
save_comm_state_t *comm_state = &l_comm->state;
nccl_net_ofi_sendrecv_recv_comm_t *r_comm;
nccl_net_ofi_sendrecv_req_t *req = (nccl_net_ofi_sendrecv_req_t *)comm_state->req;
/* Retrieve and validate endpoint */
nccl_net_ofi_sendrecv_ep_t *ep =
(nccl_net_ofi_sendrecv_ep_t *)l_comm->base.base.ep;
if (OFI_UNLIKELY(ep == NULL)) {
ret = -EINVAL;
NCCL_OFI_WARN("Invalid endpoint provided");
return ret;
}
nccl_net_ofi_sendrecv_domain_t *domain =
sendrecv_endpoint_get_domain(ep);
assert(domain != NULL);
/* Retrieve and validate device */
nccl_net_ofi_sendrecv_device_t *device =
sendrecv_endpoint_get_device(ep);
if (OFI_UNLIKELY(device == NULL)) {
ret = -EINVAL;
NCCL_OFI_WARN("Invalid device provided");
return ret;
}
nccl_ofi_connection_info_t *conn_info = NULL;
if (l_comm->conn_info != NULL) {
conn_info = (nccl_ofi_connection_info_t *)l_comm->conn_info->ptr;
}
/*
* Take appropriate actions based on connection stage of communicator.
*
* Once we have completed the actions for a particular stage, we proceed
* to the next one until failure. This is to ensure we make maximum
* progress in a single function invocation.
*/
nccl_ofi_comm_stage_t stage = comm_state->stage;
switch (stage) {
case COMM_CREATE_START:
/*
* The libfabric resources maintained by the endpoint
* structure is passed from l_comm to r_comm so they can
* then be used by nccl_net_ofi_irecv. We want to make
* sure those resources are not freed up when we call
* nccl_net_ofi_closeListen so we maintain an additional
* refcnt and free it up when nccl_net_ofi_closeRecv is
* called.
*/
nccl_net_ofi_mutex_lock(&(domain->base.domain_lock));
ep->base.ref_cnt++;
nccl_net_ofi_mutex_unlock(&(domain->base.domain_lock));
/* Prepare receive request to accept connections */
req = sendrecv_recv_req_prepare(l_comm);
if (req == NULL) {
return -ENOMEM;
}
comm_state->stage = COMM_RECV_CONN;
fallthrough;
case COMM_RECV_CONN:
/* Allocate memory for peer address for the first time ONLY */
if (l_comm->conn_info == NULL) {
l_comm->conn_info = nccl_ofi_freelist_entry_alloc(ep->conn_msg_fl);
if (l_comm->conn_info == NULL) {
NCCL_OFI_WARN("Failed to allocate connection info entry");
free(req);
return -ENOMEM;
}
conn_info = (nccl_ofi_connection_info_t *)l_comm->conn_info->ptr;
}
/* Post a receive message to receive peer connections */
ret = sendrecv_recv_conn_post(l_comm, ep, conn_info,
sizeof(nccl_ofi_connection_info_t), req);
if (ret == -FI_EAGAIN) {
/* Save recv request and buffer address for retry */
comm_state->req = &req->base;
return 0;
} else if (ret != 0) {
free(req);
nccl_ofi_freelist_entry_free(ep->conn_msg_fl, l_comm->conn_info);
l_comm->conn_info = NULL;
return ret;
}
comm_state->stage = COMM_CONN_REQ_PENDING;
fallthrough;
case COMM_CONN_REQ_PENDING:
/* Progress NCCL OFI engine so that connection is accepted */
ret = sendrecv_cq_process(ep->cq);
if (OFI_UNLIKELY(ret != 0)) {
free(req);
return ret;
}
if (l_comm->accepted != true) {
/* Save recv request and buffer to retest completion */
comm_state->req = &req->base;
return 0;
}
if (OFI_UNLIKELY(conn_info == NULL)) {
NCCL_OFI_WARN("conn_info unexpectedly NULL in COMM_CONN_REQ_PENDING");
return -EINVAL;
}
if (conn_info->connect_to_self) {
NCCL_OFI_TRACE(NCCL_NET, "Accept from self; cleaning up");
nccl_net_ofi_sendrecv_req_t *conn_info_req =
(nccl_net_ofi_sendrecv_req_t *)conn_info->req;
if (conn_info_req->state != NCCL_OFI_SENDRECV_REQ_COMPLETED) {
return 0;
}
}
/* Done processing the request so free it */
free(req);
comm_state->stage = COMM_CONNECTED;
break;
case COMM_SEND_CONN:
case COMM_CONN_RESP_REQ_PENDING:
case COMM_CONNECTED:
default:
NCCL_OFI_WARN("Invalid state of receive communicator object: %d",
stage);
return -EINVAL;
}
/* Prepare receive communicator object for the received peer connection */
r_comm = sendrecv_recv_comm_prepare(l_comm, device, domain, ep, conn_info->ep_name);
if (OFI_UNLIKELY(r_comm == NULL)) {
return -ENOMEM;
}
nccl_ofi_freelist_entry_free(ep->conn_msg_fl, l_comm->conn_info);
l_comm->conn_info = NULL;
comm_state->comm = &r_comm->base.base;
*recv_comm = &r_comm->base;
return ret;
}