ncclResult_t model_get_coll_info_internal_v2()

in src/tuner/nccl_ofi_model.cpp [221:294]


ncclResult_t model_get_coll_info_internal_v2(nccl_ofi_tuner_context_t *ctx, ncclFunc_t collType, size_t nBytes,
					     int collNetSupport, int nvlsSupport, int numPipeOps, int *algorithm,
					     int *protocol, int *nChannels)
{
	float cost = 0;
	float lowest = FLT_MAX;
	int algo, proto = 0;
	nccl_ofi_tuner_model_context_t *model_ctx = (nccl_ofi_tuner_model_context_t *)ctx->type_ctx;

	if (model_ctx == NULL) {
		/* we do not update cost table. Fall back to NCCL's tuner */
		NCCL_OFI_INFO(NCCL_TUNING, "Model Context is not ready. Fall back to NCCL's tuner.");
		return ncclSuccess;
	}

	/* Skip runs smaller than 2 nodes and fallback to NCCL's internal tunings */
	if (model_ctx->dims.num_nodes <= 2) {
		return ncclSuccess;
	}

	/* apply p5/p5e platform specific quirk */
	if (model_ctx->platform == NCCL_OFI_TUNER_P5_P5E) {
		if (collType == ncclFuncAllReduce && model_ctx->dims.num_nodes == 16 &&
		    model_ctx->dims.num_ranks == 128 && nvlsSupport && nBytes > 3ULL * 1024ULL * 1024ULL * 1024ULL &&
		    nBytes <= 5ULL * 1024ULL * 1024ULL * 1024ULL) {
			lowest = 0;
			*algorithm = NCCL_ALGO_NVLS_TREE;
			*protocol = NCCL_PROTO_SIMPLE;
			goto exit;
		}
	}

	/*
	 * Ideally, this should just be a lookup and not be in-flight math
	 * We do not want divs in the hot path, but working with the API we've
	 * got now.
	 */
	for (algo = 0; algo < NCCL_NUM_ALGORITHMS; algo++) {
		/* No CollNet on AWS today */
		if (algo == NCCL_ALGO_COLLNET_DIRECT || algo == NCCL_ALGO_COLLNET_CHAIN)
			continue;

		/* Skip NCCL_ALGO_NVLS used only for single-node jobs */
		if (algo == NCCL_ALGO_NVLS)
			continue;

		if (!nvlsSupport && (algo == NCCL_ALGO_NVLS_TREE))
			continue;

		for (proto = 0; proto < NCCL_NUM_PROTOCOLS; proto++) {
			/* This is not a supported combination in NCCL */
			if (algo == NCCL_ALGO_NVLS_TREE && proto != NCCL_PROTO_SIMPLE)
				continue;

			cost = nccl_ofi_tuner_compute_cost(model_ctx->model_params, &model_ctx->dims,
							   collType, algo, proto, numPipeOps,  nBytes);
			if (cost < 0)
				continue;

			NCCL_OFI_TRACE(NCCL_TUNING, "Model Tuner Computed cost for algo %d proto %d pipe %d: cost %.8f µsecs.",
				       algo, proto, numPipeOps, cost);
			if (cost < lowest) {
				*algorithm = algo;
				*protocol = proto;
				lowest = cost;
			}
		}
	}

exit:
	NCCL_OFI_INFO(NCCL_TUNING, "Model Tuner Choosing algo %d proto %d with cost %.8f µsecs for coll %d size %ld.",
				    *algorithm, *protocol, lowest, collType, nBytes);
	return ncclSuccess;
}