static ncclResult_t nccl_ofi_tuner_init()

in src/tuner/nccl_ofi_tuner.cpp [46:151]


static ncclResult_t nccl_ofi_tuner_init(size_t nRanks, size_t nNodes, ncclDebugLogger_t logFunction, void **context)
{
	const char *platform_type = NULL;
	const char *tuner_force_type = NULL;
	ncclResult_t ret = ncclSuccess;
	*context = NULL;
	nccl_ofi_tuner_context_t *ctx = NULL;
	bool region_support, model_support;
	int is_force_type_model = 0;
	enum nccl_ofi_tuner_platform tuner_platform;

	ofi_log_function = logFunction;

	nccl_net_ofi_mutex_lock(&nccl_ofi_tuner_ctx_lock);

	/*
	 * Retrieve platform type and pass to Region and Model based tuner support check functions.
	 * If both Region and Model based tuner are not supported, log a warning and exit.
	 */
	platform_type = nccl_net_ofi_get_product_name();
	if (platform_type == NULL) {
		NCCL_OFI_WARN("NCCL_OFI_TUNER is not available because platform type is unavailable.");
		goto exit;
	}

	tuner_force_type = ofi_nccl_tuner_force_type();
	if (tuner_force_type != NULL) {
		if (strcmp(tuner_force_type, "Internal") == 0) {
			/* fallback to NCCL internal tuner */
			NCCL_OFI_INFO(NCCL_INIT | NCCL_TUNING,
				      "NCCL_OFI_TUNER_TYPE is Internal, Fall back to NCCL's tuner for platform : %s",
				      platform_type);
			goto exit;
		} else if (strcmp(tuner_force_type, "Model") == 0) {
			is_force_type_model = 1;
		}
	}

	if (strcmp(platform_type, "p5.48xlarge") == 0 || strcmp(platform_type, "p5e.48xlarge") == 0) {
		tuner_platform = NCCL_OFI_TUNER_P5_P5E;
	} else if (strcmp(platform_type, "p5en.48xlarge") == 0) {
		tuner_platform = NCCL_OFI_TUNER_P5EN;
	} else {
		tuner_platform = NCCL_OFI_TUNER_UNKNOWN;
	}

	region_support = is_region_supported(tuner_platform, nRanks, nNodes);
	model_support = is_model_supported(tuner_platform, nRanks, nNodes);
	if (!region_support && !model_support) {
		NCCL_OFI_INFO(NCCL_INIT | NCCL_TUNING,
			      "NCCL_OFI_TUNER is not available for platform : %s, Fall back to NCCL's tuner",
			      platform_type);
		goto exit;
	}

	ctx = (nccl_ofi_tuner_context_t *)calloc(1, sizeof(nccl_ofi_tuner_context_t));
	if (ctx == NULL) {
		NCCL_OFI_WARN("Context allocation failed.");
		ret = ncclInternalError;
		goto exit;
	}

	/*
	 * We reach here. It means the folowing two conditions are met.
	 *  - "Internal" force is not set by env variable
	 *  - at least one of "Region" or "Model" tuner is supported for the given platform, nRanks and nNodes
	 */

	/*
	 * We choose "Region" over "Model" when both are supported.
	 * TUNER_TYPE env variable is ignored if the forced tuner type is not
	 * supported by the given platform, nRanks and nNodes.
	 */

	if (region_support && !(model_support && is_force_type_model)) {
		ctx->type = NCCL_OFI_TUNER_TYPE_REGION;
		ctx->init_internal = region_init_internal;
		ctx->get_coll_info_internal_v3 = region_get_coll_info_internal_v3;
		ctx->get_coll_info_internal_v2 = region_get_coll_info_internal_v2;
		ctx->destroy_internal = region_destroy_internal;
		NCCL_OFI_INFO(NCCL_INIT | NCCL_TUNING, "Region base Tuner is chosen for platform: %s", platform_type);
	} else {
		assert(model_support);
		ctx->type = NCCL_OFI_TUNER_TYPE_MODEL;
		ctx->init_internal = model_init_internal;
		ctx->get_coll_info_internal_v3 = model_get_coll_info_internal_v3;
		ctx->get_coll_info_internal_v2 = model_get_coll_info_internal_v2;
		ctx->destroy_internal = model_destroy_internal;
		NCCL_OFI_INFO(NCCL_INIT | NCCL_TUNING, "Model base Tuner is chosen for platform: %s", platform_type);
	}

	ret = ctx->init_internal(ctx, tuner_platform, nRanks, nNodes);

	NCCL_OFI_INFO(NCCL_INIT | NCCL_TUNING, "Tuner init: comm with %ld ranks and %ld nodes.", nRanks, nNodes);

exit:
	if (ret != ncclSuccess && ctx != NULL) {
		nccl_ofi_tuner_destroy((void *)ctx);
		ctx = NULL;
	}

	*context = (void *)ctx;
	nccl_net_ofi_mutex_unlock(&nccl_ofi_tuner_ctx_lock);

	return ret;
}