void platform_sort_rails()

in src/platform-aws.cpp [816:930]


void platform_sort_rails(struct fi_info **info_list, size_t num_rails, size_t num_groups)
{
	struct fi_info **info_array = NULL;
	struct fi_info *info_iter = NULL;
	size_t *vf_array = NULL;
	struct fi_info *output_info_list = NULL;
	struct fi_info *output_info_end = NULL;
	size_t highest_vf_idx = 0;
	size_t next_vf_idx = 0;
	size_t info_count;

	/* we only want to reorder if there's more than one NIC per
	 * group (ie, per GPU).  Less than that (P4d or trainium), we
	 * assume topo ordering is sufficient */
	if ((num_rails / num_groups) <= 1) {
		return;
	}

	info_array = (struct fi_info **)calloc(num_rails, sizeof(struct fi_info*));
	if (info_array == NULL) {
		NCCL_OFI_WARN("Did not reorder arrays due to calloc failure");
		goto cleanup;
	}

	vf_array = (size_t *)calloc(num_rails, sizeof(size_t));
	if (vf_array == NULL) {
		NCCL_OFI_WARN("Did not reorder arrays due to calloc failure");
		goto cleanup;
	}

	/* copy the input list into an array so that we can more *
	 * easily associate more data (like the vf array) with the
	 * input and keep everything organized */
	info_iter = *info_list;
	info_count = 0;
	while (info_iter != NULL && info_count < num_rails) {
		info_array[info_count] = fi_dupinfo(info_iter);
		if (info_array[info_count] == NULL) {
			NCCL_OFI_WARN("fi_dupinfo failed");
			goto cleanup;
		}
		info_iter = info_iter->next;

		int ret = get_rail_vf_idx(info_array[info_count]);
		if (ret < 0) {
			NCCL_OFI_WARN("lookup of rail for index %lu failed: %s",
				      info_count, strerror(-ret));
			goto cleanup;
		}
		vf_array[info_count] = ret;
		if (vf_array[info_count] > highest_vf_idx) {
			highest_vf_idx = vf_array[info_count];
		}

		info_count++;
	}
	if (info_count != num_rails) {
		NCCL_OFI_WARN("Info count (%lu) and num_rails (%lu) do not match.  Aborting reorder.",
			      info_count, num_rails);
		goto cleanup;
	}

	/* No reorder required, as devices all have the same vf idx
	   and end result would be the input array */
	if (highest_vf_idx == 0) {
		goto cleanup;
	}

	for (size_t i = 0 ; i < num_rails ; i++) {
		size_t j = num_rails;
		for (j = 0 ; j < num_rails ; j++) {
			if (info_array[j] == NULL) {
				continue;
			}

			if (vf_array[j] == next_vf_idx) {
				if (output_info_list == NULL) {
					output_info_list = output_info_end = info_array[j];
				} else {
					output_info_end->next = info_array[j];
					output_info_end = info_array[j];
				}
				info_array[j] = NULL;
				next_vf_idx = (next_vf_idx + 1) % (highest_vf_idx + 1);
				break;
			}
		}
		if (j == num_rails) {
			NCCL_OFI_WARN("Did not find a device with expected index %zu", next_vf_idx);
			goto cleanup;
		}
	}

	fi_freeinfo(*info_list);
	*info_list = output_info_list;
	output_info_list = NULL;

cleanup:
	if (info_array != NULL) {
		for (size_t i = 0 ; i < num_rails ; i++) {
			if (info_array[i] != NULL) {
				fi_freeinfo(info_array[i]);
			}
		}
		free(info_array);
	}
	if (vf_array != NULL) {
		free(vf_array);
	}
	if (output_info_list != NULL) {
		fi_freeinfo(output_info_list);
	}

	return;
}