dags/inference/maxtext_inference.py (526 lines of code) (raw):

# Copyright 2024 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """A DAG to run MaxText inference benchmarks with nightly version.""" import datetime import numpy as np from airflow import models from dags import composer_env from dags.common.vm_resource import TpuVersion from dags.inference.maxtext_model_config_generator import generate_model_configs USER_PREFIX = "" MAXTEXT_BRANCH = "" JETSTREAM_BRANCH = "" maxtext_branch = "" if not MAXTEXT_BRANCH else f"-b {MAXTEXT_BRANCH}" jetstream_branch = "" if not JETSTREAM_BRANCH else f"-b {JETSTREAM_BRANCH}" # Run once a day at 8 am UTC (12 am PST) SCHEDULED_TIME = "0 8 * * *" if composer_env.is_prod_env() else None LLAMA2_7B = "llama2-7b" LLAMA2_13B = "llama2-13b" LLAMA2_70B = "llama2-70b" GEMMA_7B = "gemma-7b" MIXTRAL_8_7B = "mixtral-8x7b" BASE_MODE = "base" CHAT_MODE = "chat" INSTRUCT_MODE = "instruct" W_BF16_KV_BF16 = "w-b16-kv-b16" W_INT8_KV_INT8 = "w-i8-kv-i8" CKPT = { LLAMA2_7B: { BASE_MODE: "gs://inference-benchmarks/models/llama2-7b/2024-04-25-14-01/param-only-decode-ckpt-maxtext/checkpoints/0/items", CHAT_MODE: "gs://inference-benchmarks/models/llama2-7b-chat/2024-05-24-12-39/param-only-decode-ckpt-maxtext/checkpoints/0/items", }, LLAMA2_13B: { BASE_MODE: "gs://inference-benchmarks/models/llama2-13b/2024-04-25-14-01/param-only-decode-ckpt-maxtext/checkpoints/0/items", CHAT_MODE: "gs://inference-benchmarks/models/llama2-13b-chat/2024-05-24-12-39/param-only-decode-ckpt-maxtext/checkpoints/0/items", }, LLAMA2_70B: { CHAT_MODE: "gs://inference-benchmarks/models/llama2-70b-chat/2024-05-08-23-16/param-only-decode-ckpt-maxtext/checkpoints/0/items" }, GEMMA_7B: { BASE_MODE: "gs://inference-benchmarks/models/gemma-7b/2024-04-25-14-01/param-only-decode-ckpt-maxtext/checkpoints/0/items" }, MIXTRAL_8_7B: { # checkpoint created using these instructions - go/mixtral-inference-testing INSTRUCT_MODE: "gs://vipannalla_mixtral_ckpt/moe_matmul/moe_matmul_06_15_24/checkpoints/0/items/" }, } dag_id = ( "jetstream_benchmark_serving" if not USER_PREFIX else f"{USER_PREFIX}_jetstream_benchmark_serving" ) tags = ["inference_team", "jetstream", "maxtext", "benchmark"] if USER_PREFIX: tags.append(USER_PREFIX) with models.DAG( dag_id=dag_id, tags=tags, start_date=datetime.datetime(2024, 1, 19), schedule=SCHEDULED_TIME, catchup=False, ) as dag: test_name_prefix = "max-js" if not USER_PREFIX else f"{USER_PREFIX}-max-js" # TODO: baseline layout can be deleted if meeting one of the below conditions: # - default cache layout performs better # - run layout tuning to get optimized cache layout for 0213 llama2_7B_bf16_batch_sizes = [1, 2, 4, 8, 12] llama2_7B_int8_batch_sizes = [1, 2, 4, 8, 12, 24] llama2_13B_bf16_batch_sizes = [1, 2, 4, 8] llama2_13B_int8_batch_sizes = [1, 2, 4, 8] llama2_70B_bf16_batch_sizes = [1, 4, 8, 16, 24] llama2_70B_int8_batch_sizes = [1, 2, 4, 8, 11] gemma_7B_bf16_batch_sizes = [1, 2, 4, 8, 12] gemma_7B_int8_batch_sizes = [1, 2, 4, 8, 12, 24] test_templates = { # LLAMA2_7B LLAMA2_7B: { "maxtext_branch": maxtext_branch, "jetstream_branch": jetstream_branch, "sleep_time": 360, "time_out_in_min": 120, "tpu_version_cores": [(TpuVersion.V5E, 8), (TpuVersion.TRILLIUM, 8)], "model_name": LLAMA2_7B, "tokenizer": "tokenizer.llama2", "weight_dtype": "bfloat16", "scan_layers": "false", "max_prefill_predict_length": 1024, "max_target_length": 2048, "reshape_q": True, # (ici_fsdp_parallelism, ici_autoregressive_parallelism, ici_tensor_parallelism) "ici_parallelisms": [(1, 1, -1)], "dataset": "openorca", "num_prompts": 1000, "max_output_length": 1024, "warmup_mode": "full", }, f"{LLAMA2_7B}-{W_BF16_KV_BF16}-dot-product": { "attention": "dot_product", "request_rate": [0.0], "axis_order": [ "0123-2013-2013", # optimized layout for 0123 "0213-0213-0213", # default layout "0213-0213-0132", # optimized layout for 0213 ], }, f"{LLAMA2_7B}-{W_INT8_KV_INT8}-dot-product": { "attention": "dot_product", "request_rate": [0.0], "axis_order": [ "0213-0213-0213", # default layout "0213-0231-0213", # optimized layout for 0213 ], }, # LLAMA2_13B LLAMA2_13B: { "maxtext_branch": maxtext_branch, "jetstream_branch": jetstream_branch, "sleep_time": 360, "time_out_in_min": 120, "tpu_version_cores": [(TpuVersion.V5E, 8), (TpuVersion.TRILLIUM, 8)], "model_name": LLAMA2_13B, "tokenizer": "tokenizer.llama2", "weight_dtype": "bfloat16", "scan_layers": "false", "max_prefill_predict_length": 1024, "max_target_length": 2048, "reshape_q": True, # (ici_fsdp_parallelism, ici_autoregressive_parallelism, ici_tensor_parallelism) "ici_parallelisms": [(1, 1, -1)], "dataset": "openorca", "request_rate": [0.0], "num_prompts": 1000, "max_output_length": 1024, "warmup_mode": "full", }, f"{LLAMA2_13B}-{W_BF16_KV_BF16}-dot-product": { "attention": "dot_product", "request_rate": [0.0], "axis_order": [ "0123-1203-1203", # baseline layout "0213-0213-0213", # default layout ], }, f"{LLAMA2_13B}-{W_INT8_KV_INT8}-dot-product": { "attention": "dot_product", "request_rate": [0.0], "axis_order": [ "0123-1203-1203", # baseline layout "0213-0213-0213", # default layout ], }, # LLAMA2_70B LLAMA2_70B: { "maxtext_branch": maxtext_branch, "jetstream_branch": jetstream_branch, "sleep_time": 360, "time_out_in_min": 240, "tpu_version_cores": [(TpuVersion.V5P, 8), (TpuVersion.TRILLIUM, 8)], "model_name": LLAMA2_70B, "tokenizer": "tokenizer.llama2", "weight_dtype": "bfloat16", "scan_layers": "false", "max_prefill_predict_length": 1024, "max_target_length": 2048, "reshape_q": True, # (ici_fsdp_parallelism, ici_autoregressive_parallelism, ici_tensor_parallelism) "ici_parallelisms": [(1, 1, -1)], "dataset": "openorca", "num_prompts": 1000, "max_output_length": 1024, "warmup_mode": "full", }, f"{LLAMA2_70B}-{W_BF16_KV_BF16}-dot-product": { "attention": "dot_product", "request_rate": [0.0], "axis_order": [ "0123-1203-1203", # baseline layout "0213-0213-0213", # default layout ], }, f"{LLAMA2_70B}-{W_INT8_KV_INT8}-dot-product": { "attention": "dot_product", "request_rate": [0.0], "axis_order": [ "0123-1203-1203", # baseline layout "0213-0213-0213", # default layout ], }, # GEMMA_7B GEMMA_7B: { "maxtext_branch": maxtext_branch, "jetstream_branch": jetstream_branch, "sleep_time": 360, "time_out_in_min": 120, "tpu_version_cores": [(TpuVersion.V5E, 8), (TpuVersion.TRILLIUM, 8)], "model_name": GEMMA_7B, "tokenizer": "tokenizer.gemma", "weight_dtype": "bfloat16", "scan_layers": "false", "max_prefill_predict_length": 1024, "max_target_length": 2048, "reshape_q": True, # (ici_fsdp_parallelism, ici_autoregressive_parallelism, ici_tensor_parallelism) "ici_parallelisms": [(1, 1, -1)], "dataset": "sharegpt", "dataset_path": "~/ShareGPT_V3_unfiltered_cleaned_split.json", "request_rate": [0.0], "num_prompts": 1000, "max_output_length": 1024, "warmup_mode": "full", }, f"{GEMMA_7B}-{W_BF16_KV_BF16}-autoselect": { "attention": "autoselected", "request_rate": [0.0], "axis_order": [ "0123-1203-1203", # baseline layout "0213-0213-0213", # default layout ], }, f"{GEMMA_7B}-{W_INT8_KV_INT8}-autoselect": { "attention": "autoselected", "request_rate": [0.0], "axis_order": [ "0123-1203-1203", # baseline layout "0213-0213-0213", # default layout ], }, # MIXTRAL_8_7B MIXTRAL_8_7B: { "maxtext_branch": maxtext_branch, "jetstream_branch": jetstream_branch, "sleep_time": 240, "time_out_in_min": 240, "tpu_version_cores": [(TpuVersion.V5P, 8), (TpuVersion.TRILLIUM, 8)], "model_name": MIXTRAL_8_7B, "tokenizer": "gs://maxtext-external/mixtral-8x7B-v0.1-Instruct/tokenizer.mistral", "weight_dtype": "bfloat16", "scan_layers": "false", "max_prefill_predict_length": 2048, "max_target_length": 3072, "reshape_q": True, # (ici_fsdp_parallelism, ici_autoregressive_parallelism, ici_tensor_parallelism) "ici_parallelisms": [(1, 1, -1)], "dataset": "openorca", "num_prompts": 1000, "max_output_length": 1024, "warmup_mode": "full", }, f"{MIXTRAL_8_7B}-{W_BF16_KV_BF16}-dot-product": { "attention": "dot_product", "request_rate": [0.0], "axis_order": [ "0123-1203-1203", # baseline layout "0213-0213-0213", # default layout ], }, f"{MIXTRAL_8_7B}-{W_INT8_KV_INT8}-dot-product": { "attention": "dot_product", "request_rate": [0.0], "axis_order": [ "0123-1203-1203", # baseline layout "0213-0213-0213", # default layout ], }, } tests_llama2_7b_bf16_base_mode_tests = {} tests_llama2_7b_bf16_chat_mode_tests = {} tests_llama2_7b_int8_base_mode_tests = {} tests_llama2_7b_int8_chat_mode_tests = {} # llama 2 7B bfloat16 tests in both base and chat mode for bs in llama2_7B_bf16_batch_sizes: tests_llama2_7b_bf16_base_mode_tests[ f"{LLAMA2_7B}-{BASE_MODE}-{W_BF16_KV_BF16}-{bs}" ] = ( test_templates[LLAMA2_7B] | test_templates[f"{LLAMA2_7B}-{W_BF16_KV_BF16}-dot-product"] | { "checkpoint": CKPT[LLAMA2_7B][BASE_MODE], "model_mode": BASE_MODE, "quant_mode": W_BF16_KV_BF16, "quantization": "", "quantize_kvcache": "false", "per_device_batch_size": bs, "kv_quant_axis": "", "run_eval": False, } ) tests_llama2_7b_bf16_chat_mode_tests[ f"{LLAMA2_7B}-{CHAT_MODE}-{W_BF16_KV_BF16}-{bs}" ] = ( test_templates[LLAMA2_7B] | test_templates[f"{LLAMA2_7B}-{W_BF16_KV_BF16}-dot-product"] | { "checkpoint": CKPT[LLAMA2_7B][CHAT_MODE], "model_mode": CHAT_MODE, "quant_mode": W_BF16_KV_BF16, "quantization": "", "quantize_kvcache": "false", "per_device_batch_size": bs, "kv_quant_axis": "", "run_eval": True, } ) # llama 2 7B int8 tests in both base and chat mode for bs in llama2_7B_int8_batch_sizes: tests_llama2_7b_int8_base_mode_tests[ f"{LLAMA2_7B}-{BASE_MODE}-{W_INT8_KV_INT8}-{bs}" ] = ( test_templates[LLAMA2_7B] | test_templates[f"{LLAMA2_7B}-{W_INT8_KV_INT8}-dot-product"] | { "checkpoint": CKPT[LLAMA2_7B][BASE_MODE], "model_mode": BASE_MODE, "quant_mode": W_INT8_KV_INT8, "quantization": "int8", "quantize_kvcache": "true", "kv_quant_dtype": "int8", "per_device_batch_size": bs, "kv_quant_axis": "heads_and_dkv", "run_eval": False, } ) tests_llama2_7b_int8_chat_mode_tests[ f"{LLAMA2_7B}-{CHAT_MODE}-{W_INT8_KV_INT8}-{bs}" ] = ( test_templates[LLAMA2_7B] | test_templates[f"{LLAMA2_7B}-{W_BF16_KV_BF16}-dot-product"] | { "checkpoint": CKPT[LLAMA2_7B][CHAT_MODE], "model_mode": CHAT_MODE, "quant_mode": W_INT8_KV_INT8, "quantization": "int8", "quantize_kvcache": "true", "kv_quant_dtype": "int8", "per_device_batch_size": bs, "kv_quant_axis": "heads_and_dkv", "run_eval": True, } ) tests_llama2_13b_bf16_base_mode_tests = {} tests_llama2_13b_bf16_chat_mode_tests = {} tests_llama2_13b_int8_base_mode_tests = {} tests_llama2_13b_int8_chat_mode_tests = {} # llama 13B 7B bfloat16 tests in both base and chat mode for bs in llama2_13B_bf16_batch_sizes: tests_llama2_13b_bf16_base_mode_tests[ f"{LLAMA2_13B}-{BASE_MODE}-{W_BF16_KV_BF16}-{bs}" ] = ( test_templates[LLAMA2_13B] | test_templates[f"{LLAMA2_13B}-{W_BF16_KV_BF16}-dot-product"] | { "checkpoint": CKPT[LLAMA2_13B][BASE_MODE], "model_mode": BASE_MODE, "quant_mode": W_BF16_KV_BF16, "quantization": "", "quantize_kvcache": "false", "per_device_batch_size": bs, "kv_quant_axis": "", "run_eval": False, } ) tests_llama2_13b_bf16_chat_mode_tests[ f"{LLAMA2_13B}-{BASE_MODE}-{W_BF16_KV_BF16}-{bs}" ] = ( test_templates[LLAMA2_13B] | test_templates[f"{LLAMA2_13B}-{W_BF16_KV_BF16}-dot-product"] | { "checkpoint": CKPT[LLAMA2_13B][CHAT_MODE], "model_mode": CHAT_MODE, "quant_mode": W_BF16_KV_BF16, "quantization": "", "quantize_kvcache": "false", "per_device_batch_size": bs, "kv_quant_axis": "", "run_eval": True, } ) for bs in llama2_13B_int8_batch_sizes: tests_llama2_13b_int8_base_mode_tests[ f"{LLAMA2_13B}-{BASE_MODE}-{W_INT8_KV_INT8}-{bs}" ] = ( test_templates[LLAMA2_13B] | test_templates[f"{LLAMA2_13B}-{W_INT8_KV_INT8}-dot-product"] | { "checkpoint": CKPT[LLAMA2_13B][BASE_MODE], "model_mode": BASE_MODE, "quant_mode": W_INT8_KV_INT8, "quantization": "int8", "quantize_kvcache": "true", "kv_quant_dtype": "int8", "per_device_batch_size": bs, "kv_quant_axis": "heads_and_dkv", "run_eval": False, } ) tests_llama2_13b_int8_chat_mode_tests[ f"{LLAMA2_13B}-{CHAT_MODE}-{W_INT8_KV_INT8}-{bs}" ] = ( test_templates[LLAMA2_13B] | test_templates[f"{LLAMA2_13B}-{W_INT8_KV_INT8}-dot-product"] | { "checkpoint": CKPT[LLAMA2_13B][CHAT_MODE], "model_mode": CHAT_MODE, "quant_mode": W_INT8_KV_INT8, "quantization": "int8", "quantize_kvcache": "true", "kv_quant_dtype": "int8", "per_device_batch_size": bs, "kv_quant_axis": "heads_and_dkv", "run_eval": True, } ) tests_gemma_7b_bf16_base_mode_tests = {} tests_gemma_7b_int8_base_mode_tests = {} for bs in gemma_7B_bf16_batch_sizes: tests_gemma_7b_bf16_base_mode_tests[ f"{GEMMA_7B}-{BASE_MODE}-{W_BF16_KV_BF16}-{bs}" ] = ( test_templates[GEMMA_7B] | test_templates[f"{GEMMA_7B}-{W_BF16_KV_BF16}-autoselect"] | { "checkpoint": CKPT[GEMMA_7B][BASE_MODE], "model_mode": BASE_MODE, "quant_mode": W_BF16_KV_BF16, "quantization": "", "quantize_kvcache": "false", "per_device_batch_size": bs, "kv_quant_axis": "", "run_eval": False, } ) for bs in gemma_7B_int8_batch_sizes: tests_gemma_7b_int8_base_mode_tests[ f"{GEMMA_7B}-{BASE_MODE}-{W_INT8_KV_INT8}-{bs}" ] = ( test_templates[GEMMA_7B] | test_templates[f"{GEMMA_7B}-{W_INT8_KV_INT8}-autoselect"] | { "checkpoint": CKPT[GEMMA_7B][BASE_MODE], "model_mode": BASE_MODE, "quant_mode": W_INT8_KV_INT8, "quantization": "int8", "quantize_kvcache": "true", "kv_quant_dtype": "int8", "per_device_batch_size": bs, "kv_quant_axis": "heads_and_dkv", "run_eval": False, } ) tests = ( tests_llama2_7b_bf16_base_mode_tests | tests_llama2_7b_bf16_chat_mode_tests | tests_llama2_7b_int8_base_mode_tests | tests_llama2_7b_int8_chat_mode_tests | tests_llama2_13b_bf16_base_mode_tests | tests_llama2_13b_bf16_chat_mode_tests | tests_llama2_13b_int8_base_mode_tests | tests_llama2_13b_int8_chat_mode_tests | tests_gemma_7b_bf16_base_mode_tests | tests_gemma_7b_int8_base_mode_tests | { # LLAMA2_70B f"{LLAMA2_70B}-{CHAT_MODE}-{W_BF16_KV_BF16}": test_templates[ LLAMA2_70B ] | test_templates[f"{LLAMA2_70B}-{W_BF16_KV_BF16}-dot-product"] | { "checkpoint": CKPT[LLAMA2_70B][CHAT_MODE], "model_mode": CHAT_MODE, "quant_mode": W_BF16_KV_BF16, "quantization": "", "quantize_kvcache": "false", "per_device_batch_size": 24, "kv_quant_axis": "", "run_eval": True, }, f"{LLAMA2_70B}-{CHAT_MODE}-{W_INT8_KV_INT8}": test_templates[ LLAMA2_70B ] | test_templates[f"{LLAMA2_70B}-{W_INT8_KV_INT8}-dot-product"] | { "checkpoint": CKPT[LLAMA2_70B][CHAT_MODE], "model_mode": CHAT_MODE, "quant_mode": W_INT8_KV_INT8, "quantization": "int8", "quantize_kvcache": "true", "kv_quant_dtype": "int8", "per_device_batch_size": 48, "kv_quant_axis": "heads_and_dkv", "run_eval": True, }, # MIXTRAL_8_7B f"{MIXTRAL_8_7B}-{INSTRUCT_MODE}-{W_BF16_KV_BF16}": test_templates[ MIXTRAL_8_7B ] | test_templates[f"{MIXTRAL_8_7B}-{W_BF16_KV_BF16}-dot-product"] | { "checkpoint": CKPT[MIXTRAL_8_7B][INSTRUCT_MODE], "model_mode": INSTRUCT_MODE, "quant_mode": W_BF16_KV_BF16, "quantization": "", "quantize_kvcache": "false", "per_device_batch_size": 128, "kv_quant_axis": "", "run_eval": True, }, f"{MIXTRAL_8_7B}-{INSTRUCT_MODE}-{W_INT8_KV_INT8}": test_templates[ MIXTRAL_8_7B ] | test_templates[f"{MIXTRAL_8_7B}-{W_INT8_KV_INT8}-dot-product"] | { "checkpoint": CKPT[MIXTRAL_8_7B][INSTRUCT_MODE], "model_mode": INSTRUCT_MODE, "quant_mode": W_INT8_KV_INT8, "quantization": "int8", "quantize_kvcache": "true", "kv_quant_dtype": "int8", "per_device_batch_size": 258, "kv_quant_axis": "heads_and_dkv", "run_eval": True, }, } ) # run_configs = [ # f"{LLAMA2_7B}-{BASE_MODE}-{W_BF16_KV_BF16}", # f"{LLAMA2_7B}-{BASE_MODE}-{W_INT8_KV_INT8}", # f"{LLAMA2_7B}-{CHAT_MODE}-{W_BF16_KV_BF16}", # f"{LLAMA2_7B}-{CHAT_MODE}-{W_INT8_KV_INT8}", # f"{LLAMA2_13B}-{BASE_MODE}-{W_BF16_KV_BF16}", # f"{LLAMA2_13B}-{BASE_MODE}-{W_INT8_KV_INT8}", # f"{LLAMA2_13B}-{CHAT_MODE}-{W_BF16_KV_BF16}", # f"{LLAMA2_13B}-{CHAT_MODE}-{W_INT8_KV_INT8}", # f"{LLAMA2_70B}-{CHAT_MODE}-{W_BF16_KV_BF16}", # f"{LLAMA2_70B}-{CHAT_MODE}-{W_INT8_KV_INT8}", # f"{GEMMA_7B}-{BASE_MODE}-{W_BF16_KV_BF16}", # f"{GEMMA_7B}-{BASE_MODE}-{W_INT8_KV_INT8}", # f"{MIXTRAL_8_7B}-{INSTRUCT_MODE}-{W_BF16_KV_BF16}", # f"{MIXTRAL_8_7B}-{INSTRUCT_MODE}-{W_INT8_KV_INT8}", # ] skip_configs = [] dags = [] for model_config_name, sweep_model_configs in tests.items(): # if run_configs and model_config_name not in run_configs: # continue if skip_configs and model_config_name in skip_configs: continue for tpu_version, tpu_cores in sweep_model_configs["tpu_version_cores"]: for axis_order in sweep_model_configs["axis_order"]: for ici_parallelism in sweep_model_configs["ici_parallelisms"]: for request_rate in sweep_model_configs["request_rate"]: jetstream_benchmark_serving_kv_cache_layout = ( generate_model_configs( test_name_prefix=test_name_prefix, model_config_name=model_config_name, sweep_model_configs=sweep_model_configs, axis_order=axis_order, ici_parallelism=ici_parallelism, request_rate=request_rate, tpu_version=tpu_version, tpu_cores=tpu_cores, ) ) dags.append(jetstream_benchmark_serving_kv_cache_layout) # Cap the number of simultaneously requested v5e-8 due to resource contraints n_parallel_jobs = 4 chunks = np.array_split(dags, n_parallel_jobs) for chunk in chunks: for i in range(1, len(chunk)): chunk[i - 1] >> chunk[i]