common/py_libs/config_spec.py (143 lines of code) (raw):
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Defines a ConfigSpec base class that can be used to define yaml schemas.
Dataclasses can inherit from this base class to construct a series of classes
that store and codify the YAML schema. YAML files can be loaded into a
dictionary and passed to the `from_dict` constructor. See the unit test for
examples.
"""
# TODO: Consider using protos.
# TODO: Consider moving to common.
# Right now there's an import collision with py_libs.logging
import dataclasses
from exceptiongroup import ExceptionGroup
import enum
import logging
import typing
from typing import Any, Dict, List, Type, TypeVar, Union
from common.py_libs import cortex_exceptions as cortex_exc
def _get_value_as_type(value: Any, field_type: Type[object]) -> Any:
"""Returns the given value as the specified field_type."""
if value is None:
return None
is_builtin = field_type.__module__ == "builtins"
is_enum = field_type.__class__ == type(enum.Enum)
if is_builtin or field_type == Any: # pylint: disable=comparison-with-callable
return value
elif isinstance(value, field_type):
return value
elif is_enum:
return field_type[value.upper()] # type: ignore
# Recursively expands nested types within field_type.
else:
return field_type.from_dict(value) # type: ignore
def _get_value_as_types(value: Any, field_types: List[Type[object]]) -> Any:
"""Returns the given value as the first matching field_type."""
exceptions = []
for field_type in field_types:
try:
return _get_value_as_type(value, field_type)
# Re-raising exception in group.
except Exception as e: # pylint: disable=broad-except
exceptions.append(e)
raise ExceptionGroup(
f"Couldn't get value as any of the annotated types: {field_types}.",
exceptions)
def _unwrap_field_type(field_type: Type[object]) -> List[Type[object]]:
"""Returns a list of potential field types wrapped by List and Union. """
logging.debug("Unwrapping field_type: %s", field_type)
is_list = typing.get_origin(field_type) == list
is_union = typing.get_origin(field_type) == Union # pylint: disable=comparison-with-callable
# For lists, recursively unwrap the element type.
if is_list:
if not typing.get_args(field_type):
raise cortex_exc.TypeCError(
"Lists must define an element datatype.")
field_type = typing.get_args(field_type)[0]
return _unwrap_field_type(field_type)
# For unions, recursively unwrap all potential types.
all_field_types = []
if is_union:
field_types = typing.get_args(field_type)
for ft in field_types:
all_field_types.extend(_unwrap_field_type(ft))
# For single types, still return in a list.
else:
all_field_types.append(field_type)
logging.debug("potential field_types: %s", all_field_types)
return all_field_types
def _get_name(d: Dict[str, Any]) -> str:
name = d.get("name")
display_name = d.get("display_name")
if name:
return name
if display_name:
return display_name
raise cortex_exc.KeyCError(
f"Dict doesn't contain a 'name' or 'display_name' field:\n{d}")
T = TypeVar("T", bound="ConfigSpec")
@dataclasses.dataclass
class ConfigSpec:
"""Base class that can be used to define yaml schemas.
Dataclasses can inherit from this base class to construct a series of
classes that store and codify the YAML schema. YAML files can be loaded into
a dictionary and passed to the `from_dict` constructor.
"""
@classmethod
def from_dict(cls: Type[T], config: dict) -> T:
"""Constructor that expects a dict loaded from a YAML file."""
logging.debug("Class: %s\n", cls.__name__)
# Retrieve the field type hints so that we can dynamically create
# appropriate nested types. An alternative way to retrieve field types
# would be to use dataclasses.fields(cls), however that approach doesn't
# resolve ForwardRefs for types that reference themselves.
field_types = typing.get_type_hints(cls)
attrs = {}
for key, value in config.items():
logging.debug("Key: %s", key)
logging.debug("Value: %s", value)
if key not in field_types:
raise cortex_exc.KeyCError(
f"{cls.__name__}.{key} is undefined or "
"missing type annotations.")
field_type = field_types[key]
logging.debug("field_type: %s", field_type)
possible_field_types = _unwrap_field_type(field_type)
if isinstance(value, list):
attrs[key] = [
_get_value_as_types(v, possible_field_types) for v in value
]
else:
attrs[key] = _get_value_as_types(value, possible_field_types)
return cls(**attrs)
@classmethod
def merge_from_dict(cls: Type[T], a_dict: Dict[str, Any], b_dict: Dict[str,
Any],
type_hints: Dict[str, Any]) -> Dict[str, Any]:
"""Returns a merged ConfigSpec represented as dictionaries.
Both inputs and output are ConfigSpecs represented as dictionaries.
Values from `a` will take precedence over `b` where they differ and are
valid. For Lists, if the elements have a common `name` or `display_name`
they will be considered the same element and will be merged.
"""
for key, a_value in a_dict.items():
if key not in b_dict:
b_dict[key] = a_value
continue
# Ignore invalid values.
if not a_value:
continue
field_type = type_hints[key]
# For single values.
if not isinstance(a_value, list):
if issubclass(field_type, ConfigSpec):
item_type_hints = typing.get_type_hints(field_type)
nested_configspec = field_type.merge_from_dict(
a_value, b_dict[key], item_type_hints)
b_dict[key] = nested_configspec
else:
b_dict[key] = a_value
continue
# Get element field types for lists.
element_field_types = _unwrap_field_type(field_type)
if len(element_field_types) == 0:
raise cortex_exc.TypeCError(
f"Unable to get nested types for {field_type}")
if len(element_field_types) > 1:
raise cortex_exc.NotImplementedCError(
"ConfigSpec merge with union field types isn't supported.")
element_field_type = element_field_types[0]
# For valid non-ConfigSpec lists, a overwrites b.
if not issubclass(element_field_type, ConfigSpec):
b_dict[key] = a_value
continue
# For repeated ConfigSpecs:
if issubclass(element_field_type, ConfigSpec):
element_type_hints = typing.get_type_hints(element_field_type)
# If the ConfigSpec doesn't have a name field, behaves the same
# as non-ConfigSpec lists, where a just overwrites b.
try:
_ = _get_name(element_type_hints)
except cortex_exc.KeyCError:
b_dict[key] = a_value
continue
# For named repeated ConfigSpecs, merge elements with same
# name/display_name.
b_element_map = {_get_name(e): e for e in b_dict[key]}
for a_element in a_value:
name = _get_name(a_element)
if name in b_element_map:
merged_element = element_field_type.merge_from_dict(
a_element, b_element_map[name], element_type_hints)
b_element_map[name] = merged_element
# Add elements with new name/display_name.
else:
b_element_map[name] = a_element
# Make sure the final list is ordered to be deterministic.
ordered_elements = [v for k, v in sorted(b_element_map.items())]
b_dict[key] = ordered_elements
continue
raise cortex_exc.TypeCError(
f"Unexpected field type '{a_value.__class__.__name__}' while "
f"merging instances of '{cls.__name__}'")
return b_dict
@classmethod
def merge(cls: Type[T], a: T, b: T) -> T:
"""Returns a merged ConfigSpec from the two inputs.
Values from `a` will take precedence over `b` where they differ and are
valid. For Lists, if the elements have a common `name` or `display_name`
they will be considered the same element and will be merged.
"""
merge_dict = cls.merge_from_dict(dataclasses.asdict(a),
dataclasses.asdict(b),
typing.get_type_hints(cls))
return cls.from_dict(merge_dict)