static ncclResult_t ofi_accept()

in src/nccl_ofi_net.c [1468:1612]


static ncclResult_t ofi_accept(void *listenComm, void **recvComm)
{
	ncclResult_t ret = ncclSuccess;
	ssize_t rc = 0;
	recvComm_t *rComm = NULL;
	listenComm_t *lComm = (listenComm_t *)listenComm;
	int dev = lComm->dev;
	nccl_ofi_t *nccl_ofi_comp = nccl_ofi_component[dev];
	nccl_ofi_req_t *req = NULL;
	char remote_ep_addr[MAX_EP_ADDR] = {0};
	fi_addr_t remote_ep;
	uint64_t max_tag;
	size_t req_size = sizeof(nccl_ofi_req_t);
	struct fid_mr *mr_handle = NULL;

	pthread_mutex_lock(&nccl_ofi_lock);
	if (nccl_ofi_comp == NULL) {
		ret = ncclSystemError;
		NCCL_OFI_WARN("NCCL OFI component for dev %d is uninitialised",
			     dev);
		goto unlock;
	}

	ret = get_nccl_ofi_comp(dev);
	if (ret)
		goto unlock;
	pthread_mutex_unlock(&nccl_ofi_lock);

	max_tag = nccl_ofi_comp->max_tag;

	if (lComm->accepted == true) {
		ret = ncclSystemError;
		NCCL_OFI_WARN("listenComm object already has an active connection.");
		goto exit;
	}

	/* Allocate a NCCL OFI request */
	req = (nccl_ofi_req_t *)calloc(1, sizeof(nccl_ofi_req_t));
	if (OFI_UNLIKELY(req == NULL)) {
		NCCL_OFI_WARN("Unable to allocate nccl_ofi_req_t");
		ret = ncclSystemError;
		goto exit;
	}

	req->state = NCCL_OFI_REQ_CREATED;
	req->lComm = lComm;
	req->dev = dev;

	/* Post a buffer for receiving connection requests */
	do {
		rc = fi_trecv(lComm->local_ep, (void *)&remote_ep_addr, MAX_EP_ADDR,
			      NULL, FI_ADDR_UNSPEC, lComm->tag | ~max_tag,
			      0, &req->ctx);
		if (rc == 0)
			break;
		else if (rc == -FI_EAGAIN) {
			/*
			 * Process completions so that you have enough
			 * resources for posting receive buffer
			 */
			ret = nccl_ofi_progress(nccl_ofi_comp);
			if (OFI_UNLIKELY(ret != 0))
				goto exit;
		}
		else {
			NCCL_OFI_WARN("Unable to post a buffer for receving connections for dev %d. RC: %zd, ERROR: %s",
				     dev, rc, fi_strerror(-rc));
			ret = ncclSystemError;
			goto exit;
		}
	} while (true);

	/* Progress NCCL_OFI until connection is accepted */
	while (lComm->accepted == false) {
		ret = nccl_ofi_progress(nccl_ofi_comp);
		if (OFI_UNLIKELY(ret != 0))
			goto exit;
	}

	/* Insert remote EP address to AV */
	ret = fi_av_insert(nccl_ofi_comp->av, (void *)remote_ep_addr, 1,
			   &remote_ep, 0, NULL);
	if (OFI_UNLIKELY(ret != 1)) {
		NCCL_OFI_WARN("Unable to insert remote address into address vector for device %d. RC: %d",
			      dev, fi_strerror(-ret));
		ret = ncclSystemError;
		goto exit;
	}

	/* Build recvComm */
	rComm = (recvComm_t *)calloc(1, sizeof(recvComm_t));
	if (rComm == NULL) {
		NCCL_OFI_WARN("Unable to allocate receive Comm object for device %d",
			     dev);
		ret = ncclSystemError;
		goto exit;
	}

	rComm->tag = lComm->tag;
	rComm->local_ep = lComm->local_ep;
	rComm->local_ep_addr = lComm->local_ep_addr;
	rComm->remote_ep = remote_ep;
	rComm->dev = dev;

	if (support_gdr) {
		rComm->flush_buff.size = sizeof(rComm->flush_buff.host_buffer);


		/* Register flush dummy buffer for provider access */
		ret = register_mr_buffers(rComm, &rComm->flush_buff.host_buffer,
					  rComm->flush_buff.size, NCCL_PTR_HOST,
					  &mr_handle);
		if (OFI_UNLIKELY(ret != ncclSuccess)) {
			NCCL_OFI_WARN("Could not register dummy buffer for flush, dev:  %d",
				      dev);
			goto error;
		}
		rComm->flush_buff.mr_handle = mr_handle;
	}

	/* Pre-allocated buffers for data path */
	ret = allocate_ofi_fl(&rComm->nccl_ofi_reqs_fl, NCCL_OFI_MAX_REQUESTS,
			      req_size);
	if (OFI_UNLIKELY(ret != 0)) {
		NCCL_OFI_WARN("Could not allocate NCCL OFI requests free list for dev %d",
			     dev);
		goto error;
	}

	*recvComm = rComm;

	goto exit;

unlock:
	pthread_mutex_unlock(&nccl_ofi_lock);
error:
	if (mr_handle)
		fi_close((fid_t)mr_handle);
	if (rComm)
		free(rComm);
exit:
	if (req)
		free(req);
	return ret;
}