dags/inference/jetstream_inference_e2e.py (131 lines of code) (raw):

# Copyright 2024 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """A DAG to run JetStream inference E2E test.""" import datetime from airflow import models from dags.common.vm_resource import TpuVersion from dags.inference.maxtext_model_config_generator import generate_model_configs """A JetStream inference E2E test (JAX nightly, no schedule) DAG. Usage: gcloud composer environments run ml-automation-solutions \ --project=cloud-ml-auto-solutions \ --location=us-central1 dags trigger \ -- \ jetstream_e2e_inference """ LLAMA2_7B = "llama2-7b" GEMMA_7B = "gemma-7b" BASE_MODE = "base" W_BF16_KV_BF16 = "w-b16-kv-b16" CKPT = { LLAMA2_7B: { BASE_MODE: "gs://inference-benchmarks/models/llama2-7b/2024-04-25-14-01/param-only-decode-ckpt-maxtext/checkpoints/0/items", }, GEMMA_7B: { BASE_MODE: "gs://inference-benchmarks/models/gemma-7b/2024-04-25-14-01/param-only-decode-ckpt-maxtext/checkpoints/0/items" }, } with models.DAG( dag_id="jetstream_e2e_inference", schedule=None, tags=["inference_team", "jetstream", "maxtext", "nightly", "e2e"], start_date=datetime.datetime(2024, 1, 19), catchup=False, ) as dag: test_name_prefix = "jetstream-e2e-inference" test_templates = { # LLAMA2_7B LLAMA2_7B: { "maxtext_branch": "", "jetstream_branch": "", "sleep_time": 360, "time_out_in_min": 60, "tpu_version_cores": [(TpuVersion.V5E, 8), (TpuVersion.TRILLIUM, 8)], "model_name": LLAMA2_7B, "tokenizer": "tokenizer.llama2", "weight_dtype": "bfloat16", "scan_layers": "false", "max_prefill_predict_length": 1024, "max_target_length": 2048, "reshape_q": True, # (ici_fsdp_parallelism, ici_autoregressive_parallelism, ici_tensor_parallelism) "ici_parallelisms": [(1, 1, -1)], "dataset": "openorca", "num_prompts": 200, "max_output_length": 1024, "warmup_mode": "full", }, f"{LLAMA2_7B}-{W_BF16_KV_BF16}-dot-product": { "attention": "dot_product", "request_rate": [0.0], "axis_order": [ "0123-2013-2013", ], }, # GEMMA_7B GEMMA_7B: { "maxtext_branch": "", "jetstream_branch": "", "sleep_time": 360, "time_out_in_min": 60, "tpu_version_cores": [(TpuVersion.V5E, 8), (TpuVersion.TRILLIUM, 8)], "model_name": GEMMA_7B, "tokenizer": "tokenizer.gemma", "weight_dtype": "bfloat16", "scan_layers": "false", "max_prefill_predict_length": 1024, "max_target_length": 2048, "reshape_q": True, # (ici_fsdp_parallelism, ici_autoregressive_parallelism, ici_tensor_parallelism) "ici_parallelisms": [(1, 1, -1)], "dataset": "sharegpt", "dataset_path": "~/ShareGPT_V3_unfiltered_cleaned_split.json", "request_rate": [0.0], "num_prompts": 200, "max_output_length": 1024, "warmup_mode": "full", }, f"{GEMMA_7B}-{W_BF16_KV_BF16}-autoselect": { "attention": "autoselected", "request_rate": [0.0], "axis_order": ["0123-1203-1203"], }, } tests = { # LLAMA2_7B f"{LLAMA2_7B}-{BASE_MODE}-{W_BF16_KV_BF16}": test_templates[LLAMA2_7B] | test_templates[f"{LLAMA2_7B}-{W_BF16_KV_BF16}-dot-product"] | { "checkpoint": CKPT[LLAMA2_7B][BASE_MODE], "model_mode": BASE_MODE, "quant_mode": W_BF16_KV_BF16, "quantization": "", "quantize_kvcache": "false", "per_device_batch_size": 12, "kv_quant_axis": "", "run_eval": True, }, # GEMMA_7B f"{GEMMA_7B}-{BASE_MODE}-{W_BF16_KV_BF16}": test_templates[GEMMA_7B] | test_templates[f"{GEMMA_7B}-{W_BF16_KV_BF16}-autoselect"] | { "checkpoint": CKPT[GEMMA_7B][BASE_MODE], "model_mode": BASE_MODE, "quant_mode": W_BF16_KV_BF16, "quantization": "", "quantize_kvcache": "false", "per_device_batch_size": 12, "kv_quant_axis": "", "run_eval": True, }, } run_configs = [ f"{LLAMA2_7B}-{BASE_MODE}-{W_BF16_KV_BF16}", f"{GEMMA_7B}-{BASE_MODE}-{W_BF16_KV_BF16}", ] skip_configs = [] for model_config_name, sweep_model_configs in tests.items(): if run_configs and model_config_name not in run_configs: continue if skip_configs and model_config_name in skip_configs: continue dags = [] for tpu_version, tpu_cores in sweep_model_configs["tpu_version_cores"]: for axis_order in sweep_model_configs["axis_order"]: for ici_parallelism in sweep_model_configs["ici_parallelisms"]: for request_rate in sweep_model_configs["request_rate"]: jetstream_benchmark_serving_kv_cache_layout = ( generate_model_configs( test_name_prefix=test_name_prefix, model_config_name=model_config_name, sweep_model_configs=sweep_model_configs, axis_order=axis_order, ici_parallelism=ici_parallelism, request_rate=request_rate, tpu_version=tpu_version, tpu_cores=tpu_cores, ) )