dags/multipod/maxtext_sft_trainer.py (44 lines of code) (raw):

# Copyright 2025 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """DAG to run MaxText SFT Trainer tests.""" import datetime from airflow import models from dags import composer_env, gcs_bucket from dags.common import test_owner from dags.common.vm_resource import DockerImage, XpkClusters from dags.multipod.configs import gke_config from dags.multipod.configs.common import SetupMode # Run once a day at 10 am UTC (2 am PST) SCHEDULED_TIME = '0 10 * * *' if composer_env.is_prod_env() else None HF_TOKEN = models.Variable.get('HF_TOKEN', None) with models.DAG( dag_id='maxtext_sft_trainer', schedule=SCHEDULED_TIME, tags=['multipod_team', 'maxtext', 'stable', 'nightly', 'mlscale_devx'], start_date=datetime.datetime(2025, 3, 1), catchup=False, concurrency=2, ) as dag: base_output_directory = f'{gcs_bucket.BASE_OUTPUT_DIR}/maxtext_sft_trainer' docker_images = [ (SetupMode.STABLE, DockerImage.MAXTEXT_TPU_JAX_STABLE_STACK), (SetupMode.NIGHTLY, DockerImage.MAXTEXT_TPU_JAX_NIGHTLY), ] for mode, image in docker_images: command = ( f'export HF_TOKEN={HF_TOKEN}', 'export PRE_TRAINED_MODEL=llama2-7b', 'export PRE_TRAINED_MODEL_TOKENIZER=meta-llama/Llama-2-7b-hf', 'export PRE_TRAINED_MODEL_CKPT_PATH=gs://maxtext-model-checkpoints/llama2-7b/2025-01-23-19-26/scanned/0/items', f'export BASE_OUTPUT_DIRECTORY={base_output_directory}', 'export STEPS=2500', 'export PROMPT="Suggest some famous landmarks in London."', 'export RTOL=1e-05', 'export ATOL=0.09', 'export KL_DIV=7e-05', 'bash end_to_end/tpu/test_sft_trainer.sh', ) maxtext_v4_configs_test = gke_config.get_gke_config( cluster=XpkClusters.TPU_V4_8_MAXTEXT_CLUSTER, time_out_in_min=60, test_name=f'sft-trainer-{mode.value}', run_model_cmds=command, docker_image=image.value, test_owner=test_owner.SURBHI_J, ).run()