in dags/sparsity_diffusion_devx/maxtext_moe_gpu_e2e.py [0:0]
def run_maxtext_tests():
test_name_prefix = "maxtext"
test_models_gpu = {
"mixtral-8x7b-1node": (
f"SCANNED_CHECKPOINT={SCANNED_CHECKPOINT} \
UNSCANNED_CKPT_PATH={UNSCANNED_CKPT_PATH} \
bash end_to_end/gpu/mixtral/test_8x7b.sh",
1,
),
"mixtral-8x7b-2node": (
f"SCANNED_CHECKPOINT={SCANNED_CHECKPOINT} \
UNSCANNED_CKPT_PATH={UNSCANNED_CKPT_PATH} \
bash end_to_end/gpu/mixtral/test_8x7b.sh",
2,
),
}
for model, (test_script, nnodes) in test_models_gpu.items():
pinned_a3plus_gpu = gke_config.get_maxtext_end_to_end_gpu_gke_test_config(
time_out_in_min=90,
test_name=f"{test_name_prefix}-pinned-{model}",
run_model_cmds=(test_script,),
num_slices=nnodes,
cluster=XpkClusters.GPU_A3PLUS_CLUSTER,
docker_image=DockerImage.MAXTEXT_GPU_JAX_PINNED.value,
test_owner=test_owner.MICHELLE_Y,
).run()
stable_a3plus_gpu = gke_config.get_maxtext_end_to_end_gpu_gke_test_config(
time_out_in_min=90,
test_name=f"{test_name_prefix}-stable-{model}",
run_model_cmds=(test_script,),
num_slices=nnodes,
cluster=XpkClusters.GPU_A3PLUS_CLUSTER,
docker_image=DockerImage.MAXTEXT_GPU_JAX_STABLE_STACK.value,
test_owner=test_owner.MICHELLE_Y,
).run()
pinned_a3plus_gpu >> stable_a3plus_gpu