dags/sparsity_diffusion_devx/jax_stable_stack_tpu_e2e.py (114 lines of code) (raw):

# Copyright 2024 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """A DAG to run end-to-end JAX Stable Stack TPU tests.""" import datetime from airflow import models from airflow.utils.task_group import TaskGroup from dags import composer_env, gcs_bucket from dags.common import test_owner from dags.common.vm_resource import Project, TpuVersion, CpuVersion, Zone, DockerImage, GpuVersion, XpkClusters from dags.sparsity_diffusion_devx.configs import gke_config as config from dags.multipod.configs.common import SetupMode from xlml.utils import name_format # Run once a day at 3 am UTC (7 pm PST) SCHEDULED_TIME = "0 3 * * *" if composer_env.is_prod_env() else None with models.DAG( dag_id="jax_stable_stack_tpu_e2e", schedule=SCHEDULED_TIME, tags=[ "sparsity_diffusion_devx", "multipod_team", "maxtext", "maxdiffusion", "axlearn", "tpu", "jax-stable-stack", "mlscale_devx", ], start_date=datetime.datetime(2024, 6, 7), catchup=False, ) as dag: current_datetime = config.get_current_datetime() maxtext_test_configs = { # accelerator: list of slices to test "v4-16": [1, 2], "v6e-256": [1], } maxdiffusion_test_configs = { # accelerator: list of slices to test "v4-8": [1], "v6e-256": [1], } axlearn_test_configs = { # accelerator: list of slices to test "v4-16": [1], } quarantine_task_group = TaskGroup( group_id="Quarantine", dag=dag, prefix_group_id=False ) maxtext_docker_images = [ (SetupMode.STABLE, DockerImage.MAXTEXT_TPU_JAX_STABLE_STACK_CANDIDATE), (SetupMode.NIGHTLY, DockerImage.MAXTEXT_TPU_STABLE_STACK_NIGHTLY_JAX), ] maxdiffusion_docker_images = [ ( SetupMode.STABLE, DockerImage.MAXDIFFUSION_TPU_JAX_STABLE_STACK, ), ( SetupMode.NIGHTLY, DockerImage.MAXDIFFUSION_TPU_STABLE_STACK_NIGHTLY_JAX, ), ] for accelerator, slices in maxtext_test_configs.items(): cores = accelerator.rsplit("-", maxsplit=1)[-1] cluster = config.clusters[accelerator] for slice_num in slices: for mode, image in maxtext_docker_images: maxtext_jax_stable_stack_test = config.get_gke_config( num_slices=slice_num, cluster=cluster, time_out_in_min=60, run_model_cmds=( f"JAX_PLATFORMS=tpu,cpu ENABLE_PJRT_COMPATIBILITY=true TPU_SLICE_BUILDER_DUMP_CHIP_FORCE=true TPU_SLICE_BUILDER_DUMP_ICI=true JAX_FORCE_TPU_INIT=true ENABLE_TPUNETD_CLIENT=true && " f"python -m MaxText.train MaxText/configs/base.yml run_name={slice_num}slice-V{cluster.device_version}_{cores}-maxtext-jax-stable-stack-{current_datetime} " "steps=30 per_device_batch_size=1 max_target_length=4096 model_name=llama2-7b " "enable_checkpointing=false attention=dot_product remat_policy=minimal_flash use_iota_embed=true scan_layers=false " "dataset_type=synthetic async_checkpointing=false " f"base_output_directory={gcs_bucket.BASE_OUTPUT_DIR}/maxtext/jax-stable-stack/automated/{current_datetime}", ), test_name=f"maxtext-jax-stable-stack-{mode.value}-{accelerator}-{slice_num}x", docker_image=image.value, test_owner=test_owner.PARAM_B, ).run_with_quarantine(quarantine_task_group) for accelerator, slices in maxdiffusion_test_configs.items(): cores = accelerator.rsplit("-", maxsplit=1)[-1] cluster = config.clusters[accelerator] for slice_num in slices: for mode, image in maxdiffusion_docker_images: maxdiffusion_jax_stable_stack_test = config.get_gke_config( num_slices=slice_num, cluster=cluster, time_out_in_min=60, run_model_cmds=( f"JAX_PLATFORMS=tpu,cpu ENABLE_PJRT_COMPATIBILITY=true TPU_SLICE_BUILDER_DUMP_CHIP_FORCE=true TPU_SLICE_BUILDER_DUMP_ICI=true JAX_FORCE_TPU_INIT=true ENABLE_TPUNETD_CLIENT=true && " f"pip install . && python src/maxdiffusion/train.py src/maxdiffusion/configs/base_2_base.yml " f"run_name={slice_num}slice-V{cluster.device_version}_{cores}-maxdiffusion-jax-stable-stack-{current_datetime} " f"output_dir={gcs_bucket.BASE_OUTPUT_DIR}/maxdiffusion-jax-stable-stack-{mode.value}-{accelerator}-{slice_num}/automated/{current_datetime}", ), test_name=f"maxdiffusion-jax-stable-stack-{mode.value}-{accelerator}-{slice_num}x", docker_image=DockerImage.MAXDIFFUSION_TPU_JAX_STABLE_STACK.value, test_owner=test_owner.PARAM_B, ).run_with_quarantine(quarantine_task_group) for accelerator, slices in axlearn_test_configs.items(): cores = accelerator.rsplit("-", maxsplit=1)[-1] cluster = config.clusters[accelerator] for slice_num in slices: axlearn_jax_stable_stack_test = config.get_gke_config( num_slices=slice_num, cluster=cluster, time_out_in_min=300, run_model_cmds=( "JAX_PLATFORMS=tpu,cpu ENABLE_PJRT_COMPATIBILITY=true TPU_SLICE_BUILDER_DUMP_CHIP_FORCE=true TPU_SLICE_BUILDER_DUMP_ICI=true JAX_FORCE_TPU_INIT=true ENABLE_TPUNETD_CLIENT=true && " "cd axlearn && python -m axlearn.common.launch_trainer_main " f"--module=text.gpt.c4_trainer --config=fuji-test-v1 " f"--trainer_dir={gcs_bucket.BASE_OUTPUT_DIR}/bite/jax-stable-stack/automated/{current_datetime} " f"--data_dir={gcs_bucket.AXLEARN_DIR} --jax_backend=tpu ", ), test_name=f"axlearn-jax-stable-stack-{accelerator}-{slice_num}x", docker_image=DockerImage.AXLEARN_TPU_JAX_STABLE_STACK.value, test_owner=test_owner.PARAM_B, ).run_with_quarantine(quarantine_task_group)