dags/map_reproducibility/utils/internal_aotc_workload.py (178 lines of code) (raw):

# Copyright 2024 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Workload functions for AOTC reproducibility benchmarks.""" import os import tempfile from airflow.decorators import task from airflow.hooks.subprocess import SubprocessHook from airflow.operators.python import get_current_context from dags.map_reproducibility.utils.common_utils import configure_project_and_cluster from dags.map_reproducibility.utils.common_utils import install_helm_cmds from dags.map_reproducibility.utils.common_utils import namespace_cmds from dags.map_reproducibility.utils.common_utils import internal_wait_for_jobs_cmds from dags.map_reproducibility.utils.common_utils import cleanup_cmds, cleanup_all_runs_cmds from dags.map_reproducibility.utils.common_utils import git_cookie_authdaemon from dags.map_reproducibility.utils.common_utils import clone_recipes_gob, clone_internal_recipes_gob from dags.map_reproducibility.utils.common_utils import helm_apply_cmds_internal_run from dags.map_reproducibility.utils.common_utils import get_bq_writer_repo from dags.map_reproducibility.utils.benchmarkdb_utils import write_run from dags.map_reproducibility.utils.common_utils import get_internal_pre_workload_cmds, get_internal_pre_workload_job_name from dags.map_reproducibility.utils.common_utils import get_gpu_recipe_cmd from dags.map_reproducibility.utils.common_utils import get_bq_writer_path from dags.map_reproducibility.utils.common_utils import get_recipe_repo_path, get_internal_recipe_repo_path from dags.map_reproducibility.utils.common_utils import get_cluster from dags.map_reproducibility.utils.common_utils import calculate_maxtext_metrics, get_skip_steps_for_metrics_calculation from dags.map_reproducibility.utils.common_utils import copy_bucket_cmds_maxtext, get_job_gcs_bucket_folder from dags.map_reproducibility.utils.common_utils import parse_internal_config_filename from dags.map_reproducibility.utils.common_utils import parse_internal_config_content from dags.map_reproducibility.utils.constants import Optimizer, KUEUE_NAME, NUM_STEPS @task def run_internal_aotc_workload( relative_config_yaml_path, test_run=False, backfill=False, timeout=None, image_version=None, ): """Runs the AOTC workload benchmark. Args: relative_config_yaml_path: Path to the config YAML relative to the repo root """ # Get the current context to access DAG ID context = get_current_context() dag_id = context["dag"].dag_id # Parse config from filename config_yaml_name = relative_config_yaml_path.rsplit("/", maxsplit=1)[ -1 ].replace(".yaml", "") config = parse_internal_config_filename(config_yaml_name) # Get derived configuration cluster, cluster_region = get_cluster(config.HYPERCOMPUTER) docker_image = image_version values_name = f"{config.HYPERCOMPUTER}_{config.FRAMEWORK}_values" with tempfile.TemporaryDirectory() as tmpdir: hook = SubprocessHook() result = hook.run_command( [ "bash", "-c", ";".join( git_cookie_authdaemon() + clone_recipes_gob() + clone_internal_recipes_gob() + get_bq_writer_repo() ), ], cwd=tmpdir, ) recipe_repo_root = get_recipe_repo_path(tmpdir) bq_writer_repo_root = get_bq_writer_path(tmpdir) # Update paths now that we have the repo paths internal_recipe_repo_root = get_internal_recipe_repo_path(tmpdir) values_file_path = f"{internal_recipe_repo_root}/values/{values_name}.yaml" model_specific_values_file_path = ( f"{internal_recipe_repo_root}/values/{config_yaml_name}_values.yaml" ) if os.path.exists(model_specific_values_file_path): # Use model-specific values file values_file_path = model_specific_values_file_path print( f"Using model-specific values file: {model_specific_values_file_path}" ) else: print( f"Model-specific values file not found, using general values file: {values_file_path}" ) full_config_yaml_path = ( f"{internal_recipe_repo_root}/{relative_config_yaml_path}" ) print(f"values_file_path is {values_file_path}") print(f"full_config_yaml_path is {full_config_yaml_path}") # Parse the config content now that we have the file path config = parse_internal_config_content(full_config_yaml_path, config=config) job_name = get_internal_pre_workload_job_name( model_id=config.MODEL_ID, precision=config.PRECISION, num_gpus=config.NUM_GPUS, framework=config.FRAMEWORK, cluster=config.HYPERCOMPUTER, ) # Print DAG ID with job name print(f"Running job '{job_name}' in DAG '{dag_id}'") container_timeout = int(timeout) - 4 print(f"container timeout is {container_timeout}") result = hook.run_command( [ "bash", "-c", ";".join( configure_project_and_cluster(cluster, cluster_region) + get_gpu_recipe_cmd( config.HYPERCOMPUTER, config.MODEL_ID, config.FRAMEWORK, recipe_repo_root, ) + install_helm_cmds() + namespace_cmds() + get_internal_pre_workload_cmds(job_name=job_name) + helm_apply_cmds_internal_run( config.FRAMEWORK, config.HYPERCOMPUTER, full_config_yaml_path, internal_recipe_repo_root, values_file_path, docker_image, cluster_name=cluster, kueue_name=KUEUE_NAME, additional_cmds=f" --set workload.gpus={config.NUM_GPUS} ", ) + internal_wait_for_jobs_cmds(timeout=container_timeout) + copy_bucket_cmds_maxtext(tmpdir) + cleanup_cmds() ), ], cwd=tmpdir, ) assert result.exit_code == 0, f"Command failed with code {result.exit_code}" log_location = os.path.join(tmpdir, "tflog/metrics") comment = ( "internal recipes regression tests" if not backfill else "internal recipes regression tests backfill" ) is_db_test_run = False if backfill else test_run gcs_bucket = get_job_gcs_bucket_folder(job_name) print(f"GCS bucket is {gcs_bucket}") # calculate mfu based on the config skip_first_n_steps = get_skip_steps_for_metrics_calculation(config) mfu, step_time = calculate_maxtext_metrics( log_location, config.HYPERCOMPUTER, skip_first=skip_first_n_steps, ) print(f"mfu: {mfu}") print(f"step_time: {step_time}") write_run( model_id=config.HELM_NAME_MODEL_ID, hardware_id=config.HYPERCOMPUTER, software_id=config.SOFTWARE_ID, number_of_nodes=config.NUM_GPUS / 8, number_of_chips=config.NUM_GPUS, container_image_name=docker_image, global_batch_size=config.per_device_batch_size * config.NUM_GPUS, precision=config.PRECISION, optimizer=Optimizer.ADAM, seq_length=config.max_target_length, median_step_time=step_time, e2e_time=step_time * NUM_STEPS, number_of_steps=NUM_STEPS, mfu=mfu, tokens_per_second=1, writer_path=bq_writer_repo_root, run_type="internal_perf_regression", topology="", comment=comment, is_test=is_db_test_run, gcs_metrics_bucket=gcs_bucket, workload_others=str(config), experiment_id=job_name, ) @task def cleanup_cml_workloads(cluster, cluster_region): with tempfile.TemporaryDirectory() as tmpdir: hook = SubprocessHook() result = hook.run_command( [ "bash", "-c", ";".join(cleanup_all_runs_cmds(cluster, cluster_region)), ], cwd=tmpdir, ) assert result.exit_code == 0, f"Command failed with code {result.exit_code}"