static int sendrecv_comm_mr_base_reg()

in src/nccl_ofi_sendrecv.cpp [867:946]


static int sendrecv_comm_mr_base_reg(nccl_net_ofi_comm_t *base_comm,
				     nccl_ofi_mr_ckey_ref ckey,
				     int type,
				     nccl_net_ofi_sendrecv_mr_handle_t **mr_handle)
{
	/* Retrieve and validate endpoint */
	nccl_net_ofi_sendrecv_ep_t *ep =
		(nccl_net_ofi_sendrecv_ep_t *)base_comm->ep;
	nccl_ofi_idpool_t *key_pool = NULL;
	if (OFI_UNLIKELY(ep == NULL)) {
		NCCL_OFI_WARN("Invalid endpoint provided");
		return -EINVAL;
	}

	/* Retrieve and validate device */
	nccl_net_ofi_sendrecv_device_t *device =
		sendrecv_endpoint_get_device(ep);
	if (OFI_UNLIKELY(device == NULL)) {
		NCCL_OFI_WARN("Invalid device provided");
		return -EINVAL;
	}

	nccl_net_ofi_sendrecv_domain_t *domain = sendrecv_endpoint_get_domain(ep);
	assert(domain != NULL);

	int dev_id = device->base.dev_id;

	int ret = 0;
	nccl_ofi_mr_cache_t *mr_cache = domain->base.mr_cache;
	nccl_net_ofi_sendrecv_mr_handle_t *ret_handle = nullptr;

	if (mr_cache) {
		/*
		 * MR cache is locked between lookup and insert, to be sure we
		 * insert a missing entry
		 */
		nccl_net_ofi_mutex_lock(&mr_cache->lock);
		ret_handle = static_cast<nccl_net_ofi_sendrecv_mr_handle_t *>(
			nccl_ofi_mr_cache_lookup_entry(mr_cache, ckey));

		if (ret_handle) {
			/* Cache hit */
			goto unlock;
		}
		/* Cache miss */
	}

	key_pool = domain->base.mr_rkey_pool;
	struct fid_domain *ofi_domain;
	ofi_domain = sendrecv_endpoint_get_ofi_domain(ep);
	ret = sendrecv_mr_base_register(ofi_domain, ep->ofi_ep, key_pool,
					dev_id, ckey, type, &ret_handle);
	if (OFI_UNLIKELY(ret_handle == NULL || ret != 0)) {
		ret_handle = NULL;
		goto unlock;
	}

	if (mr_cache) {
		ret = nccl_ofi_mr_cache_insert_entry(mr_cache, ckey, ret_handle);
		if (OFI_UNLIKELY(ret != 0)) {
			/* MR cache insert failed. Deregister memory region without
			 * trying to delete MR cache entry.
			 */
			if (sendrecv_comm_mr_base_dereg(ret_handle, key_pool, NULL) != 0) {
				NCCL_OFI_WARN("Error deregistering memory region for addr %ld (%s)",
					      nccl_ofi_mr_ckey_baseaddr(ckey), nccl_ofi_mr_ckey_type_str(ckey));
			}
			ret_handle = NULL;
			goto unlock;
		}
	}

unlock:
	if (mr_cache) {
		nccl_net_ofi_mutex_unlock(&mr_cache->lock);
	}

	*mr_handle = ret_handle;
	return ret;
}