dags/inference/jetstream_pytorch_inference.py (156 lines of code) (raw):
"""A DAG to run jetstream-pytorch inference benchmarks with nightly version."""
import datetime
from airflow import models
from airflow.models.baseoperator import chain
from dags import composer_env
from dags.common import test_owner
from dags.common.vm_resource import TpuVersion, Zone, Project, V5_NETWORKS, V5E_SUBNETWORKS, V5P_SUBNETWORKS, RuntimeVersion, V6E_GCE_NETWORK, V6E_GCE_SUBNETWORK
from dags.inference.configs import jetstream_pytorch_gce_config
from dags.multipod.configs.common import SetupMode, Platform
import numpy as np
# Run once a day at 4 am UTC (8 pm PST)
SCHEDULED_TIME = "0 4 * * *" if composer_env.is_prod_env() else None
with models.DAG(
dag_id="jetstream_pytorch_inference",
schedule=SCHEDULED_TIME,
tags=["inference_team", "jetstream_pytorch", "nightly"],
start_date=datetime.datetime(2024, 1, 19),
catchup=False,
) as dag:
test_name_prefix = "jetstream-pytorch-inference"
test_models = {
"llama3-8b": {
"model_name": "llama-3",
"size": "8b",
"model_id": "meta-llama/Meta-Llama-3-8B-Instruct",
"sleep_time": 120,
"tpu_version_cores": [(TpuVersion.V5E, 8), (TpuVersion.TRILLIUM, 8)],
"checkpoint": "gs://inference-benchmarks/models/llama3-8b-instruct/pytorch/llama3-8b-instruct-hf",
"dataset": "openorca",
"batch_sizes": [8, 32, 64, 128],
"request_rate": 100,
"num_prompts": 1000,
"max_output_length": 1024,
"quantize": [True, False],
},
"llama2-7b": {
"model_name": "llama-2",
"size": "7b",
"model_id": "meta-llama/Llama-2-7b-chat-hf",
"sleep_time": 120,
"tpu_version_cores": [(TpuVersion.V5E, 8), (TpuVersion.TRILLIUM, 8)],
"checkpoint": "gs://inference-benchmarks/models/llama2-7b-chat/pytorch/llama-2-7b-chat-hf",
"dataset": "openorca",
"batch_sizes": [8, 32, 64, 96, 128],
"request_rate": 100,
"num_prompts": 1000,
"max_output_length": 1024,
"quantize": [True, False],
},
"gemma-7b": {
"model_name": "gemma",
"size": "7b",
"model_id": "google/gemma-7b",
"sleep_time": 120,
"tpu_version_cores": [(TpuVersion.V5E, 8), (TpuVersion.TRILLIUM, 8)],
"checkpoint": "gs://inference-benchmarks/models/gemma-7b-it/pytorch/gemma-7b-it-hf",
"dataset": "openorca",
"tokenizer": "tokenizer.model",
"batch_sizes": [8, 32, 64, 128],
"request_rate": 100,
"num_prompts": 1000,
"max_output_length": 1024,
"quantize": [True, False],
},
"llama2-13b": {
"model_name": "llama-2",
"size": "13b",
"model_id": "meta-llama/Llama-2-13b-chat-hf",
"sleep_time": 120,
"tpu_version_cores": [(TpuVersion.V5E, 8), (TpuVersion.TRILLIUM, 8)],
"checkpoint": "gs://inference-benchmarks/models/llama2-13b-chat/pytorch/llama-2-13b-chat-hf",
"dataset": "openorca",
"tokenizer": "tokenizer.llama2",
"batch_sizes": [8, 32, 64, 96],
"request_rate": 100,
"num_prompts": 1000,
"max_output_length": 1024,
"quantize": [True, False],
},
"llama2-70b": {
"model_name": "llama-2",
"size": "70b",
"model_id": "meta-llama/Llama-2-70b-chat-hf",
"sleep_time": 120,
"tpu_version_cores": [(TpuVersion.V5E, 8), (TpuVersion.TRILLIUM, 8)],
"checkpoint": "gs://inference-benchmarks/models/llama2-70b-chat/pytorch/llama-2-70b-chat-hf",
"dataset": "openorca",
"tokenizer": "tokenizer.model",
"batch_sizes": [8, 32, 64, 96],
"request_rate": 100,
"num_prompts": 1000,
"max_output_length": 1024,
"quantize": [True],
},
}
skip_settings = (
("llama-2", "13b", 96, "False"),
("llama-2", "7b", 128, "False"),
)
dags = []
for model, sweep_model_configs in test_models.items():
for batch_size in sweep_model_configs["batch_sizes"]:
for quantize in sweep_model_configs["quantize"]:
for tpu_version, tpu_cores in sweep_model_configs["tpu_version_cores"]:
# Set batch_size to a single value, not a list
model_configs = {}
model_configs["model_name"] = sweep_model_configs["model_name"]
model_configs["size"] = sweep_model_configs["size"]
model_configs["model_id"] = sweep_model_configs["model_id"]
model_configs["sleep_time"] = sweep_model_configs["sleep_time"]
model_configs["checkpoint"] = sweep_model_configs["checkpoint"]
model_configs["dataset"] = sweep_model_configs["dataset"]
model_configs["batch_size"] = batch_size
model_configs["per_device_batch_size"] = batch_size // tpu_cores
model_configs["request_rate"] = sweep_model_configs["request_rate"]
model_configs["num_prompts"] = sweep_model_configs["num_prompts"]
model_configs["quantize"] = str(quantize)
model_configs["max_output_length"] = sweep_model_configs[
"max_output_length"
]
# Llama-2 13b unquantized with bs 96 cannot hold in v5e-8
if (
model_configs["model_name"],
model_configs["size"],
model_configs["batch_size"],
model_configs["quantize"],
) in skip_settings:
continue
# v5e e2e test with benchmarks
if tpu_version == TpuVersion.TRILLIUM:
project_name = Project.CLOUD_ML_AUTO_SOLUTIONS.value
zone = Zone.EUROPE_WEST4_A.value
network = V6E_GCE_NETWORK
subnetwork = V6E_GCE_SUBNETWORK
runtime_version = RuntimeVersion.V2_ALPHA_TPUV6.value
else:
project_name = Project.TPU_PROD_ENV_AUTOMATED.value
zone = Zone.US_EAST1_C.value
network = V5_NETWORKS
subnetwork = V5E_SUBNETWORKS
runtime_version = RuntimeVersion.V2_ALPHA_TPUV5_LITE.value
jetstream_pytorch_nightly_1slice = jetstream_pytorch_gce_config.get_jetstream_pytorch_inference_nightly_config(
tpu_version=tpu_version,
tpu_cores=tpu_cores,
tpu_zone=zone,
runtime_version=runtime_version,
project_name=project_name,
time_out_in_min=60,
is_tpu_reserved=True,
test_name=f"{test_name_prefix}-nightly-{model}-batch_size-{batch_size}-quantized-{quantize}",
test_mode=SetupMode.NIGHTLY,
network=network,
subnetwork=subnetwork,
model_configs=model_configs,
)
dags.append(jetstream_pytorch_nightly_1slice)
n_parallel_jobs = 10
chunks = np.array_split(dags, n_parallel_jobs)
for chunk in chunks:
for i in range(1, len(chunk)):
chunk[i - 1] >> chunk[i]