static int sendrecv_listen_comm_accept()

in src/nccl_ofi_sendrecv.cpp [1523:1690]


static int sendrecv_listen_comm_accept(nccl_net_ofi_listen_comm_t *listen_comm,
				       nccl_net_ofi_recv_comm_t **recv_comm)
{
	int ret = 0;

	nccl_net_ofi_sendrecv_listen_comm_t *l_comm =
		(nccl_net_ofi_sendrecv_listen_comm_t *)listen_comm;

	if (l_comm->state.stage != COMM_CONN_REQ_PENDING && l_comm->accepted) {
		NCCL_OFI_WARN("listen_comm %p object already has an active connection (%d).",
			      listen_comm, l_comm->accepted);
		return -EINVAL;
	}

	*recv_comm = NULL;

	/* Extract communicator state from listen communicator object */
	save_comm_state_t *comm_state = &l_comm->state;
	nccl_net_ofi_sendrecv_recv_comm_t *r_comm;
	nccl_net_ofi_sendrecv_req_t *req = (nccl_net_ofi_sendrecv_req_t *)comm_state->req;

	/* Retrieve and validate endpoint */
	nccl_net_ofi_sendrecv_ep_t *ep =
		(nccl_net_ofi_sendrecv_ep_t *)l_comm->base.base.ep;
	if (OFI_UNLIKELY(ep == NULL)) {
		ret = -EINVAL;
		NCCL_OFI_WARN("Invalid endpoint provided");
		return ret;
	}

	nccl_net_ofi_sendrecv_domain_t *domain =
		sendrecv_endpoint_get_domain(ep);
	assert(domain != NULL);

	/* Retrieve and validate device */
	nccl_net_ofi_sendrecv_device_t *device =
		sendrecv_endpoint_get_device(ep);
	if (OFI_UNLIKELY(device == NULL)) {
		ret = -EINVAL;
		NCCL_OFI_WARN("Invalid device provided");
		return ret;
	}

	nccl_ofi_connection_info_t *conn_info = NULL;
	if (l_comm->conn_info != NULL) {
		conn_info = (nccl_ofi_connection_info_t *)l_comm->conn_info->ptr;
	}
	
	/*
	 * Take appropriate actions based on connection stage of communicator.
	 *
	 * Once we have completed the actions for a particular stage, we proceed
	 * to the next one until failure. This is to ensure we make maximum
	 * progress in a single function invocation.
	 */
	nccl_ofi_comm_stage_t stage = comm_state->stage;
	switch (stage) {
	case COMM_CREATE_START:

		/*
		 * The libfabric resources maintained by the endpoint
		 * structure is passed from l_comm to r_comm so they can
		 * then be used by nccl_net_ofi_irecv. We want to make
		 * sure those resources are not freed up when we call
		 * nccl_net_ofi_closeListen so we maintain an additional
		 * refcnt and free it up when nccl_net_ofi_closeRecv is
		 * called.
		 */
		nccl_net_ofi_mutex_lock(&(domain->base.domain_lock));
		ep->base.ref_cnt++;
		nccl_net_ofi_mutex_unlock(&(domain->base.domain_lock));

		/* Prepare receive request to accept connections */
		req = sendrecv_recv_req_prepare(l_comm);
		if (req == NULL) {
			return -ENOMEM;
		}

		comm_state->stage = COMM_RECV_CONN;
		fallthrough;
	case COMM_RECV_CONN:

		/* Allocate memory for peer address for the first time ONLY */
		if (l_comm->conn_info == NULL) {
			l_comm->conn_info = nccl_ofi_freelist_entry_alloc(ep->conn_msg_fl);
			if (l_comm->conn_info == NULL) {
				NCCL_OFI_WARN("Failed to allocate connection info entry");
				free(req);
				return -ENOMEM;
			}
			conn_info = (nccl_ofi_connection_info_t *)l_comm->conn_info->ptr;
		}

		/* Post a receive message to receive peer connections */
		ret = sendrecv_recv_conn_post(l_comm, ep, conn_info,
			sizeof(nccl_ofi_connection_info_t), req);
		if (ret == -FI_EAGAIN) {
			/* Save recv request and buffer address for retry */
			comm_state->req = &req->base;
			return 0;
		} else if (ret != 0) {
			free(req);
			nccl_ofi_freelist_entry_free(ep->conn_msg_fl, l_comm->conn_info);
			l_comm->conn_info = NULL;
			return ret;
		}

		comm_state->stage = COMM_CONN_REQ_PENDING;

		fallthrough;
	case COMM_CONN_REQ_PENDING:

		/* Progress NCCL OFI engine so that connection is accepted */
		ret = sendrecv_cq_process(ep->cq);
		if (OFI_UNLIKELY(ret != 0)) {
			free(req);
			return ret;
		}

		if (l_comm->accepted != true) {
			/* Save recv request and buffer to retest completion */
			comm_state->req = &req->base;
			return 0;
		}

		if (OFI_UNLIKELY(conn_info == NULL)) {
			NCCL_OFI_WARN("conn_info unexpectedly NULL in COMM_CONN_REQ_PENDING");
			return -EINVAL;
		}

		if (conn_info->connect_to_self) {
			NCCL_OFI_TRACE(NCCL_NET, "Accept from self; cleaning up");
			nccl_net_ofi_sendrecv_req_t *conn_info_req =
				(nccl_net_ofi_sendrecv_req_t *)conn_info->req;
			if (conn_info_req->state != NCCL_OFI_SENDRECV_REQ_COMPLETED) {
				return 0;
			}
		}

		/* Done processing the request so free it */
		free(req);
		comm_state->stage = COMM_CONNECTED;

		break;

	case COMM_SEND_CONN:
	case COMM_CONN_RESP_REQ_PENDING:
	case COMM_CONNECTED:
	default:
		NCCL_OFI_WARN("Invalid state of receive communicator object: %d",
			      stage);
		return -EINVAL;
	}

	/* Prepare receive communicator object for the received peer connection */
	r_comm = sendrecv_recv_comm_prepare(l_comm, device, domain, ep, conn_info->ep_name);
	if (OFI_UNLIKELY(r_comm == NULL)) {
		return -ENOMEM;
	}

	nccl_ofi_freelist_entry_free(ep->conn_msg_fl, l_comm->conn_info);
	l_comm->conn_info = NULL;

	comm_state->comm = &r_comm->base.base;
	*recv_comm = &r_comm->base;

	return ret;
}