xlml/apis/metric_config.py (36 lines of code) (raw):
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Config file for benchmark metrics."""
import dataclasses
import enum
from typing import Iterable, List, Optional
# TODO(ranran): add project info to let users specify dataset location
class DatasetOption(enum.Enum):
BENCHMARK_DATASET = "benchmark_dataset"
XLML_DATASET = "xlml_dataset"
class FormatType(enum.Enum):
JSON_LINES = enum.auto()
TENSORBOARD_SUMMARY = enum.auto()
PROFILE = enum.auto()
class AggregationStrategy(enum.Enum):
LAST = enum.auto()
AVERAGE = enum.auto()
MEDIAN = enum.auto()
class SshEnvVars(enum.Enum):
GCS_OUTPUT = "${GCS_OUTPUT}"
BASE_OUTPUT_PATH = "${BASE_OUTPUT_PATH}"
@dataclasses.dataclass
class JSONLinesConfig:
"""A class to set up JSON Lines config.
Attributes:
file_location: The locatioin of the file in GCS. When
`use_runtime_generated_gcs_folder` flag is ture, use relative path.
"""
file_location: str
@dataclasses.dataclass
class SummaryConfig:
"""A class to set up TensorBoard summary config.
Attributes:
file_location: The locatioin of the file in GCS. When
`use_runtime_generated_gcs_folder` flag is ture, use relative path.
aggregation_strategy: The aggregation strategy for metrics.
include_tag_patterns: The matching patterns of tags that wil be included.
All tags are included by default.
exclude_tag_patterns: The matching patterns of tags that will be excluded.
No tag is excluded by default. This pattern has higher prioirty to
include_tag_pattern.
use_regex_file_location: Whether to use file_location as a regex to get the
file in GCS.
"""
file_location: str
aggregation_strategy: AggregationStrategy
include_tag_patterns: Optional[Iterable[str]] = None
exclude_tag_patterns: Optional[Iterable[str]] = None
use_regex_file_location: bool = False
@dataclasses.dataclass
class ProfileConfig:
"""A class to set up profile config.
Attributes:
file_locations: The locatioin of the file in GCS. When
`use_runtime_generated_gcs_folder` flag is ture, use relative path.
If JSON_LINES format type is used for metrics and dimensions, please
ensure the order of profiles match with test runs in JSON Lines.
"""
file_locations: List[str]
@dataclasses.dataclass
class MetricConfig:
"""A class to set up config of Benchmark metric, dimension, and profile.
Attributes:
json_lines: The config for JSON Lines input.
tensorboard_summary: The config for TensorBoard summary input.
profile: The config for profile input.
use_runtime_generated_gcs_folder: Indicator to use path based on
benchmark_id from generate_gcs_folder_location()
"""
json_lines: Optional[JSONLinesConfig] = None
tensorboard_summary: Optional[SummaryConfig] = None
profile: Optional[ProfileConfig] = None
use_runtime_generated_gcs_folder: bool = False