dags/inference/configs/jetstream_pytorch_gce_config.py (96 lines of code) (raw):

# Copyright 2024 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Utilities to construct configs for jetstream-pytorch inference DAG.""" import datetime import json from typing import Dict from xlml.apis import gcp_config, metric_config, task, test_config from dags.common import test_owner from dags.multipod.configs import common from dags.common.vm_resource import TpuVersion, Project, RuntimeVersion PROJECT_NAME = Project.CLOUD_ML_AUTO_SOLUTIONS.value RUNTIME_IMAGE = RuntimeVersion.TPU_UBUNTU2204_BASE.value GCS_SUBFOLDER_PREFIX = test_owner.Team.INFERENCE.value def get_jetstream_pytorch_inference_nightly_config( tpu_version: TpuVersion, tpu_cores: int, tpu_zone: str, time_out_in_min: int, test_name: str, test_mode: common.SetupMode, project_name: str = PROJECT_NAME, runtime_version: str = RUNTIME_IMAGE, network: str = "default", subnetwork: str = "default", is_tpu_reserved: bool = True, num_slices: int = 1, model_configs: Dict = {}, ): job_gcp_config = gcp_config.GCPConfig( project_name=project_name, zone=tpu_zone, dataset_name=metric_config.DatasetOption.BENCHMARK_DATASET, ) set_up_cmds = ( "pip install --upgrade pip", # Create a python virtual environment "sudo apt-get -y update", "sudo apt-get -y install python3.10-venv", "sudo apt-get -y install jq", "python -m venv .env", "source .env/bin/activate", # Setup jetstream-pytorch "git clone https://github.com/google/jetstream-pytorch.git", "cd jetstream-pytorch", "source install_everything.sh", """pip install -r deps/JetStream/benchmarks/requirements.in \ -r deps/JetStream/requirements.txt """, ) additional_metadata_dict = model_configs.copy() additional_metadata_dict.pop("sleep_time") run_model_cmds = ( # Start virtual environment "source .env/bin/activate", # Get commit hash of the jetstream-pytorch and jetstream repos f"export METADATA_DICT='{json.dumps(additional_metadata_dict)}'", 'cd jetstream-pytorch && export JETSTREAM_PYTORCH_COMMIT_HASH=$(git log -1 --format="%H") && cd ..', 'cd jetstream-pytorch/deps/JetStream && export JETSTREAM_COMMIT_HASH=$(git log -1 --format="%H") && cd ../../..', 'export METADATA_DICT=$(jq -c \'. + { "jetstream_pytorch_commit_hash": $newVal}\' --arg newVal ${JETSTREAM_PYTORCH_COMMIT_HASH} <<<"$METADATA_DICT")', 'export METADATA_DICT=$(jq -c \'. + { "jetstream_commit_hash": $newVal}\' --arg newVal ${JETSTREAM_COMMIT_HASH} <<<"$METADATA_DICT")', ### Benchmark "cd jetstream-pytorch", # Configure flags f"export MODEL_NAME={model_configs['model_name']}", f"export SIZE={model_configs['size']}", f"export MODEL_ID={model_configs['model_id']}", f"export BATCH_SIZE={model_configs['batch_size']}", f"export CKPT_PATH={model_configs['checkpoint']}", f"export QUANTIZE={str(model_configs['quantize'])}", "mkdir -p /dev/shm/ckpt_dir/${MODEL_ID}/hf_original", "gsutil cp -r ${CKPT_PATH}/* /dev/shm/ckpt_dir/${MODEL_ID}/hf_original/", # Start jetstream-pytorch server in the background """jpt serve \ --model_id=${MODEL_ID} \ --working_dir=/dev/shm/ckpt_dir \ --override_batch_size=${BATCH_SIZE} \ --internal_use_local_tokenizer=True \ --quantize_weights=${QUANTIZE}&""", "pip install --force-reinstall --no-deps nltk==3.8.1", # Give server time to start f"sleep {model_configs['sleep_time']}", # Run benchmark, run eval, save benchmark and eval results, and save predictions to /tmp/request-outputs.json f"""python deps/JetStream/benchmarks/benchmark_serving.py \ --tokenizer /dev/shm/ckpt_dir/{model_configs['model_id']}/hf_original/tokenizer.model \ --model {model_configs['model_name']} \ --num-prompts {model_configs['num_prompts']} \ --dataset {model_configs['dataset']} \ --max-output-length {model_configs['max_output_length']} \ --request-rate {model_configs['request_rate']} \ --warmup-mode sampled \ --save-result \ --additional-metadata-metrics-to-save ${{METADATA_DICT}} \ --save-request-outputs \ --run-eval true""", 'export BENCHMARK_OUTPUT=$(find . -name "*JetStream*" -type f -printf "%T@ %Tc %p\n" | sort -n | head -1 | awk \'NF>1{print $NF}\')', # Stop JetStream server "kill -9 %%", # Upload results (in jsonlines format) to GCS to be post-processed into # our BigQuery table "mv ${BENCHMARK_OUTPUT} metric_report.jsonl", f"gsutil cp metric_report.jsonl {metric_config.SshEnvVars.GCS_OUTPUT.value}", ) job_test_config = test_config.TpuVmTest( test_config.Tpu( version=tpu_version, cores=tpu_cores, runtime_version=runtime_version, reserved=is_tpu_reserved, network=network, subnetwork=subnetwork, ), test_name=test_name, set_up_cmds=set_up_cmds, run_model_cmds=run_model_cmds, timeout=datetime.timedelta(minutes=time_out_in_min), task_owner=test_owner.XIANG_S, num_slices=num_slices, gcs_subfolder=f"{GCS_SUBFOLDER_PREFIX}/jetstream_pytorch", ) job_metric_config = metric_config.MetricConfig( json_lines=metric_config.JSONLinesConfig("metric_report.jsonl"), use_runtime_generated_gcs_folder=True, ) return task.run_queued_resource_test( task_test_config=job_test_config, task_gcp_config=job_gcp_config, task_metric_config=job_metric_config, )