in src/nccl_ofi_rdma.cpp [1235:1378]
static inline int handle_rx_buff_recv(nccl_net_ofi_rdma_device_t *device, uint16_t rail_id, struct fi_cq_data_entry *cq_entry,
nccl_net_ofi_rdma_req_t *rx_buff_req, bool eager)
{
int ret = 0;
rdma_req_rx_buff_data_t *rx_buff_data = NULL;
nccl_ofi_rdma_connection_info_t *conn_msg = NULL;
nccl_ofi_rdma_connection_info_t *conn_resp_msg = NULL;
nccl_net_ofi_rdma_ctrl_msg_t *ctrl_msg = NULL;
nccl_net_ofi_rdma_listen_comm_t *l_comm = NULL;
nccl_net_ofi_rdma_send_comm_t *s_comm = NULL;
nccl_net_ofi_rdma_recv_comm_t *r_comm = NULL;
if (OFI_UNLIKELY(rx_buff_req == NULL)) {
NCCL_OFI_WARN("RECV event had NULL ctx!");
return -EINVAL;
}
if (OFI_UNLIKELY((eager && (rx_buff_req->type != NCCL_OFI_RDMA_EAGER_RX_BUFF))
|| ((!eager) && (rx_buff_req->type != NCCL_OFI_RDMA_CTRL_RX_BUFF)))) {
NCCL_OFI_WARN("Invalid non-rx_buff request as ctx!");
return -EINVAL;
}
rx_buff_data = get_rx_buff_data(rx_buff_req);
rx_buff_data->recv_len = cq_entry->len;
nccl_net_ofi_rdma_ep_t *ep = rx_buff_data->ep;
/* Make sure the rx message is coming from the right place */
#ifndef NDEBUG
if (eager) {
/* Eager messages should be received on data rails */
assert(rx_buff_data->rail == &ep->rails[rail_id]);
} else {
/* Non-eager messages should be received on the control rail */
assert(rx_buff_data->rail == &ep->control_rails[rail_id]);
}
#endif
/* The first 4 bits are the type, but we don't have a base
* header type. So cast to a control message and lookup the
* type from there. */
nccl_ofi_rdma_msg_type_t msg_type = eager ? (nccl_ofi_rdma_msg_type_t)NCCL_OFI_RDMA_MSG_EAGER
: get_rx_ctrl_msg(rx_buff_data)->type;
switch (msg_type) {
case NCCL_OFI_RDMA_MSG_CONN:
/* CONN receive completion */
assert(sizeof(nccl_ofi_rdma_connection_info_t) == cq_entry->len);
conn_msg = get_rx_connection_msg(rx_buff_data);
l_comm = rdma_device_get_listen_comm(device, conn_msg->remote_comm_id);
assert(l_comm->req.comm->type == NCCL_NET_OFI_LISTEN_COMM);
assert((nccl_net_ofi_comm_t *)l_comm == l_comm->req.comm);
/* Copy connection message in the communicator */
l_comm->conn_msg = *conn_msg;
ret = inc_req_completion(&l_comm->req, cq_entry->len, 1);
if (OFI_UNLIKELY(ret != 0)) {
goto exit;
}
/* Attempt to re-post rx buffer */
ret = repost_rx_buff(ep, rx_buff_req);
if (OFI_UNLIKELY(ret != 0)) {
NCCL_OFI_WARN("Failed to repost rx buff");
goto exit;
}
break;
case NCCL_OFI_RDMA_MSG_CONN_RESP:
/* CONN_RESP receive completion */
assert(sizeof(nccl_ofi_rdma_connection_info_t) == cq_entry->len);
conn_resp_msg = get_rx_connection_msg(rx_buff_data);
s_comm = rdma_device_get_send_comm(device, conn_resp_msg->remote_comm_id);
assert(NULL != s_comm->conn_resp_req);
assert(NCCL_NET_OFI_SEND_COMM == s_comm->conn_resp_req->comm->type);
assert((nccl_net_ofi_comm_t *)s_comm == s_comm->conn_resp_req->comm);
/* Copy connection response message in the communicator */
memcpy(s_comm->conn_msg->ptr, conn_resp_msg, sizeof(nccl_ofi_rdma_connection_info_t));
ret = inc_req_completion(s_comm->conn_resp_req, cq_entry->len, 1);
if (OFI_UNLIKELY(ret != 0)) {
goto exit;
}
/* Attempt to re-post rx buffer */
ret = repost_rx_buff(ep, rx_buff_req);
if (OFI_UNLIKELY(ret != 0)) {
NCCL_OFI_WARN("Failed to repost rx buff");
goto exit;
}
break;
case NCCL_OFI_RDMA_MSG_CTRL_NO_COMPLETION:
/* fall through to NCCL_OFI_RDMA_MSG_CTRL case */
case NCCL_OFI_RDMA_MSG_CTRL:
/* CTRL receive completion */
assert(cq_entry->len == nccl_net_ofi_rdma_ctrl_msg_size(ep->num_rails, ep->use_long_rkeys));
ctrl_msg = get_rx_ctrl_msg(rx_buff_data);
s_comm = rdma_device_get_send_comm(device, ctrl_msg->remote_comm_id);
NCCL_OFI_TRACE_SEND_CTRL_RECV(s_comm->base.base.dev_id, rail_id, s_comm, ctrl_msg->msg_seq_num);
ret = handle_ctrl_recv(s_comm, ctrl_msg->msg_seq_num, rx_buff_req);
if (OFI_UNLIKELY(ret != 0)) {
goto exit;
}
nccl_net_ofi_mutex_lock(&s_comm->ctrl_recv_lock);
s_comm->n_ctrl_received += 1;
nccl_net_ofi_mutex_unlock(&s_comm->ctrl_recv_lock);
break;
case NCCL_OFI_RDMA_MSG_CLOSE:
assert(cq_entry->len == sizeof(nccl_net_ofi_rdma_close_msg_t));
ret = handle_close_msg_recv(rx_buff_req);
break;
case NCCL_OFI_RDMA_MSG_EAGER:
/* Eager message receive completion */
r_comm = rdma_device_get_recv_comm(device, GET_COMM_ID_FROM_IMM(cq_entry->data));
NCCL_OFI_TRACE_EAGER_RECV(r_comm->base.base.dev_id, rail_id, r_comm,
GET_SEQ_NUM_FROM_IMM(cq_entry->data));
ret = handle_eager_recv(r_comm, GET_SEQ_NUM_FROM_IMM(cq_entry->data), rx_buff_req);
if (OFI_UNLIKELY(ret != 0)) {
goto exit;
}
break;
default:
NCCL_OFI_WARN("Recv completion with unexpected type");
ret = -EINVAL;
goto exit;
}
exit:
return ret;
}