static inline int handle_rx_buff_recv()

in src/nccl_ofi_rdma.cpp [1235:1378]


static inline int handle_rx_buff_recv(nccl_net_ofi_rdma_device_t *device, uint16_t rail_id, struct fi_cq_data_entry *cq_entry,
				     nccl_net_ofi_rdma_req_t *rx_buff_req, bool eager)
{
	int ret = 0;
	rdma_req_rx_buff_data_t *rx_buff_data = NULL;
	nccl_ofi_rdma_connection_info_t *conn_msg = NULL;
	nccl_ofi_rdma_connection_info_t *conn_resp_msg = NULL;
	nccl_net_ofi_rdma_ctrl_msg_t *ctrl_msg = NULL;
	nccl_net_ofi_rdma_listen_comm_t *l_comm = NULL;
	nccl_net_ofi_rdma_send_comm_t *s_comm = NULL;
	nccl_net_ofi_rdma_recv_comm_t *r_comm = NULL;

	if (OFI_UNLIKELY(rx_buff_req == NULL)) {
		NCCL_OFI_WARN("RECV event had NULL ctx!");
		return -EINVAL;
	}
	if (OFI_UNLIKELY((eager && (rx_buff_req->type != NCCL_OFI_RDMA_EAGER_RX_BUFF))
			 || ((!eager) && (rx_buff_req->type != NCCL_OFI_RDMA_CTRL_RX_BUFF)))) {
		NCCL_OFI_WARN("Invalid non-rx_buff request as ctx!");
		return -EINVAL;
	}

	rx_buff_data = get_rx_buff_data(rx_buff_req);
	rx_buff_data->recv_len = cq_entry->len;

	nccl_net_ofi_rdma_ep_t *ep = rx_buff_data->ep;

	/* Make sure the rx message is coming from the right place */
#ifndef NDEBUG
	if (eager) {
		/* Eager messages should be received on data rails */
		assert(rx_buff_data->rail == &ep->rails[rail_id]);
	} else {
		/* Non-eager messages should be received on the control rail */
		assert(rx_buff_data->rail == &ep->control_rails[rail_id]);
	}
#endif

	/* The first 4 bits are the type, but we don't have a base
	 * header type.  So cast to a control message and lookup the
	 * type from there. */
	nccl_ofi_rdma_msg_type_t msg_type = eager ? (nccl_ofi_rdma_msg_type_t)NCCL_OFI_RDMA_MSG_EAGER
	                                          :  get_rx_ctrl_msg(rx_buff_data)->type;

	switch (msg_type) {
	case NCCL_OFI_RDMA_MSG_CONN:
		/* CONN receive completion */
		assert(sizeof(nccl_ofi_rdma_connection_info_t) == cq_entry->len);

		conn_msg = get_rx_connection_msg(rx_buff_data);
		l_comm = rdma_device_get_listen_comm(device, conn_msg->remote_comm_id);

		assert(l_comm->req.comm->type == NCCL_NET_OFI_LISTEN_COMM);
		assert((nccl_net_ofi_comm_t *)l_comm == l_comm->req.comm);

		/* Copy connection message in the communicator */
		l_comm->conn_msg = *conn_msg;

		ret = inc_req_completion(&l_comm->req, cq_entry->len, 1);
		if (OFI_UNLIKELY(ret != 0)) {
			goto exit;
		}

		/* Attempt to re-post rx buffer */
		ret = repost_rx_buff(ep, rx_buff_req);
		if (OFI_UNLIKELY(ret != 0)) {
			NCCL_OFI_WARN("Failed to repost rx buff");
			goto exit;
		}
		break;
	case NCCL_OFI_RDMA_MSG_CONN_RESP:
		/* CONN_RESP receive completion */
		assert(sizeof(nccl_ofi_rdma_connection_info_t) == cq_entry->len);

		conn_resp_msg = get_rx_connection_msg(rx_buff_data);
		s_comm = rdma_device_get_send_comm(device, conn_resp_msg->remote_comm_id);

		assert(NULL != s_comm->conn_resp_req);
		assert(NCCL_NET_OFI_SEND_COMM == s_comm->conn_resp_req->comm->type);
		assert((nccl_net_ofi_comm_t *)s_comm == s_comm->conn_resp_req->comm);

		/* Copy connection response message in the communicator */
		memcpy(s_comm->conn_msg->ptr, conn_resp_msg, sizeof(nccl_ofi_rdma_connection_info_t));

		ret = inc_req_completion(s_comm->conn_resp_req, cq_entry->len, 1);
		if (OFI_UNLIKELY(ret != 0)) {
			goto exit;
		}

		/* Attempt to re-post rx buffer */
		ret = repost_rx_buff(ep, rx_buff_req);
		if (OFI_UNLIKELY(ret != 0)) {
			NCCL_OFI_WARN("Failed to repost rx buff");
			goto exit;
		}
		break;
	case NCCL_OFI_RDMA_MSG_CTRL_NO_COMPLETION:
		/* fall through to NCCL_OFI_RDMA_MSG_CTRL case */
	case NCCL_OFI_RDMA_MSG_CTRL:
		/* CTRL receive completion */
		assert(cq_entry->len == nccl_net_ofi_rdma_ctrl_msg_size(ep->num_rails, ep->use_long_rkeys));

		ctrl_msg = get_rx_ctrl_msg(rx_buff_data);
		s_comm = rdma_device_get_send_comm(device, ctrl_msg->remote_comm_id);

		NCCL_OFI_TRACE_SEND_CTRL_RECV(s_comm->base.base.dev_id, rail_id, s_comm, ctrl_msg->msg_seq_num);

		ret = handle_ctrl_recv(s_comm, ctrl_msg->msg_seq_num, rx_buff_req);
		if (OFI_UNLIKELY(ret != 0)) {
			goto exit;
		}

		nccl_net_ofi_mutex_lock(&s_comm->ctrl_recv_lock);
		s_comm->n_ctrl_received += 1;
		nccl_net_ofi_mutex_unlock(&s_comm->ctrl_recv_lock);

		break;
	case NCCL_OFI_RDMA_MSG_CLOSE:
		assert(cq_entry->len == sizeof(nccl_net_ofi_rdma_close_msg_t));

		ret = handle_close_msg_recv(rx_buff_req);

		break;
	case NCCL_OFI_RDMA_MSG_EAGER:
		/* Eager message receive completion */

		r_comm = rdma_device_get_recv_comm(device, GET_COMM_ID_FROM_IMM(cq_entry->data));

		NCCL_OFI_TRACE_EAGER_RECV(r_comm->base.base.dev_id, rail_id, r_comm,
					  GET_SEQ_NUM_FROM_IMM(cq_entry->data));

		ret = handle_eager_recv(r_comm, GET_SEQ_NUM_FROM_IMM(cq_entry->data), rx_buff_req);
		if (OFI_UNLIKELY(ret != 0)) {
			goto exit;
		}
		break;
	default:
		NCCL_OFI_WARN("Recv completion with unexpected type");
		ret = -EINVAL;
		goto exit;
	}
exit:
	return ret;
}