static int accept()

in src/nccl_ofi_rdma.cpp [4925:5140]


static int accept(nccl_net_ofi_listen_comm_t *listen_comm,
			   nccl_net_ofi_recv_comm_t **recv_comm)
{
	int ret = 0;
	nccl_net_ofi_rdma_req_state_t req_state;

	nccl_net_ofi_rdma_listen_comm_t *l_comm =
		(nccl_net_ofi_rdma_listen_comm_t *)listen_comm;

	/* Extract communicator state from listen communicator object */
	nccl_net_ofi_rdma_recv_comm_t *r_comm = l_comm->r_comm;

	/* Extract request used for connect and connect response message */
	nccl_net_ofi_rdma_req_t *req = &l_comm->req;

	/* Extract struct used for message exchange */
	nccl_ofi_rdma_connection_info_t *conn_msg = &l_comm->conn_msg;

	/* Retrieve and validate endpoint */
	nccl_net_ofi_rdma_ep_t *l_comm_ep = (nccl_net_ofi_rdma_ep_t *)l_comm->base.base.ep;
	assert(l_comm_ep != NULL);

	nccl_net_ofi_rdma_ep_t *ep = NULL;
	if (l_comm->r_comm) {
		ep = (nccl_net_ofi_rdma_ep_t *)r_comm->base.base.ep;
		assert(ep != NULL);
	}

	/* Retrieve and validate device */
	nccl_net_ofi_rdma_domain_t *domain = rdma_endpoint_get_domain(l_comm_ep);
	assert(domain != NULL);
	nccl_net_ofi_rdma_device_t *device = rdma_domain_get_device(domain);
	assert(device != NULL);

	int dev_id = device->base.dev_id;

	if (l_comm->stage == COMM_CONNECTED) {
		NCCL_OFI_WARN("listenComm %p object already has an active connection (%d).",
			      l_comm, l_comm->stage);
		ret = -EINVAL;
		goto exit;
	}

	/* Set return receive communicator to NULL until accept finalizes */
	*recv_comm = NULL;

	/*
	 * Take appropriate actions based on connection stage of communicator.
	 *
	 * Once we have completed the actions for a particular stage, we proceed
	 * to the next one until failure. This is to ensure we make maximum
	 * progress in a single function invocation.
	 */
	switch (l_comm->stage) {
	case COMM_CREATE_START:
		/* COMM_CREATE_START:Allocate data required for the accept function */

		l_comm->stage = COMM_RECV_CONN;

		fallthrough;
	case COMM_RECV_CONN:

		l_comm->stage = COMM_CONN_REQ_PENDING;

		fallthrough;
	case COMM_CONN_REQ_PENDING:
		/* COMM_CONN_REQ_PENDING: Wait until connect message has been
		 * received. Then, prepare for sending connect accept message,
		 * i.e., create receive communicator and reset the previously
		 * used request. */

		/* Progress NCCL OFI engine so that connection is accepted */
		ret = ofi_process_cq(l_comm_ep);
		if (OFI_UNLIKELY(ret != 0)) {
			goto exit;
		}

		/* Check if the connect message is received */
		nccl_net_ofi_mutex_lock(&req->req_lock);
		req_state = req->state;
		nccl_net_ofi_mutex_unlock(&req->req_lock);

		/* Wait until connect message is sent */
		if (req_state != NCCL_OFI_RDMA_REQ_COMPLETED) {
			return 0;
		}

		/* Number of remote rails and number of local rails match */
		if (conn_msg->num_rails != l_comm_ep->num_rails) {
			NCCL_OFI_WARN("Unexpected number of remote rails for dev %d. Expected %i but got %i",
				      dev_id, l_comm_ep->num_rails,
				      conn_msg->num_rails);
			ret = -EINVAL;
			goto exit;
		}

		/* Number of remote control rails and number of local control rails match */
		if (conn_msg->num_control_rails != l_comm_ep->num_control_rails) {
			NCCL_OFI_WARN("Unexpected number of remote control rails for dev %d. Expected %i but got %i",
				      dev_id, l_comm_ep->num_control_rails,
				      conn_msg->num_control_rails);
			ret = -EINVAL;
			goto exit;
		}

		/* Prepare receive communicator object for the received peer connection */
		r_comm = prepare_recv_comm(domain, l_comm_ep, conn_msg);
		if (OFI_UNLIKELY(r_comm == NULL)) {
			ret = -EINVAL;
			goto exit;
		}
		l_comm->r_comm = r_comm;

		/* prepare_recv_comm establishes the endpoint used for this r_comm,
		   so set the pointer here. */
		ep = (nccl_net_ofi_rdma_ep_t *)r_comm->base.base.ep;
		assert(ep != NULL);

		/*
		 * The libfabric resources maintained by the endpoint
		 * structure is passed from l_comm to r_comm so they can
		 * then be used by nccl_net_ofi_irecv. We want to make
		 * sure those resources are not freed up when we call
		 * nccl_net_ofi_closeListen so we maintain an additional
		 * refcnt and free it up when nccl_net_ofi_closeRecv is
		 * called.
		 */
		nccl_net_ofi_mutex_lock(&(domain->base.domain_lock));
		ep->base.ref_cnt++;
		nccl_net_ofi_mutex_unlock(&(domain->base.domain_lock));

		/* Reset request state for connect response message */
		prepare_send_conn_resp_req(l_comm);

		/* Initialize connect response message */
		ret = prepare_conn_resp(ep, r_comm, dev_id);
		if (ret != 0) {
			goto exit;
		}

		l_comm->stage = COMM_SEND_CONN;

		fallthrough;
	case COMM_SEND_CONN:

		/* COMM_SEND_CONN: Send connect response message to remote */
		ret = post_send_conn_resp(r_comm, device, ep, req);
		if (ret == -FI_EAGAIN) {
			return 0;
		}
		else if (ret != 0) {
			goto exit;
		}

		l_comm->stage = COMM_CONN_RESP_REQ_PENDING;

		fallthrough;
	case COMM_CONN_RESP_REQ_PENDING:
		/* COMM_CONN_RESP_REQ_PENDING: Wait until connect
		 * response message has been delivered. Afterwards,
		 * cleanup and return receive communicator. */

		/* Progress our engine to get completions */
		ret = ofi_process_cq(ep);
		if (OFI_UNLIKELY(ret != 0)) {
			goto exit;
		}

		/* Check if the connect response message is sent */
		nccl_net_ofi_mutex_lock(&req->req_lock);
		req_state = req->state;
		nccl_net_ofi_mutex_unlock(&req->req_lock);

		/* Wait until connect response message is sent */
		if (req_state != NCCL_OFI_RDMA_REQ_COMPLETED) {
			return 0;
		}

		/* The free list item was allocated on the ep
		 * associated with the r_comm (as opposed to the
		 * l_comm).  ep should point to the recv comm ep at
		 * this point.
		 */
		nccl_ofi_freelist_entry_free(ep->conn_msg_fl, r_comm->conn_msg);
		r_comm->conn_msg = NULL;

		*recv_comm = &r_comm->base;

		/* NULL pointer to recv communicator stored in listen
		 * communicator's state to avoid that `close_listen_recv_comm'
		 * deallocates the receive communicator */
		l_comm->r_comm = NULL;

		l_comm->stage = COMM_CONNECTED;

		break;

	case COMM_CONNECTED:
	default:
		NCCL_OFI_WARN("Invalid state of receive communicator object: %d",
			      l_comm->stage);
		ret = -EINVAL;
	}

	nccl_net_ofi_mutex_lock(&comm_cleanup_list_lock);
	++num_open_comms;
	nccl_net_ofi_mutex_unlock(&comm_cleanup_list_lock);

 exit:;
	/* Close receive communicator in case listen operation failed */
	int close_ret = close_listen_recv_comm(l_comm);
	if (close_ret) {
		NCCL_OFI_WARN("Failed to close listen communicator");
	}
	return ret ? ret : close_ret;
}