# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
A DAG to run MaxText checkpointing tests.
"""
import datetime
from airflow import models
from dags import composer_env, gcs_bucket
from dags.common import test_owner
from dags.common.vm_resource import TpuVersion, Zone, DockerImage, XpkClusters
from dags.multipod.configs import gke_config
from dags.multipod.configs.common import SetupMode

# Run once a day at 10 am UTC (2 am PST)
SCHEDULED_TIME = "0 10 * * *" if composer_env.is_prod_env() else None

with models.DAG(
    dag_id="maxtext_checkpointing",
    schedule=SCHEDULED_TIME,
    tags=[
        "multipod_team",
        "maxtext",
        "stable",
        "nightly",
        "mlscale_devx",
    ],
    start_date=datetime.datetime(2024, 3, 1),
    catchup=False,
    concurrency=2,
) as dag:
  base_output_directory = f"{gcs_bucket.BASE_OUTPUT_DIR}/maxtext_checkpointing"
  dataset_path = gcs_bucket.MAXTEXT_DIR
  current_time = datetime.datetime.now()
  current_datetime = current_time.strftime("%Y-%m-%d-%H-%M-%S")
  docker_images = [
      (SetupMode.STABLE, DockerImage.MAXTEXT_TPU_JAX_STABLE_STACK),
      (SetupMode.NIGHTLY, DockerImage.MAXTEXT_TPU_JAX_NIGHTLY),
  ]
  test_configs = {
      # accelerator: list of slices to test
      "v4-8": [1],
      "v4-16": [1, 2],
  }
  clusters = {
      # accelerator: cluster name
      "v4-8": XpkClusters.TPU_V4_8_MAXTEXT_CLUSTER,
      "v4-16": XpkClusters.TPU_V4_16_CLUSTER,
  }

  for mode, image in docker_images:
    for accelerator, slices in test_configs.items():
      cores = accelerator.rsplit("-", maxsplit=1)[-1]
      for slice_num in slices:
        for chkpt_mode in ["sync", "async"]:
          async_checkpointing = chkpt_mode == "async"
          run_name = f" checkpointing-{mode.value}-{slice_num}x-{accelerator}-{chkpt_mode}-{current_datetime}"
          command = (
              "bash end_to_end/test_checkpointing.sh"
              f" {run_name} {base_output_directory} {dataset_path}"
              f" true tfds autoselected {async_checkpointing}",
          )
          maxtext_v4_configs_test = gke_config.get_gke_config(
              num_slices=slice_num,
              cluster=clusters[accelerator],
              time_out_in_min=60,
              test_name=f"maxtext-checkpointing-{mode.value}-{chkpt_mode}",
              run_model_cmds=command,
              docker_image=image.value,
              test_owner=test_owner.SURBHI_J,
          ).run()
