static int sendrecv_recv_comm_flush()

in src/nccl_ofi_sendrecv.cpp [1216:1363]


static int sendrecv_recv_comm_flush(nccl_net_ofi_recv_comm_t *recv_comm, int n, void **buffers,
				    int *sizes, nccl_net_ofi_mr_handle_t **mhandles,
				    nccl_net_ofi_req_t **base_req)
{
	int ret = 0;
	nccl_net_ofi_sendrecv_recv_comm_t *r_comm =
		(nccl_net_ofi_sendrecv_recv_comm_t *)recv_comm;
	nccl_net_ofi_sendrecv_req_t *req = NULL;
	ssize_t rc = 0;
	uint64_t cuda_key = 0ULL;
	nccl_net_ofi_sendrecv_mr_handle_t *mr_handle = NULL;
	void *data = NULL;
	void *flush_mr_desc = NULL;
	int dev_id = recv_comm->base.dev_id;
	int flush_n = -1;
	auto **mr_handles = reinterpret_cast<nccl_net_ofi_sendrecv_mr_handle_t **>(mhandles);

	if (ofi_nccl_gdr_flush_disable() || support_gdr == GDR_UNSUPPORTED)
		goto exit;

#if HAVE_CUDA
	if (cuda_flush) {
		ret = nccl_net_ofi_cuda_flush_gpudirect_rdma_writes();
		if (ret != 0) {
			NCCL_OFI_WARN("Error performing CUDA GDR flush");
		}
		goto exit;
	}
#endif

	/* Plugin only supports one receive per request */
	assert(n <= NCCL_OFI_MAX_RECVS);

	/*
	 * Find the non-zero request for which we will issue flush.
	 * A single operation can flush all request at once.
	 */
	for (int recv_n = 0; recv_n < n; recv_n++) {
		if (sizes[recv_n] != 0) {
			flush_n = recv_n;
			break;
		}
	}

	if (flush_n == -1) {
		/*
		 * Flush is an expensive operation. So, don't send fi_read for
		 * 0-sized messages. Since, NCCL issues flush for every irecv(),
		 * we guarantee to sync data to GPU even without it.
		 */
		goto exit;
	}

	if (mr_handles && mr_handles[flush_n]) {
		mr_handle = mr_handles[flush_n];
	}
	data = buffers[flush_n];

	/* Support only max_requests inflight requests. */
	if (OFI_UNLIKELY(r_comm->num_inflight_reqs == NCCL_OFI_MAX_REQUESTS)) {
		ret = -ENOSPC;
		NCCL_OFI_WARN("Can not support more than %d inflight requests",
			      NCCL_OFI_MAX_REQUESTS);
		goto exit;
	}

	/* Allocate NCCL OFI request */
	req = sendrecv_allocate_req(r_comm->nccl_ofi_reqs_fl);
	if (OFI_UNLIKELY(req == NULL)) {
		ret = -ENOTSUP;
		NCCL_OFI_WARN("Unable to get NCCL OFI request for device %d",
			      dev_id);
		goto exit;
	}

	req->comm = &r_comm->base.base;
	req->dev_id = dev_id;
	req->direction = NCCL_OFI_SENDRECV_RECV;

	if (r_comm->flush_buff.mr_handle != NULL) {
		/* Not checking for NULL flush_mr_desc as fi_mr_desc()
		 * returns valid descriptors by valid handles */
		flush_mr_desc = fi_mr_desc(r_comm->flush_buff.mr_handle->mr);
	}

	if (mr_handle->mr != nullptr) {
		/* Extract remote key */
		cuda_key = fi_mr_key(mr_handle->mr);
		if (OFI_UNLIKELY(cuda_key == FI_KEY_NOTAVAIL)) {
			ret = -ENOTSUP;
			NCCL_OFI_WARN("Memory registration may not have completed.");
			goto error;
		}
	}

	NCCL_OFI_TRACE_FLUSH_SENDRECV(req, base_req);

	/* Issue RDMA read */
	do {
		rc = fi_read(r_comm->local_ep, r_comm->flush_buff.host_buffer,
			     r_comm->flush_buff.size,
			     flush_mr_desc,
			     r_comm->local_ep_addr,
			     (uint64_t)(virt_addr_mr ? data : 0),
			     cuda_key, sendrecv_req_get_ofi_context(req));
		if (rc == 0) {
			break;
		} else if (rc == -FI_EAGAIN) {
			/* Retrieve and validate endpoint */
			nccl_net_ofi_sendrecv_ep_t *ep =
				(nccl_net_ofi_sendrecv_ep_t *)r_comm->base.base.ep;
			if (OFI_UNLIKELY(ep == NULL)) {
				ret = -EINVAL;
				NCCL_OFI_WARN("Invalid endpoint provided");
				goto error;
			}

			/*
			 * Process completions so that you have enough
			 * resources for issuing fi_read
			 */
			ret = sendrecv_cq_process(ep->cq);
			if (OFI_UNLIKELY(ret != 0))
				goto error;
		} else {
			NCCL_OFI_WARN("Unable to issue read operation for dev %d. RC: %zd, ERROR: %s",
				      dev_id, rc, fi_strerror(-rc));
			ret = -ENOTSUP;
			goto error;
		}
	} while (true);

	(r_comm->num_inflight_reqs)++;

	/* Set request size */
	req->size = r_comm->flush_buff.size;

	*base_req = &req->base;

	return ret;

 error:
	if (req)
		sendrecv_recv_comm_free_req(r_comm, dev_id, req, false);
 exit:
	*base_req = NULL;
	return ret;
}