dags/inference/maxtext_model_config_generator.py (112 lines of code) (raw):
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A helper to generate maxtext model configs."""
from dags.common.vm_resource import TpuVersion, Zone, Project, V5_NETWORKS, V5E_SUBNETWORKS, V5P_SUBNETWORKS, RuntimeVersion, V6E_GCE_NETWORK, V6E_GCE_SUBNETWORK
from dags.inference.configs import jetstream_benchmark_serving_gce_config
from dags.multipod.configs.common import SetupMode
def generate_model_configs(
test_name_prefix,
model_config_name,
sweep_model_configs,
axis_order,
ici_parallelism,
request_rate,
tpu_version,
tpu_cores,
):
model_configs = {}
model_configs["model_config_name"] = model_config_name
(
compute_axis_order,
prefill_cache_axis_order,
ar_cache_axis_order,
) = axis_order.split("-")
compute_axis_order = ",".join(compute_axis_order)
prefill_cache_axis_order = ",".join(prefill_cache_axis_order)
ar_cache_axis_order = ",".join(ar_cache_axis_order)
model_configs["compute_axis_order"] = compute_axis_order
model_configs["prefill_cache_axis_order"] = prefill_cache_axis_order
model_configs["ar_cache_axis_order"] = ar_cache_axis_order
(
model_configs["ici_fsdp_parallelism"],
model_configs["ici_autoregressive_parallelism"],
model_configs["ici_tensor_parallelism"],
) = ici_parallelism
model_configs["request_rate"] = request_rate
model_configs["maxtext_branch"] = sweep_model_configs["maxtext_branch"]
model_configs["jetstream_branch"] = sweep_model_configs["jetstream_branch"]
model_configs["model_name"] = sweep_model_configs["model_name"]
model_configs["model_mode"] = sweep_model_configs["model_mode"]
model_configs["quant_mode"] = sweep_model_configs["quant_mode"]
model_configs["sleep_time"] = sweep_model_configs["sleep_time"]
model_configs["tokenizer"] = sweep_model_configs["tokenizer"]
model_configs["weight_dtype"] = sweep_model_configs["weight_dtype"]
model_configs["scan_layers"] = sweep_model_configs["scan_layers"]
model_configs["max_prefill_predict_length"] = sweep_model_configs[
"max_prefill_predict_length"
]
model_configs["max_target_length"] = sweep_model_configs["max_target_length"]
model_configs["attention"] = sweep_model_configs["attention"]
model_configs["reshape_q"] = sweep_model_configs["reshape_q"]
model_configs["per_device_batch_size"] = sweep_model_configs[
"per_device_batch_size"
]
model_configs["checkpoint"] = sweep_model_configs["checkpoint"]
model_configs["quantization"] = sweep_model_configs["quantization"]
model_configs["quantize_kvcache"] = sweep_model_configs["quantize_kvcache"]
model_configs["kv_quant_dtype"] = sweep_model_configs.get(
"kv_quant_dtype", ""
)
model_configs["kv_quant_axis"] = sweep_model_configs["kv_quant_axis"]
model_configs["dataset"] = sweep_model_configs["dataset"]
model_configs["dataset_path"] = sweep_model_configs.get("dataset_path", "")
model_configs["num_prompts"] = sweep_model_configs["num_prompts"]
model_configs["max_output_length"] = sweep_model_configs["max_output_length"]
model_configs["warmup_mode"] = sweep_model_configs["warmup_mode"]
model_configs["run_eval"] = sweep_model_configs["run_eval"]
per_device_batch_size = model_configs["per_device_batch_size"]
attention = model_configs["attention"][:3]
kv_quant_axis = "".join(
[axis for axis in model_configs["kv_quant_axis"].split("_")]
)
test_run_tag = (
model_config_name
if not kv_quant_axis
else f"{model_config_name}-{kv_quant_axis}"
)
test_run_tag = f"{test_run_tag}-rate{str(request_rate).replace('.', '_')}-pdbs{per_device_batch_size}-{attention}-{compute_axis_order.replace(',', '')}-{prefill_cache_axis_order.replace(',', '')}-{ar_cache_axis_order.replace(',', '')}"
test_name = f"{test_name_prefix}-{test_run_tag}"
if tpu_version == TpuVersion.V5E:
# v5e benchmarks
project_name = Project.TPU_PROD_ENV_AUTOMATED.value
zone = Zone.US_EAST1_C.value
network = V5_NETWORKS
subnetwork = V5E_SUBNETWORKS
runtime_version = RuntimeVersion.V2_ALPHA_TPUV5_LITE.value
elif tpu_version == TpuVersion.V5P:
zone = Zone.US_EAST5_A.value
runtime_version = RuntimeVersion.V2_ALPHA_TPUV5.value
project_name = Project.TPU_PROD_ENV_AUTOMATED.value
network = V5_NETWORKS
subnetwork = V5P_SUBNETWORKS
elif tpu_version == TpuVersion.TRILLIUM:
zone = Zone.EUROPE_WEST4_A.value
runtime_version = RuntimeVersion.V2_ALPHA_TPUV6.value
project_name = Project.CLOUD_ML_AUTO_SOLUTIONS.value
network = V6E_GCE_NETWORK
subnetwork = V6E_GCE_SUBNETWORK
jetstream_benchmark_serving = (
jetstream_benchmark_serving_gce_config.get_config(
tpu_version=tpu_version,
tpu_cores=tpu_cores,
tpu_zone=zone,
time_out_in_min=sweep_model_configs["time_out_in_min"],
test_name=test_name,
test_mode=SetupMode.STABLE,
project_name=project_name,
runtime_version=runtime_version,
network=network,
subnetwork=subnetwork,
is_tpu_reserved=True,
model_configs=model_configs,
maxtext_branch=model_configs["maxtext_branch"],
jetstream_branch=model_configs["jetstream_branch"],
)
)
return jetstream_benchmark_serving