dags/inference/maxtext_inference_microbenchmark.py (271 lines of code) (raw):

# Copyright 2024 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """A DAG to run MaxText inference microbenchmarks with nightly version.""" import datetime import pytz import itertools import numpy from airflow import models from dags.common.vm_resource import TpuVersion, Zone, Project, V5_NETWORKS, V5E_SUBNETWORKS, V5P_SUBNETWORKS, RuntimeVersion, V6E_GCE_NETWORK, V6E_GCE_SUBNETWORK from dags.inference.configs import maxtext_inference_microbenchmark_gce_config from dags.multipod.configs.common import SetupMode USER_PREFIX = "" MAXTEXT_BRANCH = "" LLAMA2_7B = "llama2-7b" LLAMA2_13B = "llama2-13b" W_BF16_KV_BF16 = "w-b16-kv-b16" W_INT8_KV_INT8 = "w-i8-kv-i8" BASE_OUTPUT_DIRECTORY = ( "gs://inference-benchmarks/logs/maxtext-inference-microbenchmark" ) test_run_datetime = datetime.datetime.now( pytz.timezone("America/Los_Angeles") ).strftime("%Y%m%d-%H%M%S") def get_concatenated_list_of_params(sweep_vm_count=1): cache_rank = 4 cache_permu_values = list(itertools.permutations(range(cache_rank))) cache_permu_strs = [ ",".join([str(i) for i in value]) for value in cache_permu_values ] cache_permu_idx_strs = { cache_permu_idx: cache_permu_str for cache_permu_idx, cache_permu_str in enumerate(cache_permu_strs) } num_cache_permu = len(cache_permu_strs) two_cache_idx_product_values = list( itertools.product(range(num_cache_permu), range(num_cache_permu)) ) two_cache_idx_product_idx_values = { two_cache_idx_product_idx: two_cache_idx_product_value for two_cache_idx_product_idx, two_cache_idx_product_value in enumerate( two_cache_idx_product_values ) } two_axis_order_product_id_list = [] prefill_cache_axis_order_str_list = [] ar_cache_axis_order_str_list = [] for two_axis_order_product_id in range(len(two_cache_idx_product_idx_values)): ( prefill_cache_axis_order_idx, ar_cache_axis_order_idx, ) = two_cache_idx_product_idx_values[int(two_axis_order_product_id)] prefill_cache_axis_order_str = cache_permu_idx_strs[ prefill_cache_axis_order_idx ] ar_cache_axis_order_str = cache_permu_idx_strs[ar_cache_axis_order_idx] two_axis_order_product_id_list.append(two_axis_order_product_id) prefill_cache_axis_order_str_list.append(prefill_cache_axis_order_str) ar_cache_axis_order_str_list.append(ar_cache_axis_order_str) two_axis_order_product_id_split = numpy.array_split( two_axis_order_product_id_list, sweep_vm_count ) prefill_cache_axis_order_str_split = numpy.array_split( prefill_cache_axis_order_str_list, sweep_vm_count ) ar_cache_axis_order_str_split = numpy.array_split( ar_cache_axis_order_str_list, sweep_vm_count ) two_axis_order_product_id_concat_list = [ ":".join(list(str(y) for y in x)) for x in two_axis_order_product_id_split ] prefill_cache_axis_order_concat_list = [ ":".join(list(x)) for x in prefill_cache_axis_order_str_split ] ar_cache_axis_order_concat_list = [ ":".join(list(x)) for x in ar_cache_axis_order_str_split ] return ( two_axis_order_product_id_concat_list, prefill_cache_axis_order_concat_list, ar_cache_axis_order_concat_list, ) def generate_model_configs( test_name_prefix, model_config_name, sweep_model_configs, compute_axis_order, ici_parallelism, vm_number, tpu_version, tpu_cores, ): model_configs = {} model_configs["model_config_name"] = model_config_name model_configs["compute_axis_order"] = compute_axis_order ( model_configs["ici_fsdp_parallelism"], model_configs["ici_autoregressive_parallelism"], model_configs["ici_tensor_parallelism"], ) = ici_parallelism model_configs["maxtext_branch"] = sweep_model_configs["maxtext_branch"] model_configs["model_name"] = sweep_model_configs["model_name"] model_configs["quant_mode"] = sweep_model_configs["quant_mode"] model_configs["sleep_time"] = sweep_model_configs["sleep_time"] model_configs["tokenizer"] = sweep_model_configs["tokenizer"] model_configs["weight_dtype"] = sweep_model_configs["weight_dtype"] model_configs["scan_layers"] = sweep_model_configs["scan_layers"] model_configs["max_prefill_predict_length"] = sweep_model_configs[ "max_prefill_predict_length" ] model_configs["max_target_length"] = sweep_model_configs["max_target_length"] model_configs["attention"] = sweep_model_configs["attention"] model_configs["per_device_batch_size"] = sweep_model_configs[ "per_device_batch_size" ] model_configs["quantization"] = sweep_model_configs["quantization"] model_configs["quantize_kvcache"] = sweep_model_configs["quantize_kvcache"] model_configs["kv_quant_axis"] = sweep_model_configs["kv_quant_axis"] model_configs["base_output_directory"] = sweep_model_configs[ "base_output_directory" ] model_configs[ "inference_microbenchmark_prefill_lengths" ] = sweep_model_configs["inference_microbenchmark_prefill_lengths"] model_configs["inference_microbenchmark_stages"] = sweep_model_configs[ "inference_microbenchmark_stages" ] model_configs["inference_microbenchmark_loop_iters"] = sweep_model_configs[ "inference_microbenchmark_loop_iters" ] model_configs["profiler"] = sweep_model_configs["profiler"] model_configs["save_config_to_gcs"] = sweep_model_configs[ "save_config_to_gcs" ] model_configs["reshape_q"] = sweep_model_configs["reshape_q"] model_configs[ "two_axis_order_product_id_list" ] = two_axis_order_product_id_concat_list[vm_number] model_configs[ "prefill_cache_axis_order_list" ] = prefill_cache_axis_order_concat_list[vm_number] model_configs["ar_cache_axis_order_list"] = ar_cache_axis_order_concat_list[ vm_number ] attention = sweep_model_configs["attention"] per_device_batch_size = sweep_model_configs["per_device_batch_size"] compute_axis_order_tag = model_configs["compute_axis_order"].replace(",", "") test_run_tag = f"{model_config_name}-bs{per_device_batch_size}-{attention[:3]}-{compute_axis_order_tag}-vm{vm_number}" test_name = f"{test_name_prefix}-{test_run_tag}" model_configs["run_name"] = test_run_tag if tpu_version == TpuVersion.V5E: # v5e benchmarks project_name = Project.TPU_PROD_ENV_AUTOMATED.value zone = Zone.US_EAST1_C.value network = V5_NETWORKS subnetwork = V5E_SUBNETWORKS runtime_version = RuntimeVersion.V2_ALPHA_TPUV5_LITE.value if tpu_version == TpuVersion.TRILLIUM: project_name = Project.CLOUD_ML_AUTO_SOLUTIONS.value zone = Zone.EUROPE_WEST4_A.value network = V6E_GCE_NETWORK subnetwork = V6E_GCE_SUBNETWORK runtime_version = RuntimeVersion.V2_ALPHA_TPUV6.value maxtext_kv_cache_layout_optimization = ( maxtext_inference_microbenchmark_gce_config.config( tpu_version=tpu_version, tpu_cores=tpu_cores, tpu_zone=zone, time_out_in_min=sweep_model_configs["time_out_in_min"], test_name=test_name, test_mode=SetupMode.STABLE, project_name=project_name, runtime_version=runtime_version, network=network, subnetwork=subnetwork, is_tpu_reserved=True, model_configs=model_configs, maxtext_branch=model_configs["maxtext_branch"], ) ) return maxtext_kv_cache_layout_optimization dag_id = ( "maxtext-inference-microbenchmark" if not USER_PREFIX else f"{USER_PREFIX}-maxtext-inference-microbenchmark" ) tags = ["inference_team", "maxtext", "microbenchmark"] if USER_PREFIX: dag_id = f"{USER_PREFIX}-maxtext-inference-microbenchmark" tags.append(USER_PREFIX) with models.DAG( dag_id=dag_id, tags=tags, start_date=datetime.datetime(2024, 1, 19), schedule=None, catchup=False, ) as dag: test_name_prefix = ( "max-micro" if not USER_PREFIX else f"{USER_PREFIX}-max-micro" ) sweep_vm_count = 8 ( two_axis_order_product_id_concat_list, prefill_cache_axis_order_concat_list, ar_cache_axis_order_concat_list, ) = get_concatenated_list_of_params(sweep_vm_count=sweep_vm_count) test_templates = { LLAMA2_7B: { "maxtext_branch": "" if not MAXTEXT_BRANCH else f"-b {MAXTEXT_BRANCH}", "sleep_time": 60, "tpu_version_cores": [(TpuVersion.V5E, 8), (TpuVersion.TRILLIUM, 8)], "model_name": LLAMA2_7B, "tokenizer": "tokenizer.llama2", "weight_dtype": "bfloat16", "scan_layers": "false", "max_prefill_predict_length": 1024, "max_target_length": 2048, "attention": "dot_product", # (ici_fsdp_parallelism, ici_autoregressive_parallelism, ici_tensor_parallelism) "ici_parallelisms": [(1, 1, -1)], "inference_microbenchmark_prefill_lengths": "64,128,256,512,1024", "inference_microbenchmark_stages": "prefill, generate", "inference_microbenchmark_loop_iters": 10, "base_output_directory": f"{BASE_OUTPUT_DIRECTORY}/{test_name_prefix}/kv_cache_layout_optimization/{test_run_datetime}", "profiler": "xplane", "save_config_to_gcs": "true", "reshape_q": "true", "compute_axis_order": ["0,2,1,3"], }, } tests = { f"{LLAMA2_7B}-{W_BF16_KV_BF16}": test_templates[LLAMA2_7B] | { "quant_mode": W_BF16_KV_BF16, "quantization": "", "quantize_kvcache": "false", "per_device_batch_size": 10, "kv_quant_axis": "", "time_out_in_min": 330, }, f"{LLAMA2_7B}-{W_INT8_KV_INT8}": test_templates[LLAMA2_7B] | { "quant_mode": W_INT8_KV_INT8, "quantization": "int8", "quantize_kvcache": "true", "per_device_batch_size": 24, "kv_quant_axis": "heads_and_dkv", "time_out_in_min": 360, }, } run_configs = [ f"{LLAMA2_7B}-{W_INT8_KV_INT8}", ] skip_configs = [ f"{LLAMA2_7B}-{W_BF16_KV_BF16}", ] for model_config_name, sweep_model_configs in tests.items(): if run_configs and model_config_name not in run_configs: continue if skip_configs and model_config_name in skip_configs: continue for tpu_version, tpu_cores in sweep_model_configs["tpu_version_cores"]: for compute_axis_order in sweep_model_configs["compute_axis_order"]: for ici_parallelism in sweep_model_configs["ici_parallelisms"]: for vm_number in range(sweep_vm_count): maxtext_kv_cache_layout_optimization = generate_model_configs( test_name_prefix=test_name_prefix, model_config_name=model_config_name, sweep_model_configs=sweep_model_configs, compute_axis_order=compute_axis_order, ici_parallelism=ici_parallelism, vm_number=vm_number, tpu_version=tpu_version, tpu_cores=tpu_cores, )