static int sendrecv_endpoint_connect()

in src/nccl_ofi_sendrecv.cpp [2190:2313]


static int sendrecv_endpoint_connect(nccl_net_ofi_ep_t *base_ep,
				     nccl_net_ofi_conn_handle_t *handle,
				     nccl_net_ofi_send_comm_t **send_comm)
{
	int ret = 0;
	ssize_t rc = 0;
	*send_comm = NULL;
	nccl_net_ofi_sendrecv_ep_t *ep =
		(nccl_net_ofi_sendrecv_ep_t *)base_ep;
	nccl_ofi_connection_info_t *conn_info = NULL;
	
	/* Retrieve and validate devices */
	nccl_net_ofi_sendrecv_device_t *device = sendrecv_endpoint_get_device(ep);
	if (OFI_UNLIKELY(device == NULL)) {
		NCCL_OFI_WARN("Error accessing devices array. Devices array has not been initialized.");
		return -EINVAL;
	}
	int dev_id = device->base.dev_id;

	/* Extract connection state of the communicator */
	save_comm_state_t *comm_state = &(handle->state);
	nccl_net_ofi_sendrecv_req_t *req = (nccl_net_ofi_sendrecv_req_t *)comm_state->req;
	nccl_net_ofi_sendrecv_send_comm_t *s_comm =
		(nccl_net_ofi_sendrecv_send_comm_t *)comm_state->comm;

	/*
	 * Take appropriate actions based on connection stage of communicator.
	 *
	 * Once we have completed the actions for a particular stage, we proceed
	 * to the next one until failure. This is to ensure we make maximum
	 * progress in a single function invocation.
	 */
	nccl_ofi_comm_stage_t stage = comm_state->stage;
	switch (stage) {
	case COMM_CREATE_START:
		/*
		 * When we are building the s_comm for the first time,
		 * it should *NOT* come initialized from handle.
		 */
		assert(s_comm == NULL);

		/* Build send_comm */
		ret = sendrecv_send_comm_create(handle, ep, &s_comm);
		if (OFI_UNLIKELY(ret != 0 || s_comm == NULL)) {
			return ret;
		}

		/* Prepare connect request to be sent to peer */
		req = sendrecv_send_comm_prepare_send_req(s_comm);
		if (OFI_UNLIKELY(req == NULL)) {
			free(s_comm);
			return -ENOMEM;
		}

		comm_state->stage = COMM_SEND_CONN;

		fallthrough;
	case COMM_SEND_CONN:
		/* Send "connect" message to remote EP */
		rc = sendrecv_send_comm_send_connect_message(s_comm, device, ep, req);
		if (rc == -FI_EAGAIN) {
			/* Save connection state */
			comm_state->comm = &s_comm->base.base;
			comm_state->req = &req->base;
			return 0;
		}
		else if (rc != 0) {
			sendrecv_send_comm_free_req(s_comm, dev_id, req, false);
			free(s_comm);
			return rc;
		}

		comm_state->stage = COMM_CONN_REQ_PENDING;
		fallthrough;
	case COMM_CONN_REQ_PENDING:
		conn_info = (nccl_ofi_connection_info_t *)s_comm->conn_info->ptr;
		if (conn_info->connect_to_self == 1) {
			NCCL_OFI_TRACE(NCCL_NET, "Connect to self; short circuit cleanup");
			/* short cut to avoid rendezvous
			   deadlock in GDR detection */
			comm_state->stage = COMM_CONNECTED;
			break;
		}

		/* Progress our engine to get completions */
		ret = sendrecv_cq_process(ep->cq);
		if (OFI_UNLIKELY(ret != 0)) {
			assert((nccl_net_ofi_comm_t *)s_comm == req->comm);
			sendrecv_send_comm_free_req(s_comm, dev_id, req, false);
			free(s_comm);
			return ret;
		}

		/* Check if the connect message is sent */
		if (req->state != NCCL_OFI_SENDRECV_REQ_COMPLETED) {
			/* Save connection state */
			comm_state->comm = &s_comm->base.base;
			comm_state->req = &req->base;
			return 0;
		}

		comm_state->stage = COMM_CONNECTED;

		break;

	case COMM_RECV_CONN:
	case COMM_CONN_RESP_REQ_PENDING:
	case COMM_CONNECTED:
	default:
		NCCL_OFI_WARN("Invalid state of send communicator object: %d", stage);
		return -EINVAL;
	};

	*send_comm = &s_comm->base;
	assert((nccl_net_ofi_comm_t *)s_comm == req->comm);
	conn_info = (nccl_ofi_connection_info_t *)s_comm->conn_info->ptr;
	if (conn_info->connect_to_self != 1) {
		sendrecv_send_comm_free_req(s_comm, dev_id, req, false);
		nccl_ofi_freelist_entry_free(ep->conn_msg_fl, s_comm->conn_info);
		s_comm->conn_info = NULL;
	}

	return ret;
}