dags/inference/configs/maxtext_inference_microbenchmark_gce_config.py (90 lines of code) (raw):

# Copyright 2024 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Utilities to construct configs for maxtext inference microbenchmarks DAG.""" import datetime import json from typing import Dict from xlml.apis import gcp_config, metric_config, task, test_config from dags.common import test_owner from dags.multipod.configs import common from dags.common.vm_resource import TpuVersion, Project, RuntimeVersion PROJECT_NAME = Project.CLOUD_ML_AUTO_SOLUTIONS.value RUNTIME_IMAGE = RuntimeVersion.TPU_UBUNTU2204_BASE.value GCS_SUBFOLDER_PREFIX = test_owner.Team.INFERENCE.value def config( tpu_version: TpuVersion, tpu_cores: int, tpu_zone: str, time_out_in_min: int, test_name: str, test_mode: common.SetupMode, project_name: str = PROJECT_NAME, runtime_version: str = RUNTIME_IMAGE, network: str = "default", subnetwork: str = "default", is_tpu_reserved: bool = True, num_slices: int = 1, model_configs: Dict = {}, maxtext_branch: str = "", ): job_gcp_config = gcp_config.GCPConfig( project_name=project_name, zone=tpu_zone, dataset_name=metric_config.DatasetOption.BENCHMARK_DATASET, ) set_up_cmds = ( "pip install --upgrade pip", # Download maxtext f"if [ ! -d maxtext ]; then git clone {maxtext_branch} https://github.com/google/maxtext.git; fi", # Create a python virtual environment "sudo apt-get -y update", "sudo apt-get -y install python3.10-venv", "sudo apt-get -y install jq", "python -m venv .env", "source .env/bin/activate", # Setup MaxText f"cd maxtext && bash setup.sh MODE={test_mode.value} && cd ..", "pip install torch --index-url https://download.pytorch.org/whl/cpu", ) additional_metadata_dict = { "quant_mode": f"{model_configs['quant_mode']}", "two_axis_order_product_id_list": f"{model_configs['two_axis_order_product_id_list']}", "prefill_cache_axis_order_list": f"{model_configs['prefill_cache_axis_order_list']}", "ar_cache_axis_order_list": f"{model_configs['ar_cache_axis_order_list']}", "accelerator": f"v{tpu_version.value}-{tpu_cores}", "flatten_microbenchmark_results": "true", } run_model_cmds = ( # Start virtual environment "source .env/bin/activate", # Get commit hash of the maxtext and jetstream repos "cd maxtext", f"export METADATA_DICT='{json.dumps(additional_metadata_dict)}'", 'export MAXTEXT_COMMIT_HASH=$(git log -1 --format="%H")', # 'export METADATA_DICT=$(jq -c \'. + { "maxtext_commit_hash": $newVal}\' --arg newVal ${MAXTEXT_COMMIT_HASH} <<<"$METADATA_DICT")', # "echo ${METADATA_DICT}", 'jq \'. + { "maxtext_commit_hash": $newVal}\' --arg newVal ${MAXTEXT_COMMIT_HASH} <<<"$METADATA_DICT" > MaxText/metadata.json', "cat MaxText/metadata.json", ### Benchmark # Configure flags "export XLA_FLAGS='--xla_disable_hlo_passes=rematerialization'", f"""python3 -m MaxText.inference_microbenchmark_sweep \ MaxText/configs/base.yml \ model_name={model_configs['model_name']} \ tokenizer_path=assets/{model_configs['tokenizer']} \ weight_dtype={model_configs['weight_dtype']} \ scan_layers={model_configs['scan_layers']} \ max_prefill_predict_length={model_configs['max_prefill_predict_length']} \ max_target_length={model_configs['max_target_length']} \ attention={model_configs['attention']} \ ici_fsdp_parallelism={model_configs['ici_fsdp_parallelism']} \ ici_autoregressive_parallelism={model_configs['ici_autoregressive_parallelism']} \ ici_tensor_parallelism={model_configs['ici_tensor_parallelism']} \ quantization={model_configs['quantization']} \ quantize_kvcache={model_configs['quantize_kvcache']} \ per_device_batch_size={model_configs['per_device_batch_size']} \ inference_microbenchmark_prefill_lengths={model_configs['inference_microbenchmark_prefill_lengths']} \ inference_microbenchmark_stages={model_configs['inference_microbenchmark_stages']} \ inference_microbenchmark_loop_iters={model_configs['inference_microbenchmark_loop_iters']} \ base_output_directory={model_configs['base_output_directory']} \ run_name={model_configs['run_name']} \ profiler={model_configs['profiler']} \ save_config_to_gcs={model_configs['save_config_to_gcs']} \ reshape_q={model_configs['reshape_q']} \ kv_quant_axis={model_configs['kv_quant_axis']} \ compute_axis_order={model_configs['compute_axis_order']} \ inference_metadata_file=MaxText/metadata.json""", "cat inference_microbenchmark_sweep_results.jsonl", "mv inference_microbenchmark_sweep_results.jsonl metric_report.jsonl", f"gsutil cp metric_report.jsonl {metric_config.SshEnvVars.GCS_OUTPUT.value}", ) job_test_config = test_config.TpuVmTest( test_config.Tpu( version=tpu_version, cores=tpu_cores, runtime_version=runtime_version, reserved=is_tpu_reserved, network=network, subnetwork=subnetwork, ), test_name=test_name, set_up_cmds=set_up_cmds, run_model_cmds=run_model_cmds, timeout=datetime.timedelta(minutes=time_out_in_min), task_owner=test_owner.MORGAN_D, num_slices=num_slices, gcs_subfolder=f"{GCS_SUBFOLDER_PREFIX}/maxtext", ) job_metric_config = metric_config.MetricConfig( json_lines=metric_config.JSONLinesConfig("metric_report.jsonl"), use_runtime_generated_gcs_folder=True, ) return task.run_queued_resource_test( task_test_config=job_test_config, task_gcp_config=job_gcp_config, task_metric_config=job_metric_config, )