ncclResult_t model_get_coll_info_internal_v3()

in src/tuner/nccl_ofi_model.cpp [139:219]


ncclResult_t model_get_coll_info_internal_v3(nccl_ofi_tuner_context_t *ctx,
					     ncclFunc_t collType,
					     size_t nBytes,
					     int numPipeOps,
					     float **collCostTable,
					     int numAlgo,
					     int numProto,
					     int *nChannels)
{
	float cost = 0;
	float lowest = FLT_MAX;
	int algo, proto = 0;
	float(*table)[NCCL_NUM_PROTOCOLS] = (float(*)[NCCL_NUM_PROTOCOLS])collCostTable;
	int chosen_algo = NCCL_ALGO_UNDEF;
	int chosen_proto = NCCL_PROTO_UNDEF;
	nccl_ofi_tuner_model_context_t *model_ctx = (nccl_ofi_tuner_model_context_t *)ctx->type_ctx;

	if (model_ctx == NULL) {
		/* we do not update cost table. Fall back to NCCL's tuner */
		NCCL_OFI_INFO(NCCL_TUNING, "Model Context is not ready. Fall back to NCCL's tuner.");
		return ncclSuccess;
	}

	/* Skip runs smaller than 2 nodes and fallback to NCCL's internal tunings */
	if (model_ctx->dims.num_nodes <= 2) {
		return ncclSuccess;
	}

	/* apply p5/p5e platform specific quirk */
	if (model_ctx->platform == NCCL_OFI_TUNER_P5_P5E) {
		if (collType == ncclFuncAllReduce && model_ctx->dims.num_nodes == 16 &&
		    model_ctx->dims.num_ranks == 128 && nBytes > 3ULL * 1024ULL * 1024ULL * 1024ULL &&
		    nBytes <= 5ULL * 1024ULL * 1024ULL * 1024ULL) {
			lowest = 0;
			chosen_algo = NCCL_ALGO_NVLS_TREE;
			chosen_proto = NCCL_PROTO_SIMPLE;
			goto table_update;
		}
	}

	/*
	 * Ideally, this should just be a lookup and not be in-flight math
	 * We do not want divs in the hot path, but working with the API we've
	 * got now.
	 */
	for (algo = 0; algo < NCCL_NUM_ALGORITHMS; algo++) {
		/* No CollNet on AWS today */
		if (algo == NCCL_ALGO_COLLNET_DIRECT || algo == NCCL_ALGO_COLLNET_CHAIN)
			continue;

		/* Skip NCCL_ALGO_NVLS used only for single-node jobs */
		if (algo == NCCL_ALGO_NVLS)
			continue;

		for (proto = 0; proto < NCCL_NUM_PROTOCOLS; proto++) {
			/* This is not a supported combination in NCCL */
			if (algo == NCCL_ALGO_NVLS_TREE && proto != NCCL_PROTO_SIMPLE)
				continue;

			cost = nccl_ofi_tuner_compute_cost(model_ctx->model_params, &model_ctx->dims,
							   collType, algo, proto, numPipeOps,  nBytes);
			if (cost < 0)
				continue;

			NCCL_OFI_TRACE(NCCL_TUNING, "Model Tuner Computed cost for algo %d proto %d pipe %d: cost %.8f µsecs.",
				       algo, proto, numPipeOps, cost);
			if (cost < lowest) {
				chosen_algo = algo;
				chosen_proto = proto;
				lowest = cost;
			}
		}
	}

table_update:
	table[chosen_algo][chosen_proto] = 0.0;
	NCCL_OFI_INFO(NCCL_TUNING, "Model Tuner Choosing algo %d proto %d with cost %.8f µsecs for coll %d size %ld.",
		      chosen_algo, chosen_proto, table[chosen_algo][chosen_proto], collType, nBytes);

	return ncclSuccess;
}