int nccl_net_ofi_sendrecv_init()

in src/nccl_ofi_sendrecv.cpp [2747:2942]


int nccl_net_ofi_sendrecv_init(const char *provider_filter,
			       nccl_net_ofi_plugin_t **plugin_p)
{
	int ret = 0;
	struct fi_info *provider_list = NULL;
	unsigned int num_providers;
	nccl_net_ofi_sendrecv_plugin_t *plugin = NULL;
	struct fi_info *hints;

	hints = fi_allocinfo();
	if (hints == NULL) {
		NCCL_OFI_WARN("Allocation of fi_info failed");
		ret = -FI_ENOMEM;
		goto error;
	}

	if (nccl_ofi_dmabuf_viable()) {
		sendrecv_get_hints(hints, true);
		ret = nccl_ofi_ofiutils_get_providers(provider_filter,
						      FI_VERSION(1, 20),
						      hints,
						      &provider_list,
						      &num_providers);
		if (ret == 0) {
			NCCL_OFI_TRACE(NCCL_INIT | NCCL_NET, "Using Libfabric 1.20 API, with DMA-BUF support");
			support_gdr = GDR_UNKNOWN;
			goto found;
		}
	}

	sendrecv_get_hints(hints, true);
	ret = nccl_ofi_ofiutils_get_providers(provider_filter, FI_VERSION(1, 18), hints,
					      &provider_list, &num_providers);
	if (ret == 0) {
		/* The 1.18 API allows providers to use CUDA to
		 * support HMEM pointers, so just having HMEM doesn't
		 * tell us anything about the usability of CUDA
		 * pointers with NCCL.  So leave the state unknown
		 * until we create an endpoint and try to disable
		 * CUDA
		 */
		NCCL_OFI_TRACE(NCCL_INIT | NCCL_NET,
			       "Using Libfabric 1.18 API, with GPUDirect RDMA support");
		support_gdr = GDR_UNKNOWN;
		goto found;
	}

	sendrecv_get_hints(hints, true);
	ret = nccl_ofi_ofiutils_get_providers(provider_filter, FI_VERSION(1, 6), hints,
					      &provider_list, &num_providers);
	if (ret == 0) {
		NCCL_OFI_TRACE(NCCL_INIT | NCCL_NET,
			       "Using Libfabric 1.6 API, with GPUDirect RDMA support");
		support_gdr = GDR_SUPPORTED;
		goto found;
	}

	sendrecv_get_hints(hints, false);
	ret = nccl_ofi_ofiutils_get_providers(provider_filter, FI_VERSION(1, 6), hints,
					      &provider_list, &num_providers);
	if (ret == 0) {
		NCCL_OFI_TRACE(NCCL_INIT | NCCL_NET,
			       "Using Libfabric 1.6 API, without GPUDirect RDMA support");
		support_gdr = GDR_UNSUPPORTED;
		goto found;
	}

	ret = -FI_ENODATA;
found:
	fi_freeinfo(hints);
	if (ret != 0 && ret != -FI_ENODATA) {
		NCCL_OFI_WARN("OFI fi_getinfo() call failed: %s", fi_strerror(ret));
		goto error;
	}
	if (provider_list == NULL) {
		ret = -FI_ENODATA;
		goto error;
	}

	/* The TCP provider in Libfabric versions prior to 2.2.0
	 * erroneously requires a unique MR key even when FI_RMA
	 * capabilities are not requested. Because we use local MRs
	 * even if the provider does not require FI_MR_LOCAL and
	 * because Libfabric clears the FI_MR_PROV_KEY mr_mode when
	 * FI_RMA is not requested, we pass 0 as the mr key for all
	 * registrations, tripping the TCP bug.
	 * On versions of Libfabric before the bug is fixed, we
	 * request FI_RMA capabilities from the tcp provider even
	 * though we don't need it, so that we see the cleared
	 * FI_MR_PROV_KEY, fi_mr_key() returns the passed key, and
	 * everyone is happy (modulo a potential slight performance
	 * hit for having the emulated RMA operations loaded).
	 */
	if (FI_VERSION_LT(fi_version(), FI_VERSION(2, 2)) &&
			strcmp(provider_list->fabric_attr->prov_name, "tcp") == 0) {
		struct fi_info *iter = provider_list;
		while (iter != NULL) {
			iter->caps |= FI_RMA;
			iter = iter->next;
		}
	}
	support_fi_rma = ((provider_list->caps & FI_RMA) != 0);

	/* Allow for multiple virtual nics per nic to increase
	 * throughput for NICs that do not handle single QP situations
	 * well. */
	if (nic_dup_conns > 1) {
		struct fi_info *input_iter, *tmp, *output_head, *output_tail;

		/* The goal of the next chunk of code is to make
		 * provider_list contain the existing providr
		 * structures nic_dup_conns times each.  We start by
		 * multiplying the number of devices (ie, the size of
		 * the provider_list array) by nic_dup_conns.  We then
		 * iterate over a new info list, adding that number of
		 * devices by repeatedly copying the entries in the
		 * original list.
		 *
		 * If the input list was info objects A, B, C and
		 * dup_conns was 2, the output array (ie, provider_list
		 * at the end) will be A, B, C, A, B, C.
		 *
		 * Note that this isn't entirely sufficient to get
		 * NCCL to use all the connections.  We must also fake
		 * the locality of the info structures so that they
		 * look like more appealing paths; see the dup_conns
		 * code in the PCIe path discovery logic.
		 */
		num_providers *= nic_dup_conns;

		input_iter = NULL;
		output_head = output_tail = NULL;
		for (size_t i = 0 ; i < num_providers ; i++) {
			/* note that because we'll iterate through
			   provider_list multiple times (because
			   num_providers is already multiplied by
			   nic_dup_conns), this check has to be in the
			   for loop.  Each time we reach the end of
			   the list, we'll see iter as NULL and
			   restart. */
			if (!input_iter)
				input_iter = provider_list;

			tmp = fi_dupinfo(input_iter);
			if (!tmp) {
				NCCL_OFI_WARN("DUP_CONNS fi_dupinfo failed.");
				ret = -ENOMEM;
				goto error;
			}
			/* just in case */
			tmp->next = NULL;

			if (!output_head)
				output_head = tmp;

			if (!output_tail) {
				output_tail = tmp;
			} else {
				output_tail->next = tmp;
				output_tail = tmp;
			}

			input_iter = input_iter->next;
		}

		fi_freeinfo(provider_list);
		provider_list = output_head;

		NCCL_OFI_INFO(NCCL_INIT, "DUP_CONNS of %d changing device count to %d",
			      nic_dup_conns, num_providers);
	}

	ret = nccl_net_ofi_query_provider_capabilities(provider_list, num_providers);
	if (ret != 0) {
		NCCL_OFI_WARN("Querying provider capabilities failed: %d", ret);
		goto error;
	}

	ret = nccl_net_ofi_sendrecv_plugin_create(num_providers, provider_list, &plugin);
	if (ret != 0) {
		NCCL_OFI_WARN("Unable to allocate nccl_net_ofi_plugin_t");
		goto error;
	}

	*plugin_p = &plugin->base;

	return ret;

 error:
	if (plugin != NULL) {
		plugin->base.release_plugin(&plugin->base);
		plugin = NULL;
	}

	return ret;
}