in dags/inference/maxtext_inference_microbenchmark.py [0:0]
def get_concatenated_list_of_params(sweep_vm_count=1):
cache_rank = 4
cache_permu_values = list(itertools.permutations(range(cache_rank)))
cache_permu_strs = [
",".join([str(i) for i in value]) for value in cache_permu_values
]
cache_permu_idx_strs = {
cache_permu_idx: cache_permu_str
for cache_permu_idx, cache_permu_str in enumerate(cache_permu_strs)
}
num_cache_permu = len(cache_permu_strs)
two_cache_idx_product_values = list(
itertools.product(range(num_cache_permu), range(num_cache_permu))
)
two_cache_idx_product_idx_values = {
two_cache_idx_product_idx: two_cache_idx_product_value
for two_cache_idx_product_idx, two_cache_idx_product_value in enumerate(
two_cache_idx_product_values
)
}
two_axis_order_product_id_list = []
prefill_cache_axis_order_str_list = []
ar_cache_axis_order_str_list = []
for two_axis_order_product_id in range(len(two_cache_idx_product_idx_values)):
(
prefill_cache_axis_order_idx,
ar_cache_axis_order_idx,
) = two_cache_idx_product_idx_values[int(two_axis_order_product_id)]
prefill_cache_axis_order_str = cache_permu_idx_strs[
prefill_cache_axis_order_idx
]
ar_cache_axis_order_str = cache_permu_idx_strs[ar_cache_axis_order_idx]
two_axis_order_product_id_list.append(two_axis_order_product_id)
prefill_cache_axis_order_str_list.append(prefill_cache_axis_order_str)
ar_cache_axis_order_str_list.append(ar_cache_axis_order_str)
two_axis_order_product_id_split = numpy.array_split(
two_axis_order_product_id_list, sweep_vm_count
)
prefill_cache_axis_order_str_split = numpy.array_split(
prefill_cache_axis_order_str_list, sweep_vm_count
)
ar_cache_axis_order_str_split = numpy.array_split(
ar_cache_axis_order_str_list, sweep_vm_count
)
two_axis_order_product_id_concat_list = [
":".join(list(str(y) for y in x)) for x in two_axis_order_product_id_split
]
prefill_cache_axis_order_concat_list = [
":".join(list(x)) for x in prefill_cache_axis_order_str_split
]
ar_cache_axis_order_concat_list = [
":".join(list(x)) for x in ar_cache_axis_order_str_split
]
return (
two_axis_order_product_id_concat_list,
prefill_cache_axis_order_concat_list,
ar_cache_axis_order_concat_list,
)