static int recv()

in src/nccl_ofi_rdma.cpp [3524:3716]


static int recv(nccl_net_ofi_recv_comm_t *recv_comm, int n, void **buffers,
			 size_t *sizes, int *tags, nccl_net_ofi_mr_handle_t **mhandles,
			 nccl_net_ofi_req_t **base_req)
{
	int ret = 0;
	nccl_net_ofi_rdma_req_t *req = NULL;
	nccl_net_ofi_rdma_recv_comm_t *r_comm = (nccl_net_ofi_rdma_recv_comm_t *)recv_comm;
	rdma_req_recv_data_t *recv_data = NULL;
	nccl_net_ofi_rdma_ep_t *ep = NULL;
	nccl_net_ofi_rdma_domain_t *domain = NULL;
	nccl_net_ofi_rdma_device_t *device = NULL;
	int dev_id = 0;
	nccl_net_ofi_rdma_mr_handle_t **mr_handles = (nccl_net_ofi_rdma_mr_handle_t **)mhandles;
	uint16_t msg_seq_num = 0;
	bool eager = false;
	int i;
	bool recv_completion_optional = false;

	assert(r_comm != NULL);

	if (early_completion && *base_req == (void *)NCCL_NET_OPTIONAL_RECV_COMPLETION) {
		recv_completion_optional = true;
	}

	if (r_comm->comm_active == false) {
		NCCL_OFI_WARN("Called irecv on inactive communicator");
		ret = -EINVAL;
		goto error;
	}

	if (OFI_UNLIKELY(r_comm->num_inflight_reqs == NCCL_OFI_MAX_REQUESTS)) {
		ret = -ENOSPC;
		NCCL_OFI_WARN("Can not support more than %d inflight requests",
			      NCCL_OFI_MAX_REQUESTS);
		goto error;
	}

	dev_id = r_comm->base.base.dev_id;

	ep = (nccl_net_ofi_rdma_ep_t *)r_comm->base.base.ep;
	assert(ep != NULL);

	domain = rdma_endpoint_get_domain(ep);
	assert(domain != NULL);

	device = rdma_endpoint_get_device(ep);
	assert(device != NULL);

	ret = process_cq_if_pending(ep);
	if (ret == -EAGAIN) {
		/* Network is still busy. Return NULL to NCCL. */
		*base_req = NULL;
		ret = 0;
		goto error;
	}
	if (ret != 0) {
		goto error;
	}

	msg_seq_num = r_comm->next_msg_seq_num;

	eager = false;
	void *elem;
	nccl_ofi_msgbuff_elemtype_t type;
	nccl_ofi_msgbuff_status_t msg_stat;
	nccl_ofi_msgbuff_result_t mb_res;

	mb_res = nccl_ofi_msgbuff_retrieve(r_comm->msgbuff, msg_seq_num, &elem,
					   &type, &msg_stat);
	if (mb_res == NCCL_OFI_MSGBUFF_SUCCESS) {

		if (type == NCCL_OFI_MSGBUFF_REQ) {
			/* Shouldn't happen: duplicate request */
			NCCL_OFI_WARN("Duplicate request in message buffer for msg %hu", msg_seq_num);
			ret = -EINVAL;
			goto error;
		} else if (OFI_LIKELY(type == NCCL_OFI_MSGBUFF_BUFF)) {
			/* This is an eager message */
			eager = true;
		} else {
			NCCL_OFI_WARN("Invalid type in msg buff");
			ret = -EINVAL;
			goto error;
		}
	} else if ((mb_res == NCCL_OFI_MSGBUFF_INVALID_IDX) &&
		   (msg_stat == NCCL_OFI_MSGBUFF_NOTSTARTED)) {
		/* Allocate a new req */
	} else {
		NCCL_OFI_WARN("Message %hu has invalid status.", msg_seq_num);
		ret = -EINVAL;
		goto error;
	}

	/* NCCL versions prior to 2.24 require special handling for 0 byte
	 * messages when using user buffer registration.  NCCL passes the base
	 * pointer from the user buffer, but passes the registration from the
	 * channel buffer, to avoid an MR cache lookup.  This is fine with
	 * InfiniBand, where the spec says the SGE is not used for a 0 byte
	 * message, but is a problem for EFA, which validates the pointer / MR
	 * even for a 0 byte transfer.
	 *
	 * To handle this case, we use the flush buffer (note we still move 0
	 * bytes of data, we just need a valid SGE) instead of the provided base
	 * pointer and MR
	 */
	for (i = 0 ; i < n ; i++) {
		if (sizes[i] == 0) {
			buffers[i] = domain->flush_buff.host_buffer;
			mr_handles[i] = domain->flush_buff.mr_handle;
		}
	}

	ret = allocate_rdma_recv_req(r_comm, device, dev_id, msg_seq_num,
					buffers[0], sizes[0],
					mr_handles[0], &req, recv_completion_optional);
	if (ret != 0) {
		goto error;
	}

	recv_data = get_recv_data(req);

	if (eager) {
		nccl_net_ofi_rdma_req_t *rx_buff_req = (nccl_net_ofi_rdma_req_t *)elem;
		rdma_req_rx_buff_data_t *rx_buff_data = get_rx_buff_data(rx_buff_req);
		if (rx_buff_data->recv_len == 0) {
			/* Special case for zero-sized messages */
			ret = check_post_rx_buff_req(rx_buff_req);
			if (ret != 0) {
				NCCL_OFI_WARN("Failed call to check_post_rx_buff_req");
				return ret;
			}
			recv_data->eager_copy_req = NULL;
		} else {
			ret = alloc_eager_copy_req(req, r_comm, rx_buff_req);
			if (ret != 0) {
				goto error;
			}
		}
	}

	ret = insert_rdma_recv_req_into_msgbuff(r_comm, eager, &req);
	if (ret != 0 || req == NULL) {
		goto free_req;
	}

	/* At this point, we've successfully inserted a new request, so update the num inflight. */
	(r_comm->num_inflight_reqs)++;

	NCCL_OFI_TRACE_RECV(dev_id, r_comm, sizes[0], req, base_req);

	/* Send ctrl msg */
	nccl_net_ofi_mutex_lock(&r_comm->ctrl_counter_lock);
	r_comm->n_ctrl_sent += 1;
	nccl_net_ofi_mutex_unlock(&r_comm->ctrl_counter_lock);
	ret = receive_progress(recv_data->send_ctrl_req, true);
	if (OFI_UNLIKELY(ret != 0)) {
		/* TODO: Remove req from message buffer */
		goto error;
	}

	if (eager) {
		if (recv_data->eager_copy_req == NULL) {
			/* If we don't need to do eager copy, this recv is already complete */
			ret = inc_req_completion(req, 0, recv_data->total_num_compls);
			if (ret != 0) {
				goto error;
			}
		} else {
			/* Post eager copy */
			ret = receive_progress(recv_data->eager_copy_req, true);
			if (ret != 0) {
				NCCL_OFI_WARN("Failed to issue eager read");
				/* TODO: Remove req from message buffer */
				goto error;
			}
		}
	}

	/* Return request to NCCL */
	*base_req = (nccl_net_ofi_req_t *)req;
	/* Increment next_msg_seq_num for next call */
	r_comm->next_msg_seq_num = (r_comm->next_msg_seq_num + 1) & MSG_SEQ_NUM_MASK;

	goto exit;

 free_req:
 error:
	if (req)
		req->free(req, false);
	*base_req = NULL;
 exit:
	return ret;
}