dags/mlcompass/maxtext_gke.py (84 lines of code) (raw):

# Copyright 2024 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """This Airflow DAG runs a maxtext machine learning benchmark on a GKE cluster Usage: gcloud composer environments run ml-automation-solutions \ --project=cloud-ml-auto-solutions \ --location=us-central1 dags trigger \ -- \ mlcompass_maxtext_gke \ --conf={\\\"uuid\\\":\\\"abc\\\"} """ import datetime import json from airflow import models from airflow.decorators import task from airflow.providers.google.cloud.hooks.gcs import GCSHook from dags.common import test_owner from xlml.utils import xpk with models.DAG( dag_id="mlcompass_maxtext_gke", schedule=None, tags=["mlcompass", "maxtext"], start_date=datetime.datetime(2024, 9, 1), catchup=False, params={ "uuid": "", }, default_args={ "retries": 0, }, ) as dag: @task.python(multiple_outputs=True) def load_xlml_state(params: dict = None): dag.log.info(params) uuid = params["uuid"] if not uuid: raise RuntimeError("uuid is not set") gcs_hook = GCSHook() file_content = gcs_hook.download( "mlcompass-jax-artifacts", f"xlml/{uuid}/xlml_state.json" ) return json.loads(file_content) xlml_state = load_xlml_state() cluster_name = xlml_state["cluster_name"] cluster_project = xlml_state["cluster_project"] cluster_region = xlml_state["cluster_region"] cluster_zone = xlml_state["cluster_zone"] benchmark_id = xlml_state["test_name"] docker_image_path = xlml_state["docker_image_path"] accelerator_type = xlml_state["accelerator_type"] num_slices = xlml_state["num_slices"] model_name = xlml_state["model_name"] workdir_bucket = xlml_state["workdir_bucket"] workdir_path = xlml_state["workdir_path"] gcs_path = f"gs://{workdir_bucket}/{workdir_path}" workload_id = f'mlc-{xlml_state["uuid"]}' workload_provision_timeout = datetime.timedelta(minutes=300).total_seconds() workload_run_timeout = datetime.timedelta(minutes=60).total_seconds() run_workload = xpk.run_workload.override(owner=test_owner.ORTI_B)( task_id="run_workload", cluster_project=cluster_project, zone=cluster_zone, cluster_name=cluster_name, benchmark_id=benchmark_id, workload_id=workload_id, gcs_path=gcs_path, docker_image=docker_image_path, accelerator_type=accelerator_type, run_cmds=f"source benchmark_run.sh;run {model_name} {gcs_path}", num_slices=num_slices, use_vertex_tensorboard=False, use_pathways=False, ) wait_for_workload_start = xpk.wait_for_workload_start.override( timeout=workload_provision_timeout )( workload_id=workload_id, project_id=cluster_project, region=cluster_region, cluster_name=cluster_name, ) wait_for_workload_completion = xpk.wait_for_workload_completion.override( timeout=workload_run_timeout )( workload_id=workload_id, project_id=cluster_project, region=cluster_region, cluster_name=cluster_name, ) ( xlml_state >> run_workload >> wait_for_workload_start >> wait_for_workload_completion )