dags/sparsity_diffusion_devx/maxtext_moe_tpu_e2e.py (135 lines of code) (raw):
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A DAG to run end-to-end MoE tests."""
import datetime
from airflow import models
from airflow.utils.task_group import TaskGroup
from dags import composer_env
from dags.common.quarantined_tests import QuarantineTests
from dags.common import test_owner
from dags.common.vm_resource import XpkClusters, DockerImage
from dags.multipod.configs import gke_config
from xlml.utils import name_format
# Run once a day at 1 am UTC (5 pm PST)
SCHEDULED_TIME = "0 1 * * *" if composer_env.is_prod_env() else None
with models.DAG(
dag_id="maxtext_moe_tpu_e2e",
schedule=SCHEDULED_TIME,
tags=[
"sparsity_diffusion_devx",
"multipod_team",
"maxtext",
"tpu",
"stable",
"nightly",
"mlscale_devx",
],
start_date=datetime.datetime(2024, 11, 14),
catchup=False,
) as dag:
test_name_prefix = "maxtext"
quarantine_task_group = TaskGroup(
group_id="Quarantine", dag=dag, prefix_group_id=False
)
docker_image = {
"stable": DockerImage.MAXTEXT_TPU_JAX_STABLE_STACK.value,
"nightly": DockerImage.MAXTEXT_TPU_JAX_NIGHTLY.value,
}
# Unchained tests
# TODO(ranran): add back ckpt conversation after b/384580048
test_models_tpu = {
"mixtral-8x22b": {
"script_name": "tpu/mixtral/8x22b/2_test_mixtral",
"cluster": XpkClusters.TPU_V6E_256_MLPERF_CLUSTER,
"time_out_in_min": 60,
},
}
unchained_tests = []
for model, test_scripts_details in test_models_tpu.items():
for image in docker_image.keys():
training_tpu = gke_config.get_gke_config(
time_out_in_min=test_scripts_details["time_out_in_min"],
test_name=f"{test_name_prefix}_{image}_{model}",
run_model_cmds=(
f"bash end_to_end/{test_scripts_details['script_name']}.sh",
),
docker_image=docker_image[image],
test_owner=test_owner.RAN_R,
cluster=test_scripts_details["cluster"],
).run_with_quarantine(quarantine_task_group)
unchained_tests.append(training_tpu)
# stable_tpu >> nightly_tpu
for i in range(len(unchained_tests) - 1):
unchained_tests[i] >> unchained_tests[i + 1]
# Chained tests
multicluster_test_models = {
"mixtral-8x7b": [
{
"script_name": "tpu/mixtral/8x7b/1_test_mixtral",
"cluster": XpkClusters.CPU_M1_MEGAMEM_96_CLUSTER,
"time_out_in_min": 240,
},
{
"script_name": "tpu/mixtral/8x7b/2_test_mixtral",
"cluster": XpkClusters.TPU_V6E_256_MLPERF_CLUSTER,
"time_out_in_min": 90,
},
],
}
def convert_checkpoint_and_run_training(
test_group_id,
test_name_prefix,
image,
docker_image,
model,
test_scripts_details,
):
with TaskGroup(group_id=test_group_id, prefix_group_id=False) as group:
test_name = f"{test_name_prefix}_{image}_{model}"
shared_gcs_location = name_format.generate_gcs_folder_location.override(
task_id=f"{test_group_id}_generate_gcs_folder_location"
)(
gcs_subfolder,
test_group_id,
)
conversion_cpu = gke_config.get_maxtext_cpu_end_to_end_gke_config(
time_out_in_min=test_scripts_details[0]["time_out_in_min"],
test_name=test_name,
run_model_cmds=(
f"export BASE_OUTPUT_PATH=$GCS_OUTPUT; bash end_to_end/{test_scripts_details[0]['script_name']}.sh",
),
docker_image=docker_image,
test_owner=test_owner.RAN_R,
cluster=test_scripts_details[0]["cluster"],
).run(gcs_location=shared_gcs_location)
training_tpu = gke_config.get_gke_config(
time_out_in_min=test_scripts_details[1]["time_out_in_min"],
test_name=test_name,
run_model_cmds=(
f"export BASE_OUTPUT_PATH=$GCS_OUTPUT; bash end_to_end/{test_scripts_details[1]['script_name']}.sh",
),
docker_image=docker_image,
test_owner=test_owner.RAN_R,
cluster=test_scripts_details[1]["cluster"],
).run(gcs_location=shared_gcs_location)
return conversion_cpu, training_tpu
tests = []
for model, test_scripts_details in multicluster_test_models.items():
gcs_subfolder = f"{test_owner.Team.SPARSITY_DIFFUSION_DEVX.value}/maxtext"
for image in docker_image.keys():
test_group_id = "chained_tests" + "_" + model + "_" + image
if QuarantineTests.is_quarantined(test_group_id):
with quarantine_task_group:
mode_cpu, mode_tpu = convert_checkpoint_and_run_training(
test_group_id,
test_name_prefix,
image,
docker_image[image],
model,
test_scripts_details,
)
else:
mode_cpu, mode_tpu = convert_checkpoint_and_run_training(
test_group_id,
test_name_prefix,
image,
docker_image[image],
model,
test_scripts_details,
)
tests.append(mode_cpu)
tests.append(mode_tpu)
# stable_cpu >> stable_tpu >> nightly_cpu >> nightly_tpu
for i in range(len(tests) - 1):
tests[i] >> tests[i + 1]