ncclResult_t nccl_net_ofi_regMrDmaBuf_v6()

in src/nccl_ofi_api.cpp [459:526]


ncclResult_t nccl_net_ofi_regMrDmaBuf_v6(void* comm, void* data, size_t size,
					 int type, uint64_t offset,
					 int fd, void** mhandle)
{
	int ret;
	/* Retrieve and validate comm */
	nccl_net_ofi_comm_t *base_comm =
		(nccl_net_ofi_comm_t *)comm;

	if (OFI_UNLIKELY(plugin == NULL)) {
		NCCL_OFI_WARN("Error accessing plugin. Plugin has not been initialized yet.");
		return check_return(ncclInvalidArgument);
	}

	if (OFI_UNLIKELY(base_comm == NULL)) {
		NCCL_OFI_WARN("Invalid comm object provided");
		return check_return(ncclInternalError);
	}

	/* Validate type of buffer */
	bool valid_buffer_type = false;
	if (type == NCCL_PTR_HOST) valid_buffer_type = true;
#if HAVE_CUDA
	if (type == NCCL_PTR_CUDA) valid_buffer_type = true;
#endif
#if HAVE_NEURON
	if (type == NCCL_PTR_NEURON) valid_buffer_type = true;
#endif
	if (!valid_buffer_type) {
		NCCL_OFI_WARN("Invalid buffer type provided: %d", type);
		return check_return(ncclInternalError);
	}

#if HAVE_DECL_FI_MR_DMABUF
	const nccl_ofi_mr_ckey_t cache_key = (fd == -1)
		? nccl_ofi_mr_ckey_mk_vec(data, size)
		: nccl_ofi_mr_ckey_mk_dmabuf(fd, offset, size, data);
#else
	if (fd != -1) {
		NCCL_OFI_WARN("Passed fd handle, but not compiled with DMA-BUF support.");
		return nccl_net_ofi_retval_translate(-EINVAL);
	}
	const nccl_ofi_mr_ckey_t cache_key = nccl_ofi_mr_ckey_mk_vec(data, size);
#endif

	nccl_net_ofi_send_comm_t *send_comm = NULL;
	nccl_net_ofi_recv_comm_t *recv_comm = NULL;

	switch (base_comm->type) {
	case NCCL_NET_OFI_SEND_COMM:
		send_comm = (nccl_net_ofi_send_comm_t *)base_comm;
		ret = send_comm->regMr(send_comm, &cache_key, type, mhandle);
		break;
	case NCCL_NET_OFI_RECV_COMM:
		recv_comm = (nccl_net_ofi_recv_comm_t *)base_comm;
		ret = recv_comm->regMr(recv_comm, &cache_key, type, mhandle);
		break;
	case NCCL_NET_OFI_BASE_COMM:
	case NCCL_NET_OFI_LISTEN_COMM:
	default:
		NCCL_OFI_WARN("Unexpected communicator type. Communicator type: %d",
			      base_comm->type);
		ret = -EINVAL;
		break;
	}

	return nccl_net_ofi_retval_translate(ret);
}