cli/golden_config_parser.py (62 lines of code) (raw):
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Library for parsing a golden qualified configuration from JSON to NodeConfig."""
import json
from typing import Any
from google.cloud import storage
import config
import dependency_version_parser
def get_golden_configs(
dependency_parsers: list[dependency_version_parser.DependencyVersionParser],
machine_type: str,
) -> list[config.NodeConfig]:
"""Returns the golden configurations for all supported machine types."""
qualified_configurations = _get_qualified_configurations(
bucket_name='aiinfra-qualified-cluster-configurations',
blob_name=f'{machine_type}/qualified_versions.json',
)
return [
_parse_qualified_configuration(
config,
[parser.name for parser in dependency_parsers],
)
for config in _filter_qualified_configurations_for_status(
qualified_configuration_versions=qualified_configurations['version'],
status='ACTIVE',
)
]
def _parse_qualified_configuration(
qualified_configuration: dict[str, str | list[str]],
dependency_names: list[str],
) -> config.NodeConfig:
"""Parses a qualified configuration into a NodeConfig proto."""
dependency_configs = {}
for dependency_name in dependency_names:
dependency_version = qualified_configuration.get(dependency_name, None)
if isinstance(dependency_version, str):
dependency_configs[dependency_name] = config.DependencyConfig(
name=dependency_name,
version=str(dependency_version),
)
elif isinstance(dependency_version, list):
deps = {
pair.split('=')[0]: pair.split('=')[1] for pair in dependency_version
}
dependency_configs[dependency_name] = config.DependencyConfig(
name=dependency_name,
version='',
config_settings=deps,
)
return config.NodeConfig(
name='GoldenConfig',
dependencies=dependency_configs,
)
def _filter_qualified_configurations_for_status(
qualified_configuration_versions: list[dict[str, str | list[str]]],
status: str,
) -> list[dict[str, Any]]:
"""Filters qualified configurations for a given machine type."""
return [
config_version
for config_version in qualified_configuration_versions
if config_version['status'] == status
]
def _get_qualified_configurations(bucket_name, blob_name) -> dict[str, Any]:
"""Reads data from a Google Cloud Storage bucket.
Args:
bucket_name: The name of the GCS bucket.
blob_name: The path to the file within the bucket.
Returns:
The contents of the file as a string.
"""
# Authenticate using the service account key file
storage_client = storage.Client.create_anonymous_client()
# Get the bucket and blob objects
bucket = storage_client.bucket(bucket_name)
blob = bucket.blob(blob_name)
return json.loads(blob.download_as_text())