in dags/multipod/maxtext_gpu_end_to_end.py [0:0]
def run_maxtext_tests(dag: models.DAG):
test_name_prefix = "maxtext"
timestamp = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M")
train_base = (
"XLA_PYTHON_CLIENT_MEM_FRACTION=0.65 TF_FORCE_GPU_ALLOW_GROWTH=true "
"python3 -m MaxText.train MaxText/configs/base.yml "
"base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset "
"steps=2 enable_checkpointing=false attention=dot_product"
)
decode_base = (
"XLA_PYTHON_CLIENT_MEM_FRACTION=0.65 TF_FORCE_GPU_ALLOW_GROWTH=true "
"python3 -m MaxText.decode MaxText/configs/base.yml "
"base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset "
"steps=2 enable_checkpointing=false attention=dot_product "
"max_target_length=128 per_device_batch_size=1"
)
test_models_gpu = {
"train-c4-data": (f"{train_base} run_name=runner-{timestamp}-0", 1),
"train-synthetic-data": (
f"{train_base} run_name=runner-{timestamp}-1 dataset_type=synthetic",
1,
),
"train-flash": (
f"{train_base} run_name=runner-{timestamp}-2 attention=cudnn_flash_te",
1,
),
"train-quarter-batch-size": (
f"{train_base} run_name=runner-{timestamp}-3 per_device_batch_size=0.25 ici_tensor_parallelism=4",
1,
),
"train-int8": (
f"{train_base} run_name=runner-{timestamp}-6 quantization=int8",
1,
),
"train-fp8": (
f"{train_base} run_name=runner-{timestamp}-7 quantization=fp8",
1,
),
"decode": (f"{decode_base} run_name=runner-{timestamp}-4", 1),
"decode-quarter-batch-size": (
f"{decode_base} run_name=runner-{timestamp}-5 per_device_batch_size=.25 ici_tensor_parallelism=4",
1,
),
"generate-param-only-checkpoint": (
"XLA_PYTHON_CLIENT_MEM_FRACTION=0.65 TF_FORCE_GPU_ALLOW_GROWTH=true "
f"bash end_to_end/test_generate_param_only_checkpoint.sh -r runner-{timestamp}-8 "
"-o gs://runner-maxtext-logs -d gs://maxtext-dataset -i 4 -a dot_product",
1,
),
"generate-param-only-checkpoint-int8": (
"XLA_PYTHON_CLIENT_MEM_FRACTION=0.65 TF_FORCE_GPU_ALLOW_GROWTH=true "
f"bash end_to_end/test_generate_param_only_checkpoint.sh -r runner-{timestamp}-9 "
"-o gs://runner-maxtext-logs -d gs://maxtext-dataset -i 4 -q int8 -a dot_product",
1,
),
"grain-checkpoint-determinism": (
"XLA_PYTHON_CLIENT_MEM_FRACTION=0.65 TF_FORCE_GPU_ALLOW_GROWTH=true "
"bash end_to_end/test_checkpointing.sh runner gs://runner-maxtext-logs "
"gs://maxtext-dataset False c4-array_record dot_product",
1,
),
"checkpoint-compatibility": (
"XLA_PYTHON_CLIENT_MEM_FRACTION=0.65 TF_FORCE_GPU_ALLOW_GROWTH=true "
"bash end_to_end/test_checkpoint_compatibility.sh runner "
"gs://runner-maxtext-logs gs://maxtext-dataset dot_product",
1,
),
"llama2-7b-train-1node": ("bash MaxText/configs/a3/llama_2_7b/1vm.sh", 1),
"llama2-7b-train-2node": ("bash MaxText/configs/a3/llama_2_7b/2vm.sh", 2),
"llama2-7b": ("bash end_to_end/gpu/a3/test_llama2_7b.sh", 1),
}
quarantine_task_group = TaskGroup(
group_id="Quarantine", dag=dag, prefix_group_id=False
)
for model, (test_script, nnodes) in test_models_gpu.items():
stable_a3_gpu = gke_config.get_maxtext_end_to_end_gpu_gke_test_config(
time_out_in_min=300,
test_name=f"{test_name_prefix}-stable-stack-{model}",
run_model_cmds=(test_script,),
num_slices=nnodes,
cluster=XpkClusters.GPU_A3_CLUSTER,
docker_image=DockerImage.MAXTEXT_GPU_JAX_STABLE_STACK.value,
test_owner=test_owner.YUWEI_Y,
).run_with_quarantine(quarantine_task_group)
stable_a3plus_gpu = gke_config.get_maxtext_end_to_end_gpu_gke_test_config(
time_out_in_min=300,
test_name=f"{test_name_prefix}-stable-stack-{model}",
run_model_cmds=(test_script,),
num_slices=nnodes,
cluster=XpkClusters.GPU_A3PLUS_CLUSTER,
docker_image=DockerImage.MAXTEXT_GPU_JAX_STABLE_STACK.value,
test_owner=test_owner.YUWEI_Y,
).run_with_quarantine(quarantine_task_group)
stable_a3_gpu >> stable_a3plus_gpu