def run_maxtext_tests()

in dags/multipod/maxtext_gpu_end_to_end.py [0:0]


def run_maxtext_tests(dag: models.DAG):
  test_name_prefix = "maxtext"

  timestamp = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M")
  train_base = (
      "XLA_PYTHON_CLIENT_MEM_FRACTION=0.65 TF_FORCE_GPU_ALLOW_GROWTH=true "
      "python3 -m MaxText.train MaxText/configs/base.yml "
      "base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset "
      "steps=2 enable_checkpointing=false attention=dot_product"
  )
  decode_base = (
      "XLA_PYTHON_CLIENT_MEM_FRACTION=0.65 TF_FORCE_GPU_ALLOW_GROWTH=true "
      "python3 -m MaxText.decode MaxText/configs/base.yml "
      "base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset "
      "steps=2 enable_checkpointing=false attention=dot_product "
      "max_target_length=128 per_device_batch_size=1"
  )
  test_models_gpu = {
      "train-c4-data": (f"{train_base} run_name=runner-{timestamp}-0", 1),
      "train-synthetic-data": (
          f"{train_base} run_name=runner-{timestamp}-1 dataset_type=synthetic",
          1,
      ),
      "train-flash": (
          f"{train_base} run_name=runner-{timestamp}-2 attention=cudnn_flash_te",
          1,
      ),
      "train-quarter-batch-size": (
          f"{train_base} run_name=runner-{timestamp}-3 per_device_batch_size=0.25 ici_tensor_parallelism=4",
          1,
      ),
      "train-int8": (
          f"{train_base} run_name=runner-{timestamp}-6 quantization=int8",
          1,
      ),
      "train-fp8": (
          f"{train_base} run_name=runner-{timestamp}-7 quantization=fp8",
          1,
      ),
      "decode": (f"{decode_base} run_name=runner-{timestamp}-4", 1),
      "decode-quarter-batch-size": (
          f"{decode_base} run_name=runner-{timestamp}-5 per_device_batch_size=.25 ici_tensor_parallelism=4",
          1,
      ),
      "generate-param-only-checkpoint": (
          "XLA_PYTHON_CLIENT_MEM_FRACTION=0.65 TF_FORCE_GPU_ALLOW_GROWTH=true "
          f"bash end_to_end/test_generate_param_only_checkpoint.sh -r runner-{timestamp}-8 "
          "-o gs://runner-maxtext-logs -d gs://maxtext-dataset -i 4 -a dot_product",
          1,
      ),
      "generate-param-only-checkpoint-int8": (
          "XLA_PYTHON_CLIENT_MEM_FRACTION=0.65 TF_FORCE_GPU_ALLOW_GROWTH=true "
          f"bash end_to_end/test_generate_param_only_checkpoint.sh -r runner-{timestamp}-9 "
          "-o gs://runner-maxtext-logs -d gs://maxtext-dataset -i 4 -q int8 -a dot_product",
          1,
      ),
      "grain-checkpoint-determinism": (
          "XLA_PYTHON_CLIENT_MEM_FRACTION=0.65 TF_FORCE_GPU_ALLOW_GROWTH=true "
          "bash end_to_end/test_checkpointing.sh runner gs://runner-maxtext-logs "
          "gs://maxtext-dataset False c4-array_record dot_product",
          1,
      ),
      "checkpoint-compatibility": (
          "XLA_PYTHON_CLIENT_MEM_FRACTION=0.65 TF_FORCE_GPU_ALLOW_GROWTH=true "
          "bash end_to_end/test_checkpoint_compatibility.sh runner "
          "gs://runner-maxtext-logs gs://maxtext-dataset dot_product",
          1,
      ),
      "llama2-7b-train-1node": ("bash MaxText/configs/a3/llama_2_7b/1vm.sh", 1),
      "llama2-7b-train-2node": ("bash MaxText/configs/a3/llama_2_7b/2vm.sh", 2),
      "llama2-7b": ("bash end_to_end/gpu/a3/test_llama2_7b.sh", 1),
  }

  quarantine_task_group = TaskGroup(
      group_id="Quarantine", dag=dag, prefix_group_id=False
  )

  for model, (test_script, nnodes) in test_models_gpu.items():
    stable_a3_gpu = gke_config.get_maxtext_end_to_end_gpu_gke_test_config(
        time_out_in_min=300,
        test_name=f"{test_name_prefix}-stable-stack-{model}",
        run_model_cmds=(test_script,),
        num_slices=nnodes,
        cluster=XpkClusters.GPU_A3_CLUSTER,
        docker_image=DockerImage.MAXTEXT_GPU_JAX_STABLE_STACK.value,
        test_owner=test_owner.YUWEI_Y,
    ).run_with_quarantine(quarantine_task_group)
    stable_a3plus_gpu = gke_config.get_maxtext_end_to_end_gpu_gke_test_config(
        time_out_in_min=300,
        test_name=f"{test_name_prefix}-stable-stack-{model}",
        run_model_cmds=(test_script,),
        num_slices=nnodes,
        cluster=XpkClusters.GPU_A3PLUS_CLUSTER,
        docker_image=DockerImage.MAXTEXT_GPU_JAX_STABLE_STACK.value,
        test_owner=test_owner.YUWEI_Y,
    ).run_with_quarantine(quarantine_task_group)
    stable_a3_gpu >> stable_a3plus_gpu