xlml/utils/metric.py (556 lines of code) (raw):

# Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Utilities to process Benchmark metrics.""" import dataclasses import datetime import enum import hashlib import os import re from typing import Dict, Iterable, List, Optional import uuid from absl import logging import airflow from airflow.decorators import task from airflow.exceptions import AirflowFailException from airflow.models import TaskInstance from airflow.operators.python import get_current_context from xlml.apis import gcp_config, test_config from xlml.apis import metric_config from xlml.utils import bigquery, composer from dags import composer_env from google.cloud import storage import jsonlines import numpy as np import tensorflow as tf from tensorflow.core.util import event_pb2 from urllib.parse import urlparse @dataclasses.dataclass class TensorBoardScalar: metric_value: float step: int class TaskState(enum.Enum): FAILED = "failed" SKIPPED = "upstream_failed" SUCCESS = "success" def is_valid_tag( tag: str, include_tag_patterns: Optional[Iterable[str]], exclude_tag_patterns: Optional[Iterable[str]], ) -> bool: """Check if it is a valid tag. Args: tag: The tag to check. include_tag_patterns: A list of patterns should be included. exclude_tag_patterns: A list of patterns should be excluded. This pattern has higher priority to include_tag_pattern, if any conflict. Returns: A bool to indicate if this tag should be included. """ if exclude_tag_patterns and any( re.match(x, tag) for x in exclude_tag_patterns ): # check if tag in exclude_tag_patterns return False if include_tag_patterns: # check if tag in include_tag_patterns return any(re.match(x, tag) for x in include_tag_patterns) return True def read_from_tb( file_location: str, include_tag_patterns: Optional[Iterable[str]], exclude_tag_patterns: Optional[Iterable[str]], ) -> (Dict[str, List[TensorBoardScalar]], Dict[str, str]): """Read metrics and dimensions from TensorBoard file. Args: file_location: The full path of a file in GCS. include_tag_patterns: The matching pattern of tags that wil be included. exclude_tag_patterns: The matching pattern of tags that will be excluded. This pattern has higher priority to include_tag_pattern, if any conflict. Returns: A dict that maps metric name to a list of TensorBoardScalar, and a dict that maps dimension name to dimenstion value. """ metrics = {} metadata = {} serialized_examples = tf.data.TFRecordDataset(file_location) logging.info(f"TensorBoard metric_location is: {file_location}") for ex in serialized_examples: event = event_pb2.Event.FromString(ex.numpy()) for value in event.summary.value: if not is_valid_tag( value.tag, include_tag_patterns, exclude_tag_patterns ): continue value_type = value.metadata.plugin_data.plugin_name if value_type == "scalars": if value.tag not in metrics: metrics[value.tag] = [] t = tf.make_ndarray(value.tensor) metrics[value.tag].append(TensorBoardScalar(float(t), event.step)) elif value_type == "text": metadata[value.tag] = bytes(value.tensor.string_val[0]).decode("utf-8") elif value.HasField("simple_value"): # simple_value indicates the value is a float: # https://github.com/tensorflow/tensorflow/blob/4dacf3f/tensorflow/core/framework/summary.proto#L122 scalar = TensorBoardScalar(value.simple_value, event.step) metrics.setdefault(value.tag, []).append(scalar) else: logging.info( f"Discarding data point {value.tag} with type {value_type}." ) return metrics, metadata def aggregate_metrics( metrics: Iterable[TensorBoardScalar], strategy: metric_config.AggregationStrategy, ) -> float: """Get the aggregated value based on stragety. Args: metrics: The TensorBoardScalar from TensorBoard file. strategy: The strategy for aggregate values. Returns: A value after aggregation. """ if strategy == metric_config.AggregationStrategy.LAST: last_value = max(metrics, key=lambda p: p.step) return last_value.metric_value elif strategy == metric_config.AggregationStrategy.AVERAGE: return np.mean([m.metric_value for m in metrics]) elif strategy == metric_config.AggregationStrategy.MEDIAN: return np.median([m.metric_value for m in metrics]) else: raise NotImplementedError(f"Unknown aggregation strategy: {strategy}") def download_object_from_gcs( source_location: str, destination_location: str ) -> None: """Download object from GCS bucket. Args: source_location: The full path of a file in GCS. destination_location: The local path of the file. """ storage_client = storage.Client() bucket_name = source_location.split("/")[2] object_name = "/".join(source_location.split("/")[3:]) bucket = storage_client.bucket(bucket_name) blob = bucket.blob(object_name) blob.download_to_filename(destination_location) logging.info( ( "Download JSON Lines file from" f" {source_location} to {destination_location}" ) ) def process_json_lines( base_id: str, file_location: str, ) -> ( List[List[bigquery.MetricHistoryRow]], List[List[bigquery.MetadataHistoryRow]], ): """Process metrics and dimensions from JSON Lines file. Args: base_id: The unique ID for this test job. file_location: The full path of a file in GCS. Returns: A list of MetricHistoryRow for all test runs, and a list of MetadataHistoryRow ofr all test runs in a test job. """ tmp_location = "/tmp/ml-auto-solutions-metrics.jsonl" download_object_from_gcs(file_location, tmp_location) metric_list = [] metadata_list = [] with jsonlines.open(tmp_location) as reader: index = 0 for object in reader: uuid = generate_row_uuid(base_id, index) index += 1 raw_metrics = object["metrics"] metadata = object["dimensions"] metric_history_rows = [] metadata_history_rows = [] for key, value in raw_metrics.items(): metric_history_rows.append( bigquery.MetricHistoryRow( job_uuid=uuid, metric_key=key, metric_value=value ) ) for key, value in metadata.items(): metadata_history_rows.append( bigquery.MetadataHistoryRow( job_uuid=uuid, metadata_key=key, metadata_value=value ) ) metric_list.append(metric_history_rows) metadata_list.append(metadata_history_rows) return metric_list, metadata_list def process_tensorboard_summary( base_id: str, summary_config: metric_config.SummaryConfig, use_generated_gcs_folder: bool, generated_gcs_folder: Optional[str], ) -> ( List[List[bigquery.MetricHistoryRow]], List[List[bigquery.MetadataHistoryRow]], ): """Process metrics and dimensions from TensorBoard file. Args: base_id: The unique ID for this test job. summary_config: The configs for TensorBoard summary. use_generated_gcs_folder: The indicator to use default gcs folder. generated_gcs_folder: The GCS path of default folder. Returns: A list of MetricHistoryRow for a test run, and a list of MetadataHistoryRow ofr a test run in a test job. """ uuid = generate_row_uuid(base_id, 0) if isinstance(summary_config.file_location, airflow.XComArg): file_location = summary_config.file_location.resolve(get_current_context()) else: if use_generated_gcs_folder: file_location = os.path.join( generated_gcs_folder, summary_config.file_location ) else: file_location = summary_config.file_location if summary_config.use_regex_file_location: file_location = get_gcs_file_location_with_regex(file_location) if file_location == "": return [[]], [[]] aggregation_strategy = summary_config.aggregation_strategy include_tag_patterns = summary_config.include_tag_patterns exclude_tag_patterns = summary_config.exclude_tag_patterns metrics, metadata = read_from_tb( file_location, include_tag_patterns, exclude_tag_patterns ) aggregated_metrics = {} for key, value in metrics.items(): aggregated_metrics[key] = aggregate_metrics(value, aggregation_strategy) print("aggregated_metrics", aggregated_metrics) metric_history_rows = [] metadata_history_rows = [] for key, value in aggregated_metrics.items(): metric_history_rows.append( bigquery.MetricHistoryRow( job_uuid=uuid, metric_key=key, metric_value=value ) ) for key, value in metadata.items(): metadata_history_rows.append( bigquery.MetadataHistoryRow( job_uuid=uuid, metadata_key=key, metadata_value=value ) ) return [metric_history_rows], [metadata_history_rows] def get_gcs_file_location_with_regex(file_location: str) -> str: """ Get a file from GCS given a regex in the form of `gs://<your_bucket>/<your_file_path_regex>`. Does not support bucket name or path regex. Only supports file name regex. Args: file_location: File location regex in the form of `gs://<your_bucket>/<path>/<your_file_name_regex>`. Returns: The file location of the first file that fits the given regex. """ storage_client = storage.Client() url = urlparse(file_location) bucket_name = url.netloc file_path = url.path.strip("/") file_path_regex = re.compile(file_path) prefix = "/".join(file_path.split("/")[:-1]) all_blobs_names = [ b.name for b in storage_client.list_blobs(bucket_name, prefix=prefix) ] try: return ( f"gs://{bucket_name}/" f"{next(filter(file_path_regex.match, all_blobs_names))}" ) except StopIteration: logging.warning(f"No objects matched supplied regex: {file_location}") return "" # TODO(qinwen): implement profile metrics & upload to Vertex AI TensorBoard def process_profile( uuid: str, file_location: str ) -> List[List[bigquery.MetricHistoryRow]]: raise NotImplementedError def encode_url(url: str) -> str: """Replace characters with % followed by two hexadecimal digits. Args: url: The url to be encoded. Returns: An encoded url. """ return str(url).replace(":", "%3A").replace("+", "%2B") def add_airflow_metadata( base_id: str, project_name: str, metadata: List[List[bigquery.MetricHistoryRow]], ) -> List[List[bigquery.MetricHistoryRow]]: """Add airflow metadata: run_id, prev_start_date_success, and airflow_dag_run_link. Args: base_id: The base id to generate uuid. metadata: The data to append airflow metadata. configs: The GCP configs to get composer metadata. Returns: The data with airflow metadata. """ context = get_current_context() run_id = context["run_id"] prev_start_date_success = str(context["prev_start_date_success"]) dag_run = context["dag_run"] dag_id = dag_run.dag_id task_id = context["task"].task_id dag_run_id = encode_url(run_id) airflow_link = composer.get_airflow_url( project_name, os.environ.get(composer_env.COMPOSER_LOCATION), os.environ.get(composer_env.COMPOSER_ENVIRONMENT), ) airflow_dag_run_link = ( f"{airflow_link}/dags/{dag_id}/" f"grid?dag_run_id={dag_run_id}&task_id={task_id}" ) logging.info(f"airflow_dag_run_link is {airflow_dag_run_link}") # append airflow metadata for each test run. for index in range(len(metadata)): uuid = generate_row_uuid(base_id, index) airflow_meta = [] airflow_meta.append( bigquery.MetadataHistoryRow( job_uuid=uuid, metadata_key="run_id", metadata_value=run_id ) ) if context["prev_start_date_success"]: airflow_meta.append( bigquery.MetadataHistoryRow( job_uuid=uuid, metadata_key="prev_start_date_success", metadata_value=prev_start_date_success, ) ) airflow_meta.append( bigquery.MetadataHistoryRow( job_uuid=uuid, metadata_key="airflow_dag_run_link", metadata_value=airflow_dag_run_link, ) ) metadata[index].extend(airflow_meta) return metadata def add_test_config_metadata( base_id: str, task_test_config: test_config.TestConfig[test_config.Accelerator], task_gcp_config: gcp_config.GCPConfig, task_metric_config: metric_config.MetricConfig, metadata: List[List[bigquery.MetricHistoryRow]], ) -> List[List[bigquery.MetricHistoryRow]]: for index in range(len(metadata)): uuid = generate_row_uuid(base_id, index) test_config_meta = [] test_config_meta.append( bigquery.MetadataHistoryRow( job_uuid=uuid, metadata_key="accelerator", metadata_value=task_test_config.accelerator.name, ) ) test_config_meta.append( bigquery.MetadataHistoryRow( job_uuid=uuid, metadata_key="project", metadata_value=task_gcp_config.project_name, ) ) if hasattr(task_test_config, "num_slices"): test_config_meta.append( bigquery.MetadataHistoryRow( job_uuid=uuid, metadata_key="num_slices", metadata_value=task_test_config.num_slices, ) ) test_config_meta.append( bigquery.MetadataHistoryRow( job_uuid=uuid, metadata_key="multislice_topology", metadata_value=( f"{task_test_config.num_slices}" f"x{task_test_config.accelerator.name}" ), ) ) if ( task_metric_config is not None and task_metric_config.tensorboard_summary ): test_config_meta.append( bigquery.MetadataHistoryRow( job_uuid=uuid, metadata_key="metric_aggregation_strategy", metadata_value=task_metric_config.tensorboard_summary.aggregation_strategy.name, ) ) metadata[index].extend(test_config_meta) return metadata def generate_row_uuid(base_id: str, index: int) -> str: """Generate uuid for entry. Args: base_id: The process id generated once per post process task group. index: The index of test runs. Returns: A uuid for table entry. """ return hashlib.sha256(str(base_id + str(index)).encode("utf-8")).hexdigest() @task(trigger_rule="all_done") def generate_process_id() -> str: """Generate a process id that will be a base id for uuid of test runs. Returns: A random uuid. """ return str(uuid.uuid4()) def update_dataset_name_if_needed( prod_dataset_name: metric_config.DatasetOption, ) -> str: """Update the dataset name based on stage (if needed). All data from prod env will be sent to benchmark_dataset or xlml_dataset; the rest will be sent to dev_benchmark_dataset or dev_xlml_dataset. """ if not composer_env.is_prod_env(): logging.info("This is a non-prod run, and send all data to dev dataset.") return f"dev_{prod_dataset_name.value}" return prod_dataset_name.value def get_xpk_job_status(benchmark_id: str) -> bigquery.JobStatus: """Get job status for the GKE run. FAILED - if any failure occurs in run_model SUCCESS - end-to-end model tests are successful in run_model """ context = get_current_context() execution_date = context["dag_run"].logical_date current_dag = context["dag"] workload_completion = current_dag.get_task( task_id=f"{benchmark_id}.run_model.wait_for_workload_completion" ) workload_completion_ti = TaskInstance(workload_completion, execution_date) workload_completion_state = workload_completion_ti.current_state() if workload_completion_state == TaskState.SUCCESS.value: logging.info( "The wait_for_workload_completion state is success, and the job status" " is success." ) return bigquery.JobStatus.SUCCESS logging.info( "The wait_for_workload_completion state is not success, and the job" " status is failed." ) return bigquery.JobStatus.FAILED def get_gke_job_status( task_test_config: test_config.TestConfig[test_config.Accelerator], ) -> bigquery.JobStatus: """Get job status for the GCE run. FAILED - if any failure occurs in setup & run_model (including timeout of run_model). SUCCESS - end-to-end model tests are successful from provision to run_model """ context = get_current_context() execution_date = context["dag_run"].logical_date current_dag = context["dag"] benchmark_id = task_test_config.benchmark_id # check setup status to see if setup step is successful setup_task = current_dag.get_task( task_id=f"{benchmark_id}.generate_gcs_folder_location" ) setup_ti = TaskInstance(setup_task, execution_date) setup_state = setup_ti.current_state() if setup_state == TaskState.FAILED.value: logging.info("The setup state is failed, and the job status is failed.") return bigquery.JobStatus.FAILED # check run_model status to see if run_model step is successful run_model_task = current_dag.get_task( task_id=f"{benchmark_id}.run_model.stream_logs" ) run_model_ti = TaskInstance(run_model_task, execution_date) run_model_state = run_model_ti.current_state() if run_model_state == TaskState.SUCCESS.value: logging.info( "The run_model state is success, and the job status is success." ) return bigquery.JobStatus.SUCCESS logging.info("The run_model state is failed, and the job status is failed.") return bigquery.JobStatus.FAILED def get_gce_job_status( task_test_config: test_config.TestConfig[test_config.Accelerator], use_startup_script: bool, ) -> bigquery.JobStatus: """Get job status for the GCE run. MISSED - if any failure occurs in initialize & create_queued_resource FAILED - if any failure occurs in setup & run_model (including timeout of run_model) for SSH method. FAILED - if any failure occurs in check_if_startup_script_end (including timeout of check_if_startup_script_end) for startup script method. SUCCESS - end-to-end model tests are successful from provision to run_model """ context = get_current_context() execution_date = context["dag_run"].logical_date current_dag = context["dag"] benchmark_id = task_test_config.benchmark_id # GCE SSH method if not use_startup_script: if isinstance(task_test_config.accelerator, test_config.Tpu): # check wait status to see if wait_for_ready_queued_resource is successful wait_task = current_dag.get_task( task_id=f"{benchmark_id}.provision.create_queued_resource.wait_for_ready_queued_resource" ) elif isinstance(task_test_config, test_config.GpuVmTest): if task_test_config.use_existing_instance: wait_task = current_dag.get_task( task_id=f"{benchmark_id}.provision.get_existing_resource" ) else: wait_task = current_dag.get_task( task_id=f"{benchmark_id}.provision.create_resource.get_ip_address" ) else: raise NotImplementedError( f"Unable to get task for {type(task_test_config.accelerator)}." ) wait_ti = TaskInstance(wait_task, execution_date) wait_state = wait_ti.current_state() if wait_state == TaskState.SKIPPED.value: logging.info( "The wait_for_ready_queued_resource state is skipped, and the job status is missed." ) return bigquery.JobStatus.MISSED # check setup status to see if setup step is successful if ( hasattr(task_test_config, "use_existing_instance") and task_test_config.use_existing_instance ): get_instance_task = current_dag.get_task( task_id=f"{benchmark_id}.provision.get_existing_resource" ) get_instance_ti = TaskInstance(get_instance_task, execution_date) get_instance_state = get_instance_ti.current_state() if get_instance_state == TaskState.FAILED.value: logging.info( "The getting existing instance state is failed, and the job status is failed." ) return bigquery.JobStatus.FAILED else: setup_task = current_dag.get_task( task_id=f"{benchmark_id}.provision.setup" ) setup_ti = TaskInstance(setup_task, execution_date) setup_state = setup_ti.current_state() if setup_state == TaskState.FAILED.value: logging.info("The setup state is failed, and the job status is failed.") return bigquery.JobStatus.FAILED # check run_model status to see if run_model step is successful run_model_task = current_dag.get_task(task_id=f"{benchmark_id}.run_model") run_model_ti = TaskInstance(run_model_task, execution_date) run_model_state = run_model_ti.current_state() if run_model_state == TaskState.SUCCESS.value: logging.info( "The run_model state is success, and the job status is success." ) return bigquery.JobStatus.SUCCESS logging.info("The run_model state is failed, and the job status is failed.") return bigquery.JobStatus.FAILED # GCE startup script method else: # check wait status to see if provision step is successful wait_task = current_dag.get_task( task_id=f"{benchmark_id}.provision_with_startup_script.create_queued_resource.wait_for_ready_queued_resource" ) wait_ti = TaskInstance(wait_task, execution_date) wait_state = wait_ti.current_state() if wait_state == TaskState.SKIPPED.value: logging.info( "The wait_for_ready_queued_resource state is skipped, and the job status is missed." ) return bigquery.JobStatus.MISSED # check startup_script status to see if startup_script step is successful startup_script_task = current_dag.get_task( task_id=f"{benchmark_id}.provision_with_startup_script.create_queued_resource.check_if_startup_script_end" ) startup_script_ti = TaskInstance(startup_script_task, execution_date) startup_script_state = startup_script_ti.current_state() if startup_script_state == TaskState.FAILED.value: logging.info( "The startup_script state is failed, and the job status is failed." ) return bigquery.JobStatus.FAILED else: logging.info( "The startup_script state is success, and the job status is success." ) return bigquery.JobStatus.SUCCESS # TODO(ranran): handle Airflow retry to avoid duplicate records in tables @task def process_metrics( base_id: str, task_test_config: test_config.TestConfig[test_config.Accelerator], task_metric_config: Optional[metric_config.MetricConfig], task_gcp_config: gcp_config.GCPConfig, use_startup_script: bool = False, folder_location: Optional[str] = None, ) -> None: benchmark_id = task_test_config.benchmark_id current_time = datetime.datetime.now() has_profile = False metric_history_rows_list = [[]] metadata_history_rows_list = [[]] profile_history_rows_list = [] # process metrics, metadata, and profile if task_metric_config: if task_metric_config.json_lines: absolute_path = ( os.path.join( folder_location, task_metric_config.json_lines.file_location ) if task_metric_config.use_runtime_generated_gcs_folder else task_metric_config.json_lines.file_location ) metric_history_rows_list, metadata_history_rows_list = process_json_lines( base_id, absolute_path ) if task_metric_config.tensorboard_summary: ( metric_history_rows_list, metadata_history_rows_list, ) = process_tensorboard_summary( base_id, task_metric_config.tensorboard_summary, task_metric_config.use_runtime_generated_gcs_folder, folder_location, ) if task_metric_config.profile: has_profile = True num_profiles = len(task_metric_config.profile.file_locations) for index in range(num_profiles): profile_history_rows = process_profile( base_id, task_metric_config.profile.file_locations[index] ) profile_history_rows_list.append(profile_history_rows) # add default airflow metadata metadata_history_rows_list = add_airflow_metadata( base_id, task_gcp_config.composer_project, metadata_history_rows_list ) metadata_history_rows_list = add_test_config_metadata( base_id, task_test_config, task_gcp_config, task_metric_config, metadata_history_rows_list, ) # append profile metrics to metric_history_rows_list if any if has_profile: if len(metric_history_rows_list) != len(profile_history_rows_list): logging.error( f"The num of profile is {len(profile_history_rows_list)}, but it is" " different to the number of test runs" f" {len(metric_history_rows_list)}. Ignoring profiles." ) else: for index in range(len(metric_history_rows_list)): metric_history_rows_list[index].extend(profile_history_rows_list[index]) test_run_rows = [] dataset_name = update_dataset_name_if_needed(task_gcp_config.dataset_name) bigquery_metric = bigquery.BigQueryMetricClient( task_gcp_config.dataset_project, dataset_name ) if hasattr(task_test_config, "cluster_name"): test_job_status = get_xpk_job_status(task_test_config.benchmark_id) elif isinstance(task_test_config, test_config.GpuGkeTest): test_job_status = get_gke_job_status(task_test_config) else: test_job_status = get_gce_job_status(task_test_config, use_startup_script) for index in range(len(metadata_history_rows_list)): job_history_row = bigquery.JobHistoryRow( uuid=generate_row_uuid(base_id, index), timestamp=current_time, owner=task_test_config.task_owner, job_name=benchmark_id, job_status=test_job_status.value, ) test_run_row = bigquery.TestRun( job_history_row, metric_history_rows_list[index], metadata_history_rows_list[index], ) test_run_rows.append(test_run_row) print("Test run rows:", test_run_rows) bigquery_metric.insert(test_run_rows)