in src/nccl_ofi_rdma.cpp [3524:3716]
static int recv(nccl_net_ofi_recv_comm_t *recv_comm, int n, void **buffers,
size_t *sizes, int *tags, nccl_net_ofi_mr_handle_t **mhandles,
nccl_net_ofi_req_t **base_req)
{
int ret = 0;
nccl_net_ofi_rdma_req_t *req = NULL;
nccl_net_ofi_rdma_recv_comm_t *r_comm = (nccl_net_ofi_rdma_recv_comm_t *)recv_comm;
rdma_req_recv_data_t *recv_data = NULL;
nccl_net_ofi_rdma_ep_t *ep = NULL;
nccl_net_ofi_rdma_domain_t *domain = NULL;
nccl_net_ofi_rdma_device_t *device = NULL;
int dev_id = 0;
nccl_net_ofi_rdma_mr_handle_t **mr_handles = (nccl_net_ofi_rdma_mr_handle_t **)mhandles;
uint16_t msg_seq_num = 0;
bool eager = false;
int i;
bool recv_completion_optional = false;
assert(r_comm != NULL);
if (early_completion && *base_req == (void *)NCCL_NET_OPTIONAL_RECV_COMPLETION) {
recv_completion_optional = true;
}
if (r_comm->comm_active == false) {
NCCL_OFI_WARN("Called irecv on inactive communicator");
ret = -EINVAL;
goto error;
}
if (OFI_UNLIKELY(r_comm->num_inflight_reqs == NCCL_OFI_MAX_REQUESTS)) {
ret = -ENOSPC;
NCCL_OFI_WARN("Can not support more than %d inflight requests",
NCCL_OFI_MAX_REQUESTS);
goto error;
}
dev_id = r_comm->base.base.dev_id;
ep = (nccl_net_ofi_rdma_ep_t *)r_comm->base.base.ep;
assert(ep != NULL);
domain = rdma_endpoint_get_domain(ep);
assert(domain != NULL);
device = rdma_endpoint_get_device(ep);
assert(device != NULL);
ret = process_cq_if_pending(ep);
if (ret == -EAGAIN) {
/* Network is still busy. Return NULL to NCCL. */
*base_req = NULL;
ret = 0;
goto error;
}
if (ret != 0) {
goto error;
}
msg_seq_num = r_comm->next_msg_seq_num;
eager = false;
void *elem;
nccl_ofi_msgbuff_elemtype_t type;
nccl_ofi_msgbuff_status_t msg_stat;
nccl_ofi_msgbuff_result_t mb_res;
mb_res = nccl_ofi_msgbuff_retrieve(r_comm->msgbuff, msg_seq_num, &elem,
&type, &msg_stat);
if (mb_res == NCCL_OFI_MSGBUFF_SUCCESS) {
if (type == NCCL_OFI_MSGBUFF_REQ) {
/* Shouldn't happen: duplicate request */
NCCL_OFI_WARN("Duplicate request in message buffer for msg %hu", msg_seq_num);
ret = -EINVAL;
goto error;
} else if (OFI_LIKELY(type == NCCL_OFI_MSGBUFF_BUFF)) {
/* This is an eager message */
eager = true;
} else {
NCCL_OFI_WARN("Invalid type in msg buff");
ret = -EINVAL;
goto error;
}
} else if ((mb_res == NCCL_OFI_MSGBUFF_INVALID_IDX) &&
(msg_stat == NCCL_OFI_MSGBUFF_NOTSTARTED)) {
/* Allocate a new req */
} else {
NCCL_OFI_WARN("Message %hu has invalid status.", msg_seq_num);
ret = -EINVAL;
goto error;
}
/* NCCL versions prior to 2.24 require special handling for 0 byte
* messages when using user buffer registration. NCCL passes the base
* pointer from the user buffer, but passes the registration from the
* channel buffer, to avoid an MR cache lookup. This is fine with
* InfiniBand, where the spec says the SGE is not used for a 0 byte
* message, but is a problem for EFA, which validates the pointer / MR
* even for a 0 byte transfer.
*
* To handle this case, we use the flush buffer (note we still move 0
* bytes of data, we just need a valid SGE) instead of the provided base
* pointer and MR
*/
for (i = 0 ; i < n ; i++) {
if (sizes[i] == 0) {
buffers[i] = domain->flush_buff.host_buffer;
mr_handles[i] = domain->flush_buff.mr_handle;
}
}
ret = allocate_rdma_recv_req(r_comm, device, dev_id, msg_seq_num,
buffers[0], sizes[0],
mr_handles[0], &req, recv_completion_optional);
if (ret != 0) {
goto error;
}
recv_data = get_recv_data(req);
if (eager) {
nccl_net_ofi_rdma_req_t *rx_buff_req = (nccl_net_ofi_rdma_req_t *)elem;
rdma_req_rx_buff_data_t *rx_buff_data = get_rx_buff_data(rx_buff_req);
if (rx_buff_data->recv_len == 0) {
/* Special case for zero-sized messages */
ret = check_post_rx_buff_req(rx_buff_req);
if (ret != 0) {
NCCL_OFI_WARN("Failed call to check_post_rx_buff_req");
return ret;
}
recv_data->eager_copy_req = NULL;
} else {
ret = alloc_eager_copy_req(req, r_comm, rx_buff_req);
if (ret != 0) {
goto error;
}
}
}
ret = insert_rdma_recv_req_into_msgbuff(r_comm, eager, &req);
if (ret != 0 || req == NULL) {
goto free_req;
}
/* At this point, we've successfully inserted a new request, so update the num inflight. */
(r_comm->num_inflight_reqs)++;
NCCL_OFI_TRACE_RECV(dev_id, r_comm, sizes[0], req, base_req);
/* Send ctrl msg */
nccl_net_ofi_mutex_lock(&r_comm->ctrl_counter_lock);
r_comm->n_ctrl_sent += 1;
nccl_net_ofi_mutex_unlock(&r_comm->ctrl_counter_lock);
ret = receive_progress(recv_data->send_ctrl_req, true);
if (OFI_UNLIKELY(ret != 0)) {
/* TODO: Remove req from message buffer */
goto error;
}
if (eager) {
if (recv_data->eager_copy_req == NULL) {
/* If we don't need to do eager copy, this recv is already complete */
ret = inc_req_completion(req, 0, recv_data->total_num_compls);
if (ret != 0) {
goto error;
}
} else {
/* Post eager copy */
ret = receive_progress(recv_data->eager_copy_req, true);
if (ret != 0) {
NCCL_OFI_WARN("Failed to issue eager read");
/* TODO: Remove req from message buffer */
goto error;
}
}
}
/* Return request to NCCL */
*base_req = (nccl_net_ofi_req_t *)req;
/* Increment next_msg_seq_num for next call */
r_comm->next_msg_seq_num = (r_comm->next_msg_seq_num + 1) & MSG_SEQ_NUM_MASK;
goto exit;
free_req:
error:
if (req)
req->free(req, false);
*base_req = NULL;
exit:
return ret;
}