def get_concatenated_list_of_params()

in dags/inference/maxtext_inference_microbenchmark.py [0:0]


def get_concatenated_list_of_params(sweep_vm_count=1):
  cache_rank = 4
  cache_permu_values = list(itertools.permutations(range(cache_rank)))
  cache_permu_strs = [
      ",".join([str(i) for i in value]) for value in cache_permu_values
  ]
  cache_permu_idx_strs = {
      cache_permu_idx: cache_permu_str
      for cache_permu_idx, cache_permu_str in enumerate(cache_permu_strs)
  }
  num_cache_permu = len(cache_permu_strs)
  two_cache_idx_product_values = list(
      itertools.product(range(num_cache_permu), range(num_cache_permu))
  )
  two_cache_idx_product_idx_values = {
      two_cache_idx_product_idx: two_cache_idx_product_value
      for two_cache_idx_product_idx, two_cache_idx_product_value in enumerate(
          two_cache_idx_product_values
      )
  }
  two_axis_order_product_id_list = []
  prefill_cache_axis_order_str_list = []
  ar_cache_axis_order_str_list = []
  for two_axis_order_product_id in range(len(two_cache_idx_product_idx_values)):
    (
        prefill_cache_axis_order_idx,
        ar_cache_axis_order_idx,
    ) = two_cache_idx_product_idx_values[int(two_axis_order_product_id)]
    prefill_cache_axis_order_str = cache_permu_idx_strs[
        prefill_cache_axis_order_idx
    ]
    ar_cache_axis_order_str = cache_permu_idx_strs[ar_cache_axis_order_idx]
    two_axis_order_product_id_list.append(two_axis_order_product_id)
    prefill_cache_axis_order_str_list.append(prefill_cache_axis_order_str)
    ar_cache_axis_order_str_list.append(ar_cache_axis_order_str)
  two_axis_order_product_id_split = numpy.array_split(
      two_axis_order_product_id_list, sweep_vm_count
  )
  prefill_cache_axis_order_str_split = numpy.array_split(
      prefill_cache_axis_order_str_list, sweep_vm_count
  )
  ar_cache_axis_order_str_split = numpy.array_split(
      ar_cache_axis_order_str_list, sweep_vm_count
  )
  two_axis_order_product_id_concat_list = [
      ":".join(list(str(y) for y in x)) for x in two_axis_order_product_id_split
  ]
  prefill_cache_axis_order_concat_list = [
      ":".join(list(x)) for x in prefill_cache_axis_order_str_split
  ]
  ar_cache_axis_order_concat_list = [
      ":".join(list(x)) for x in ar_cache_axis_order_str_split
  ]
  return (
      two_axis_order_product_id_concat_list,
      prefill_cache_axis_order_concat_list,
      ar_cache_axis_order_concat_list,
  )