dags/sparsity_diffusion_devx/project_bite_gpu_e2e.py (45 lines of code) (raw):
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A DAG to run end-to-end JAX Stable Stack tests for GCP GPUs."""
import datetime
from airflow import models
from dags import composer_env, gcs_bucket
from dags.common import test_owner
from dags.common.vm_resource import DockerImage, XpkClusters
from dags.sparsity_diffusion_devx.configs import gke_config as config
from xlml.utils import name_format
# Run once a day at 3 am UTC (7 pm PST)
SCHEDULED_TIME = "0 3 * * *" if composer_env.is_prod_env() else None
with models.DAG(
dag_id="project_bite_gpu_e2e",
schedule=SCHEDULED_TIME,
tags=[
"sparsity_diffusion_devx",
"multipod_team",
"gpu",
"axlearn",
"bite",
],
start_date=datetime.datetime(2024, 11, 12),
catchup=False,
) as dag:
current_datetime = config.get_current_datetime()
axlearn_test_configs = {
# accelerator: list of slices to test
"a3plus": [1, 2],
}
for accelerator, slices in axlearn_test_configs.items():
cores = accelerator.rsplit("-", maxsplit=1)[-1]
cluster = config.clusters[accelerator]
for slice_num in slices:
maxtext_jax_stable_stack_test = config.get_gpu_gke_test_config(
num_slices=slice_num,
cluster=cluster,
time_out_in_min=300,
run_model_cmds=(
"cd axlearn && "
"XLA_PYTHON_CLIENT_MEM_FRACTION=0.65 TF_FORCE_GPU_ALLOW_GROWTH=true "
"python -m axlearn.common.launch_trainer_main "
f"--module=text.gpt.c4_trainer --config=fuji-test-v1 "
f"--trainer_dir={gcs_bucket.BASE_OUTPUT_DIR}/bite/gpu/jax-stable-stack/automated/{current_datetime} "
f"--data_dir={gcs_bucket.AXLEARN_DIR} --jax_backend=gpu ",
),
test_name=f"axlearn-jax-nightly-{accelerator}-{slice_num}x",
docker_image=DockerImage.AXLEARN_GPU_JAX_NIGHTLY.value,
test_owner=test_owner.Maggie_Z,
).run()