dags/map_reproducibility/utils/benchmarkdb_utils.py (176 lines of code) (raw):

# Copyright 2024 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. "Bash helper commands for AOTC artifacts" import sys import os import getpass def write_run( model_id: str, hardware_id: str, software_id: str, number_of_nodes: int, number_of_chips: int, container_image_name: str, global_batch_size: int, precision: str, optimizer: str, seq_length: int, median_step_time: float, e2e_time: float, number_of_steps: int, mfu: float, tokens_per_second: float, writer_path: str, run_success: bool = True, # True because if mfu is none, writing to db will fail anyway. run_type: str = "perf_regression", run_release_status: str = "local", other_metrics_in_json: str = "", nccl_driver_nickname: str = None, env_variables: str = "", framework_config_in_json: str = "", xla_flags: str = "", topology: str = "", dataset: str = "", num_of_superblock: int = None, update_person_ldap: str = getpass.getuser(), comment: str = "", is_test: bool = False, logs_profile="", gcs_metrics_bucket="", workload_others="", experiment_id="", ): """Writes a workload benchmark run manually to the database. This function validates the provided IDs and, if valid, constructs a WorkloadBenchmarkV2Schema object with the given data and writes it to the "run_summary" table in BigQuery. Args: model_id: The ID of the model used in the run. hardware_id: The ID of the hardware used in the run. software_id: The ID of the software stack used in the run. number_of_nodes: The number of nodes used in the run. number_of_chips: The number of chips used in the run. container_image_name: The name of the container image used in the run. global_batch_size: The global batch size used in the run. precision: The precision used in the run (e.g., fp32, bf16). optimizer: The optimizer used in the run (e.g., adam, sgd). seq_length: The sequence length used in the run. median_step_time: The median step time of the run. e2e_time: The end-to-end time of the run. number_of_steps: The number of steps taken in the run. mfu: The MFU (model flops utilization) achieved in the run. tokens_per_second: The tokens per second achieved in the run. run_type: The type of run (default: "perf_optimization"). run_release_status: possible values "local" ( code changes are done locally), "prep_release" ( all code code changes are present in the image) other_metrics_in_json: A JSON string containing other metrics. nccl_driver_nickname: The nickname of the NCCL driver used. env_variables: A string containing environment variables. framework_config_in_json: A JSON string containing framework configurations. xla_flags: A json string containing all the XLA flags. topology: The topology of the hardware used in the run. ( valid for TPUs) dataset: The dataset used in the run. num_of_superblock: The number of superblocks in the hardware. ( valid for GPUs) update_person_ldap: The LDAP ID of the person updating the record (default: current user). comment: A comment about the run. is_test: Whether to use the testing project or the production project. Raises: ValueError: If any of the IDs are invalid. """ sys.path.append(writer_path) # pylint: disable=import-outside-toplevel import logging import uuid from typing import Type from benchmark_db_writer import bq_writer_utils from benchmark_db_writer.schema.workload_benchmark_v2 import workload_benchmark_v2_schema from benchmark_db_writer.schema.workload_benchmark_v2 import model_info_schema from benchmark_db_writer.schema.workload_benchmark_v2 import software_info_schema from benchmark_db_writer.schema.workload_benchmark_v2 import hardware_info_schema # pylint: enable=import-outside-toplevel logging.basicConfig( format="%(asctime)s %(levelname)-8s %(message)s", level=logging.INFO, datefmt="%Y-%m-%d %H:%M:%S", ) logger = logging.getLogger(__name__) def get_db_client( table: str, dataclass_type: Type, is_test: bool = False ) -> bq_writer_utils.create_bq_writer_object: """Creates a BigQuery client object. Args: table: The name of the BigQuery table. dataclass_type: The dataclass type corresponding to the table schema. is_test: Whether to use the testing project or the production project. Returns: A BigQuery client object. """ project = "supercomputer-testing" if is_test else "ml-workload-benchmarks" dataset = "mantaray_v2" if is_test else "benchmark_dataset_v2" print(f"Writing to project: {project}, dataset: {dataset}") return bq_writer_utils.create_bq_writer_object( project=project, dataset=dataset, table=table, dataclass_type=dataclass_type, ) def _validate_id( id_value: str, table_name: str, id_field: str, dataclass_type: Type, is_test: bool = False, ) -> bool: """Generic function to validate an ID against a BigQuery table. Args: id_value: The ID value to validate. table_name: The name of the BigQuery table. id_field: The name of the ID field in the table. is_test: Whether to use the testing project or the production project. Returns: True if the ID is valid, False otherwise. """ client = get_db_client(table_name, dataclass_type, is_test) result = client.query(where={id_field: id_value}) if not result: logger.info( "%s: %s is not present in the %s table ", id_field.capitalize(), id_value, table_name, ) logger.info( "Please add %s specific row in %s table before adding to run summary table", id_value, table_name, ) return False return True def validate_model_id(model_id: str, is_test: bool = False) -> bool: """Validates a model ID against the model_info table.""" print("model id: " + model_id) id_val = _validate_id( model_id, "model_info", "model_id", model_info_schema.ModelInfo, is_test ) if not id_val: print("model id validation failed") return False return True def validate_hardware_id(hardware_id: str, is_test: bool = False) -> bool: """Validates a hardware ID against the hardware_info table.""" id_val = _validate_id( hardware_id, "hardware_info", "hardware_id", hardware_info_schema.HardwareInfo, is_test, ) if not id_val: print("hardware id validation failed") return False return True def validate_software_id(software_id: str, is_test: bool = False) -> bool: """Validates a software ID against the software_info table.""" id_val = _validate_id( software_id, "software_info", "software_id", software_info_schema.SoftwareInfo, is_test, ) if not id_val: print("software id validation failed") return False return True print(model_id) if ( validate_model_id(model_id, is_test) and validate_hardware_id(hardware_id, is_test) and validate_software_id(software_id, is_test) ): summary = workload_benchmark_v2_schema.WorkloadBenchmarkV2Schema( run_id=f"run-{uuid.uuid4()}", model_id=model_id, software_id=software_id, hardware_id=hardware_id, hardware_num_chips=number_of_chips, hardware_num_nodes=number_of_nodes, result_success=run_success, configs_framework=framework_config_in_json, configs_env=env_variables, configs_container_version=container_image_name, configs_xla_flags=xla_flags, configs_dataset=dataset, logs_artifact_directory="", update_person_ldap=update_person_ldap, run_source="automation", run_type=run_type, run_release_status=run_release_status, workload_precision=precision, workload_gbs=global_batch_size, workload_optimizer=optimizer, workload_sequence_length=seq_length, metrics_e2e_time=e2e_time, metrics_mfu=mfu, metrics_step_time=median_step_time, metrics_tokens_per_second=tokens_per_second, metrics_num_steps=number_of_steps, metrics_other=other_metrics_in_json, hardware_nccl_driver_nickname=nccl_driver_nickname, hardware_topology=topology, hardware_num_superblocks=num_of_superblock, logs_comments=comment, logs_profile=logs_profile, gcs_metrics_bucket=gcs_metrics_bucket, workload_others=workload_others, experiment_id=experiment_id, ) client = get_db_client( "run_summary", workload_benchmark_v2_schema.WorkloadBenchmarkV2Schema, is_test, ) print("******metrics query is******") print(summary) client.write([summary]) else: raise ValueError("Could not upload data in run summary table")