def hybridsim_compile_and_run()

in dags/multipod/maxtext_configs_aot_hybridsim.py [0:0]


def hybridsim_compile_and_run(test_group_id):
  with TaskGroup(group_id=test_group_id, prefix_group_id=False) as group:
    gcs_subfolder = f"{test_owner.Team.MULTIPOD.value}/maxtext"
    shared_gcs_location = name_format.generate_gcs_folder_location.override(
        task_id=f"{test_group_id}_generate_gcs_folder_location"
    )(
        f"{gcs_subfolder}/maxtext_configs_aot_hybridsim/v{tpu.value}",
        test_group_id,
    )

    # Run AOT workload: generate HLO, upload to GCS
    aot_cmd = (
        'export XLA_FLAGS="--xla_dump_to=/tmp/xla_dump/ --xla_dump_large_constants"',
        f"bash MaxText/configs/v{v5e_alt if tpu.value == TpuVersion.V5E.value else tpu.value}/{model_size}.sh EXECUTABLE=train_compile.py M_COMPILE_TOPOLOGY=v{v5e_alt if tpu.value == TpuVersion.V5E.value else tpu.value}-{num_cores} M_COMPILE_TOPOLOGY_NUM_SLICES={n}",
        "gsutil -m cp -r /tmp/xla_dump/ ${GCS_OUTPUT}",
    )
    maxtext_aot = gke_config.get_gke_config(
        time_out_in_min=240,
        test_name=f"maxtext-{model_size}-{n}xv{tpu.value}-{num_cores}-aot",
        run_model_cmds=aot_cmd,
        docker_image=DockerImage.MAXTEXT_TPU_JAX_NIGHTLY.value,
        test_owner=test_owner.RAYMOND_Z,
    ).run(gcs_location=shared_gcs_location)

    # Run HybridSim workload: read HLO from GCS, generate estimated step time
    cluster = clusters[tpu]
    chip_config = "default" if tpu == TpuVersion.V5E else "megacore"
    hybridsim_cmd = (
        "gsutil cp gs://cloud-hybridsim-prod/run_hybridsim.sh .",
        f"bash run_hybridsim.sh GCS_XLA_DUMP_PATH=${{GCS_OUTPUT}}xla_dump GCS_OUTPUT_PATH=${{GCS_OUTPUT}}estimated_cost_ns.jsonl CHIP_CONFIG={chip_config}",
    )
    job_metric_config = metric_config.MetricConfig(
        json_lines=metric_config.JSONLinesConfig(
            file_location="estimated_cost_ns.jsonl",
        ),
        use_runtime_generated_gcs_folder=True,
    )
    maxtext_hybridsim = gke_config.get_gke_config(
        cluster=cluster,
        time_out_in_min=240,
        test_name=f"maxtext-{model_size}-{n}xv{tpu.value}-{num_cores}-hybridsim",
        run_model_cmds=hybridsim_cmd,
        docker_image=DockerImage.CLOUD_HYBRIDSIM_NIGHTLY.value,
        test_owner=test_owner.RAYMOND_Z,
        user_specified_job_metric_config=job_metric_config,
    ).run(gcs_location=shared_gcs_location)

    shared_gcs_location >> maxtext_aot >> maxtext_hybridsim