static ncclResult_t ofi_connect()

in src/nccl_ofi_net.c [1314:1466]


static ncclResult_t ofi_connect(int dev, void *handle, void **sendComm)
{
	ncclResult_t ret = ncclSuccess;
	ssize_t rc = 0;
	uint64_t tag = 0ULL;
	char remote_ep_addr[MAX_EP_ADDR] = {0};
	char local_ep_addr[MAX_EP_ADDR] = {0};
	size_t namelen = sizeof(local_ep_addr);
	fi_addr_t remote_addr;
	sendComm_t *sComm = NULL;
	uint64_t max_tag = 0;
	nccl_ofi_req_t *req = NULL;
	size_t req_size = sizeof(nccl_ofi_req_t);

	if (OFI_UNLIKELY(dev < 0 || dev >= ofi_ndevices)) {
		NCCL_OFI_WARN("Incorrect device ID %d provided. Correct values are from 0 to %d",
			      dev, ofi_ndevices - 1);
		ret = ncclSystemError;
		goto exit;
	}

	if (OFI_UNLIKELY(nccl_ofi_component == NULL)) {
		NCCL_OFI_WARN("NCCL OFI component is not initialised.");
		ret = ncclSystemError;
		goto exit;
	}

	/*
	 * Create libfabric components for the given NIC, if not
	 * already created.
	 */
	pthread_mutex_lock(&nccl_ofi_lock);
	ret = get_nccl_ofi_comp(dev);
	if (ret)
		goto unlock;
	pthread_mutex_unlock(&nccl_ofi_lock);
	max_tag = nccl_ofi_component[dev]->max_tag;

	/* Parse handle to get tag and remote name */
	memcpy(&remote_ep_addr, (char *)handle, MAX_EP_ADDR);
	memcpy(&tag, (char *)handle + MAX_EP_ADDR, sizeof(tag));
	if (tag < 1 || tag > max_tag) {
		NCCL_OFI_WARN("Received an invalid tag %lu for device %d", tag,
			       dev);
		goto exit;
	}

	/* Insert remote address into AV */
	ret = fi_av_insert(nccl_ofi_component[dev]->av,
			   (void *)remote_ep_addr, 1,
			   &remote_addr, 0, NULL);
	if (OFI_UNLIKELY(ret != 1)) {
		NCCL_OFI_WARN("Unable to insert remote address into address vector for device %d. RC: %d",
			     dev, ret);
		ret = ncclSystemError;
		goto exit;
	}

	/* Build sendComm */
	sComm = (sendComm_t *)calloc(1, sizeof(sendComm_t));
	if (OFI_UNLIKELY(sComm == NULL)) {
		NCCL_OFI_WARN("Couldn't allocate sendComm for dev %d", dev);
		ret = ncclSystemError;
		goto error;
	}

	sComm->tag = tag;
	sComm->local_ep = nccl_ofi_component[dev]->ep;
	sComm->remote_ep = remote_addr;
	sComm->dev = dev;

	/* Pre-allocated buffers for data path */
	ret = allocate_ofi_fl(&sComm->nccl_ofi_reqs_fl, NCCL_OFI_MAX_REQUESTS,
			      req_size);
	if (OFI_UNLIKELY(ret != 0)) {
		NCCL_OFI_WARN("Could not allocate NCCL OFI requests free list for dev %d",
			     dev);
		goto error;
	}

	req = allocate_nccl_ofi_request(sComm->nccl_ofi_reqs_fl);
	if (OFI_UNLIKELY(req == NULL)) {
			ret = ncclSystemError;
			NCCL_OFI_WARN("Unable to get NCCL OFI request for device %d",
				      sComm->dev);
			goto error;
	}

	req->sComm = sComm;
	req->dev = sComm->dev;
	req->direction = NCCL_OFI_SEND;

	/* Get local EP address to transfer in the connect message */
	ret = fi_getname(&(nccl_ofi_component[dev]->ep->fid),
			 (void *)&local_ep_addr,
			 &namelen);
	if (ret != 0) {
		NCCL_OFI_WARN("Call to fi_getname() failed with RC: %d, ERROR: %s",
			      ret, fi_strerror(-ret));
		ret = ncclSystemError;
		goto error;
	}

	/* Send "connect" message to remote EP */
	do {
		/*
		 * TODO: replace it with API of FI_INJECT type when most of
		 * providers can support it, so that need for completion check
		 * below can be lifted.
		 */
		rc = fi_tsend(sComm->local_ep, (void *)&local_ep_addr,
			      MAX_EP_ADDR, NULL, sComm->remote_ep,
			      sComm->tag | ~max_tag, &req->ctx);
		if (rc == 0)
			break;
		else if (rc == -FI_EAGAIN) {
			/*
			 * Process completions so that you have enough
			 * resources for sending connect message
			 */
			ret = nccl_ofi_progress(nccl_ofi_component[dev]);
			if (OFI_UNLIKELY(ret != 0))
				goto error;
		}
		else {
			NCCL_OFI_WARN("Unable to send connect message for dev %d. RC: %zd, ERROR: %s",
				     dev, rc, fi_strerror(-rc));
			ret = ncclSystemError;
			goto error;
		}
	} while (true);

	/* Ensure the message is sent. */
	do {
		ret = nccl_ofi_progress(nccl_ofi_component[dev]);
		if (OFI_UNLIKELY(ret != 0))
			goto error;
	} while (req->state != NCCL_OFI_REQ_COMPLETED);

	*sendComm = sComm;

	goto exit;

unlock:
	pthread_mutex_unlock(&nccl_ofi_lock);
error:
	if (sComm)
		free(sComm);
exit:
	if (req)
		free_nccl_ofi_req(req, false);
	return ret;
}