static int send()

in src/nccl_ofi_rdma.cpp [5857:6069]


static int send(nccl_net_ofi_send_comm_t *send_comm, void *data, size_t size, int tag,
			 nccl_net_ofi_mr_handle_t *mhandle, nccl_net_ofi_req_t **base_req)
{
	int ret = 0;
	nccl_net_ofi_rdma_send_comm_t *s_comm = (nccl_net_ofi_rdma_send_comm_t *)send_comm;
	nccl_net_ofi_rdma_mr_handle_t *mr_handle = (nccl_net_ofi_rdma_mr_handle_t *)mhandle;
	nccl_net_ofi_rdma_ep_t *ep = NULL;
	nccl_net_ofi_rdma_domain_t *domain = NULL;
	nccl_net_ofi_rdma_req_t *req = NULL;
	uint16_t msg_seq_num = s_comm->next_msg_seq_num;
	bool polled_cq = false;
	bool have_ctrl = false;
	bool eager = false;
	int dev_id = 0;

	assert(s_comm != NULL);

	if (s_comm->comm_active == false) {
		NCCL_OFI_WARN("Called isend on inactive communicator");
		ret = -EINVAL;
		goto error;
	}

	/* Support only NCCL_OFI_MAX_REQUESTS inflight requests. */
	if (OFI_UNLIKELY(s_comm->num_inflight_reqs == NCCL_OFI_MAX_SEND_REQUESTS)) {
		ret = -EINVAL;
		NCCL_OFI_WARN("Can not support more than %d inflight requests",
			      NCCL_OFI_MAX_SEND_REQUESTS);
		goto error;
	}

	dev_id = s_comm->base.base.dev_id;

	ep = (nccl_net_ofi_rdma_ep_t *)s_comm->base.base.ep;
	assert(ep != NULL);

	domain = rdma_endpoint_get_domain(ep);
	assert(domain != NULL);

	ret = process_cq_if_pending(ep);
	if (ret == -EAGAIN) {
		/* Network is still busy. Return NULL to NCCL. */
		*base_req = NULL;
		ret = 0;
		goto error;
	}
	if (ret != 0) {
		goto error;
	}

	/*
	 * TODO: Use NCCL provided tags when using grouped receives aka
	 * props->maxRecvs > 1.
	 */

	have_ctrl = false;
	msg_seq_num = s_comm->next_msg_seq_num;

	void *elem;
	nccl_ofi_msgbuff_elemtype_t type;
	nccl_ofi_msgbuff_status_t msg_stat;
	nccl_ofi_msgbuff_result_t mb_res;

retry:
	/* Retrive entry from message buffer for msg_seq_num index */
	mb_res = nccl_ofi_msgbuff_retrieve(s_comm->msgbuff, msg_seq_num, &elem,
					   &type, &msg_stat);
	if (mb_res == NCCL_OFI_MSGBUFF_SUCCESS) {
		if (OFI_LIKELY(type == NCCL_OFI_MSGBUFF_BUFF)) {
			/*
			 * Received RDMA control message from receiver so
			 * allocate request and initiate RDMA write
			 */
			have_ctrl = true;
		} else if (type == NCCL_OFI_MSGBUFF_REQ) {
			/* Shouldn't happen: we already have a req in the message buffer */
			NCCL_OFI_WARN("Duplicate request in message buffer for msg %hu", msg_seq_num);
			ret = -EINVAL;
			goto error;
		} else {
			NCCL_OFI_WARN("Unexpected type of buffer retrieved from message buffer: %d",
				      type);
			ret = -EINVAL;
			goto error;
		}
	} else if ((mb_res == NCCL_OFI_MSGBUFF_INVALID_IDX) &&
		   (msg_stat == NCCL_OFI_MSGBUFF_NOTSTARTED)) {
		/*
		 * We haven't encountered this message sequence number.
		 * Allocate a request so that we are able to send RDMA write
		 * as soon as we receive the RDMA control message.
		 */
		have_ctrl = false;
	} else {
		NCCL_OFI_WARN("Message %hu has invalid status. res = %d and stat = %d",
			      msg_seq_num, mb_res, msg_stat);
		ret = -EINVAL;
		goto error;
	}

	/* look for control messages and then retry the message search
	   to avoid unnecessary polling / queueing. */
	if (OFI_UNLIKELY(!polled_cq && !have_ctrl)) {
		for (uint16_t rail_id = 0; rail_id != ep->num_control_rails; ++rail_id) {
			nccl_net_ofi_ep_rail_t *rail = rdma_endpoint_get_control_rail(ep, rail_id);

			ret = ofi_process_cq_rail(ep, rail);
			if (OFI_UNLIKELY(ret != 0)) {
				goto error;
			}
		}
		polled_cq = true;
		goto retry;
	}

	/* NCCL versions prior to 2.24 require special handling for 0 byte
	 * messages when using user buffer registration.  NCCL passes the base
	 * pointer from the user buffer, but passes the registration from the
	 * channel buffer, to avoid an MR cache lookup.  This is fine with
	 * InfiniBand, where the spec says the SGE is not used for a 0 byte
	 * message, but is a problem for EFA, which validates the pointer / MR
	 * even for a 0 byte transfer.
	 *
	 * To handle this case, we use the flush buffer (note we still move 0
	 * bytes of data, we just need a valid SGE) instead of the provided base
	 * pointer and MR
	 */
	if (size == 0) {
		data = domain->flush_buff.host_buffer;
		mr_handle = domain->flush_buff.mr_handle;
	}

	/* Determine if this should be sent eagerly. */
	eager = false;
	if (!have_ctrl && (ssize_t)size <= ep->eager_send_size && s_comm->num_inflight_writes == 0) {
		eager = true;
	}

	ret = alloc_rdma_send_req(s_comm, msg_seq_num, data,
				  size, mr_handle, eager, &req);
	if (OFI_UNLIKELY(ret != 0)) {
		goto error;
	}

	if (have_ctrl) {
		/*
		 * For already received RDMA control message, populate
		 * the RDMA write metadata from the rx buffer
		 */
		nccl_net_ofi_rdma_req_t *rx_buff_req = (nccl_net_ofi_rdma_req_t *)elem;
		ret = update_send_data_from_remote(s_comm, rx_buff_req, req);
		if (OFI_UNLIKELY(ret != 0)) {
			NCCL_OFI_WARN("Failed to copy ctrl data");
			goto error;
		}

		/* Post if needed */
		ret = check_post_rx_buff_req(rx_buff_req);
		if (OFI_UNLIKELY(ret != 0)) {
			goto error;
		}
	}

	ret = insert_rdma_send_req_into_msgbuff(s_comm, dev_id, have_ctrl, &req);
	if (OFI_UNLIKELY(ret != 0 || req == NULL)) {
		goto free_req;
	}

	/*
	 * At this point, we've successfully inserted a new request,
	 * so update the num inflight
	 */
	(s_comm->num_inflight_reqs)++;

	if (!eager) {
		(s_comm->num_inflight_writes)++;
	}

	NCCL_OFI_TRACE_SEND(req->dev_id, size, s_comm, msg_seq_num, req, base_req);

	/* Try posting RDMA write for received RDMA control messages */
	if (have_ctrl || eager) {

		ret = send_progress(req);
		if (ret == -FI_EAGAIN) {
			/* Add to pending reqs queue */
			nccl_net_ofi_mutex_lock(&ep->pending_reqs_lock);
			ep->pending_reqs_queue->push_back(req);
			nccl_net_ofi_mutex_unlock(&ep->pending_reqs_lock);
			ret = 0;
			NCCL_OFI_TRACE_PENDING_INSERT(req);
		} else if (OFI_UNLIKELY(ret != 0)) {
			/* TODO: Remove req from message buffer */
			ret = -ENOTSUP;
			goto error;
		}
	}

	/* Return request to NCCL */
	*base_req = &req->base;
	/* Increment next_msg_seq_num for next call */
	s_comm->next_msg_seq_num = (s_comm->next_msg_seq_num + 1) & MSG_SEQ_NUM_MASK;

	goto exit;

 free_req:
 error:
	if (req)
		req->free(req, false);
	*base_req = NULL;
 exit:
	return ret;
}