dags/inference/maxtext_model_config_generator.py (112 lines of code) (raw):

# Copyright 2024 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """A helper to generate maxtext model configs.""" from dags.common.vm_resource import TpuVersion, Zone, Project, V5_NETWORKS, V5E_SUBNETWORKS, V5P_SUBNETWORKS, RuntimeVersion, V6E_GCE_NETWORK, V6E_GCE_SUBNETWORK from dags.inference.configs import jetstream_benchmark_serving_gce_config from dags.multipod.configs.common import SetupMode def generate_model_configs( test_name_prefix, model_config_name, sweep_model_configs, axis_order, ici_parallelism, request_rate, tpu_version, tpu_cores, ): model_configs = {} model_configs["model_config_name"] = model_config_name ( compute_axis_order, prefill_cache_axis_order, ar_cache_axis_order, ) = axis_order.split("-") compute_axis_order = ",".join(compute_axis_order) prefill_cache_axis_order = ",".join(prefill_cache_axis_order) ar_cache_axis_order = ",".join(ar_cache_axis_order) model_configs["compute_axis_order"] = compute_axis_order model_configs["prefill_cache_axis_order"] = prefill_cache_axis_order model_configs["ar_cache_axis_order"] = ar_cache_axis_order ( model_configs["ici_fsdp_parallelism"], model_configs["ici_autoregressive_parallelism"], model_configs["ici_tensor_parallelism"], ) = ici_parallelism model_configs["request_rate"] = request_rate model_configs["maxtext_branch"] = sweep_model_configs["maxtext_branch"] model_configs["jetstream_branch"] = sweep_model_configs["jetstream_branch"] model_configs["model_name"] = sweep_model_configs["model_name"] model_configs["model_mode"] = sweep_model_configs["model_mode"] model_configs["quant_mode"] = sweep_model_configs["quant_mode"] model_configs["sleep_time"] = sweep_model_configs["sleep_time"] model_configs["tokenizer"] = sweep_model_configs["tokenizer"] model_configs["weight_dtype"] = sweep_model_configs["weight_dtype"] model_configs["scan_layers"] = sweep_model_configs["scan_layers"] model_configs["max_prefill_predict_length"] = sweep_model_configs[ "max_prefill_predict_length" ] model_configs["max_target_length"] = sweep_model_configs["max_target_length"] model_configs["attention"] = sweep_model_configs["attention"] model_configs["reshape_q"] = sweep_model_configs["reshape_q"] model_configs["per_device_batch_size"] = sweep_model_configs[ "per_device_batch_size" ] model_configs["checkpoint"] = sweep_model_configs["checkpoint"] model_configs["quantization"] = sweep_model_configs["quantization"] model_configs["quantize_kvcache"] = sweep_model_configs["quantize_kvcache"] model_configs["kv_quant_dtype"] = sweep_model_configs.get( "kv_quant_dtype", "" ) model_configs["kv_quant_axis"] = sweep_model_configs["kv_quant_axis"] model_configs["dataset"] = sweep_model_configs["dataset"] model_configs["dataset_path"] = sweep_model_configs.get("dataset_path", "") model_configs["num_prompts"] = sweep_model_configs["num_prompts"] model_configs["max_output_length"] = sweep_model_configs["max_output_length"] model_configs["warmup_mode"] = sweep_model_configs["warmup_mode"] model_configs["run_eval"] = sweep_model_configs["run_eval"] per_device_batch_size = model_configs["per_device_batch_size"] attention = model_configs["attention"][:3] kv_quant_axis = "".join( [axis for axis in model_configs["kv_quant_axis"].split("_")] ) test_run_tag = ( model_config_name if not kv_quant_axis else f"{model_config_name}-{kv_quant_axis}" ) test_run_tag = f"{test_run_tag}-rate{str(request_rate).replace('.', '_')}-pdbs{per_device_batch_size}-{attention}-{compute_axis_order.replace(',', '')}-{prefill_cache_axis_order.replace(',', '')}-{ar_cache_axis_order.replace(',', '')}" test_name = f"{test_name_prefix}-{test_run_tag}" if tpu_version == TpuVersion.V5E: # v5e benchmarks project_name = Project.TPU_PROD_ENV_AUTOMATED.value zone = Zone.US_EAST1_C.value network = V5_NETWORKS subnetwork = V5E_SUBNETWORKS runtime_version = RuntimeVersion.V2_ALPHA_TPUV5_LITE.value elif tpu_version == TpuVersion.V5P: zone = Zone.US_EAST5_A.value runtime_version = RuntimeVersion.V2_ALPHA_TPUV5.value project_name = Project.TPU_PROD_ENV_AUTOMATED.value network = V5_NETWORKS subnetwork = V5P_SUBNETWORKS elif tpu_version == TpuVersion.TRILLIUM: zone = Zone.EUROPE_WEST4_A.value runtime_version = RuntimeVersion.V2_ALPHA_TPUV6.value project_name = Project.CLOUD_ML_AUTO_SOLUTIONS.value network = V6E_GCE_NETWORK subnetwork = V6E_GCE_SUBNETWORK jetstream_benchmark_serving = ( jetstream_benchmark_serving_gce_config.get_config( tpu_version=tpu_version, tpu_cores=tpu_cores, tpu_zone=zone, time_out_in_min=sweep_model_configs["time_out_in_min"], test_name=test_name, test_mode=SetupMode.STABLE, project_name=project_name, runtime_version=runtime_version, network=network, subnetwork=subnetwork, is_tpu_reserved=True, model_configs=model_configs, maxtext_branch=model_configs["maxtext_branch"], jetstream_branch=model_configs["jetstream_branch"], ) ) return jetstream_benchmark_serving