in dags/multipod/maxtext_configs_aot_hybridsim.py [0:0]
def hybridsim_compile_and_run(test_group_id):
with TaskGroup(group_id=test_group_id, prefix_group_id=False) as group:
gcs_subfolder = f"{test_owner.Team.MULTIPOD.value}/maxtext"
shared_gcs_location = name_format.generate_gcs_folder_location.override(
task_id=f"{test_group_id}_generate_gcs_folder_location"
)(
f"{gcs_subfolder}/maxtext_configs_aot_hybridsim/v{tpu.value}",
test_group_id,
)
# Run AOT workload: generate HLO, upload to GCS
aot_cmd = (
'export XLA_FLAGS="--xla_dump_to=/tmp/xla_dump/ --xla_dump_large_constants"',
f"bash MaxText/configs/v{v5e_alt if tpu.value == TpuVersion.V5E.value else tpu.value}/{model_size}.sh EXECUTABLE=train_compile.py M_COMPILE_TOPOLOGY=v{v5e_alt if tpu.value == TpuVersion.V5E.value else tpu.value}-{num_cores} M_COMPILE_TOPOLOGY_NUM_SLICES={n}",
"gsutil -m cp -r /tmp/xla_dump/ ${GCS_OUTPUT}",
)
maxtext_aot = gke_config.get_gke_config(
time_out_in_min=240,
test_name=f"maxtext-{model_size}-{n}xv{tpu.value}-{num_cores}-aot",
run_model_cmds=aot_cmd,
docker_image=DockerImage.MAXTEXT_TPU_JAX_NIGHTLY.value,
test_owner=test_owner.RAYMOND_Z,
).run(gcs_location=shared_gcs_location)
# Run HybridSim workload: read HLO from GCS, generate estimated step time
cluster = clusters[tpu]
chip_config = "default" if tpu == TpuVersion.V5E else "megacore"
hybridsim_cmd = (
"gsutil cp gs://cloud-hybridsim-prod/run_hybridsim.sh .",
f"bash run_hybridsim.sh GCS_XLA_DUMP_PATH=${{GCS_OUTPUT}}xla_dump GCS_OUTPUT_PATH=${{GCS_OUTPUT}}estimated_cost_ns.jsonl CHIP_CONFIG={chip_config}",
)
job_metric_config = metric_config.MetricConfig(
json_lines=metric_config.JSONLinesConfig(
file_location="estimated_cost_ns.jsonl",
),
use_runtime_generated_gcs_folder=True,
)
maxtext_hybridsim = gke_config.get_gke_config(
cluster=cluster,
time_out_in_min=240,
test_name=f"maxtext-{model_size}-{n}xv{tpu.value}-{num_cores}-hybridsim",
run_model_cmds=hybridsim_cmd,
docker_image=DockerImage.CLOUD_HYBRIDSIM_NIGHTLY.value,
test_owner=test_owner.RAYMOND_Z,
user_specified_job_metric_config=job_metric_config,
).run(gcs_location=shared_gcs_location)
shared_gcs_location >> maxtext_aot >> maxtext_hybridsim