dags/sparsity_diffusion_devx/project_bite_tpu_e2e.py (76 lines of code) (raw):
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A DAG to run all supported ML models with the latest JAX/FLAX version."""
import datetime
from airflow import models
from dags import composer_env
from dags.common import test_owner
from dags.common.vm_resource import TpuVersion, Zone, RuntimeVersion, Project
from dags.sparsity_diffusion_devx.configs import project_bite_config as config
# Run once a day at 6 pm UTC (11 am PST)
SCHEDULED_TIME = "0 18 * * *" if composer_env.is_prod_env() else None
with models.DAG(
dag_id="project_bite_tpu_e2e",
schedule=SCHEDULED_TIME,
tags=[
"sparsity_diffusion_devx",
"multipod_team",
"tpu",
"axlearn",
"bite",
],
start_date=datetime.datetime(2024, 4, 4),
catchup=False,
) as dag:
# AXLearn head against JAX head
# Runs Fuji training on v5p-8 in the provided GCP Project
jax_fuji_v5p_8 = config.get_bite_tpu_config(
tpu_version=TpuVersion.V5P,
tpu_cores=8,
tpu_zone=Zone.US_EAST5_A.value,
runtime_version=RuntimeVersion.V2_ALPHA_TPUV5.value,
project_name=Project.TPU_PROD_ENV_AUTOMATED.value,
network="mas-test",
subnetwork="mas-test",
is_tpu_reserved=True,
model_config="fuji-test-v1",
time_out_in_min=180,
task_owner=test_owner.Maggie_Z,
)
# AXLearn head against JAX head
# Runs Fuji training on v6e-8
jax_fuji_v6e_8 = config.get_bite_tpu_config(
tpu_version=TpuVersion.TRILLIUM,
tpu_cores=8,
tpu_zone=Zone.US_EAST5_B.value,
runtime_version=RuntimeVersion.V2_ALPHA_TPUV6.value,
model_config="fuji-test-v1",
time_out_in_min=180,
task_owner=test_owner.Maggie_Z,
)
# AXLearn pinned version against JAX head
# pinned_version commit: 35a189c15bdd06416be743cecde272693363ce3c
# pinned_version PR: https://github.com/apple/axlearn/pull/1033
# pinned_version date: March 6, 2025 (test succeeds at this commit)
jax_pinned_fuji_v6e_8 = config.get_bite_tpu_config(
tpu_version=TpuVersion.TRILLIUM,
tpu_cores=8,
tpu_zone=Zone.US_EAST5_B.value,
runtime_version=RuntimeVersion.V2_ALPHA_TPUV6.value,
model_config="fuji-test-v1",
pinned_version="35a189c15bdd06416be743cecde272693363ce3c",
time_out_in_min=180,
task_owner=test_owner.Maggie_Z,
)
default_unittest_args = {
"retries": 0,
}
with models.DAG(
dag_id="project_bite_tpu_unittests",
schedule=SCHEDULED_TIME,
tags=[
"sparsity_diffusion_devx",
"tpu",
"axlearn",
"bite",
],
start_date=datetime.datetime(2025, 2, 24),
catchup=False,
default_args=default_unittest_args,
) as dag:
unittests = config.get_bite_tpu_unittests_config(
tpu_version=TpuVersion.TRILLIUM,
tpu_cores=4,
tpu_zone=Zone.EUROPE_WEST4_A.value,
runtime_version=RuntimeVersion.V2_ALPHA_TPUV6.value,
time_out_in_min=180,
task_owner=test_owner.Andrew_S,
)