def run_maxtext_tests()

in dags/sparsity_diffusion_devx/maxtext_moe_gpu_e2e.py [0:0]


def run_maxtext_tests():
  test_name_prefix = "maxtext"

  test_models_gpu = {
      "mixtral-8x7b-1node": (
          f"SCANNED_CHECKPOINT={SCANNED_CHECKPOINT} \
            UNSCANNED_CKPT_PATH={UNSCANNED_CKPT_PATH} \
            bash end_to_end/gpu/mixtral/test_8x7b.sh",
          1,
      ),
      "mixtral-8x7b-2node": (
          f"SCANNED_CHECKPOINT={SCANNED_CHECKPOINT} \
            UNSCANNED_CKPT_PATH={UNSCANNED_CKPT_PATH} \
            bash end_to_end/gpu/mixtral/test_8x7b.sh",
          2,
      ),
  }

  for model, (test_script, nnodes) in test_models_gpu.items():
    pinned_a3plus_gpu = gke_config.get_maxtext_end_to_end_gpu_gke_test_config(
        time_out_in_min=90,
        test_name=f"{test_name_prefix}-pinned-{model}",
        run_model_cmds=(test_script,),
        num_slices=nnodes,
        cluster=XpkClusters.GPU_A3PLUS_CLUSTER,
        docker_image=DockerImage.MAXTEXT_GPU_JAX_PINNED.value,
        test_owner=test_owner.MICHELLE_Y,
    ).run()
    stable_a3plus_gpu = gke_config.get_maxtext_end_to_end_gpu_gke_test_config(
        time_out_in_min=90,
        test_name=f"{test_name_prefix}-stable-{model}",
        run_model_cmds=(test_script,),
        num_slices=nnodes,
        cluster=XpkClusters.GPU_A3PLUS_CLUSTER,
        docker_image=DockerImage.MAXTEXT_GPU_JAX_STABLE_STACK.value,
        test_owner=test_owner.MICHELLE_Y,
    ).run()
    pinned_a3plus_gpu >> stable_a3plus_gpu