chz/blueprint/_blueprint.py (952 lines of code) (raw):
from __future__ import annotations
import ast
import collections.abc
import dataclasses
import functools
import inspect
import io
import sys
import textwrap
import typing
from dataclasses import dataclass
from typing import Any, Callable, Final, Generic, Mapping, Protocol
from typing_extensions import TypeVar
import chz
from chz.blueprint._argmap import ArgumentMap, Layer, join_arg_path
from chz.blueprint._argv import argv_to_blueprint_args
from chz.blueprint._entrypoint import (
ConstructionException,
EntrypointHelpException,
ExtraneousBlueprintArg,
InvalidBlueprintArg,
MissingBlueprintArg,
)
from chz.blueprint._lazy import (
Evaluatable,
ParamRef,
Thunk,
Value,
check_reference_targets,
evaluate,
)
from chz.field import Field
from chz.tiepin import (
CastError,
_simplistic_try_cast,
_simplistic_type_of_value,
eval_in_context,
is_kwargs_unpack,
is_subtype_instance,
is_typed_dict,
type_repr,
)
from chz.util import _MISSING_TYPE, MISSING
_T = TypeVar("_T")
_T_cov_def = TypeVar("_T_cov_def", covariant=True, default=Any)
class SpecialArg: ...
class Castable(SpecialArg):
"""A wrapper class for str if you want a Blueprint value to be magically type aware casted."""
def __init__(self, value: str) -> None:
self.value = value
def __repr__(self) -> str:
return f"Castable({self.value!r})"
def __hash__(self) -> int:
return hash(self.value)
def __eq__(self, other: object) -> bool:
if not isinstance(other, Castable):
try:
return _simplistic_try_cast(self.value, type(other)) == other
except CastError:
return False
return self.value == other.value
class Reference(SpecialArg):
"""A reference to another parameter in a Blueprint."""
def __init__(self, ref: str) -> None:
if "..." in ref:
raise ValueError("Cannot use wildcard as a reference target")
self.ref = ref
def __repr__(self) -> str:
return f"Reference({self.ref!r})"
@dataclass(frozen=True)
class _MakeResult:
# `value_mapping` is a dictionary mapping from parameter paths to Evaluatable values.
# This ultimately contains all the kinds of values we will use in instantiation.
# See chz.blueprint._lazy.evaluate for an example of using Evaluatable.
value_mapping: dict[str, Evaluatable]
# `all_params` is a dictionary containing all parameters we discover, mapping from that param
# path to the parameter. Note what parameters we discover will depend on polymorphic
# construction via meta_factories. We use all_params to provide a useful --help (and various
# other things, e.g. detect clobbering when checking for extraneous arguments)
all_params: dict[str, _Param]
# `used_args` is a set of (key, layer_index) tuples that we use to track which arguments from
# arg_map we've used. We use this to check for extraneous arguments.
used_args: set[tuple[str, int]]
# `meta_factory_value` records what meta_factory we're using. This makes --help more
# understandable in the presence of polymorphism, especially when factories come from
# blueprint_unspecified. It's conceptually the same information as in Thunk.fn in value_mapping,
# but preserves user input for variadics or generics (instead of being a constructor function)
meta_factory_value: dict[str, Any]
# `missing_params` is a list of parameters we know need are required but haven't been
# specified. In theory, this is unnecessary because `__init__` will raise an error if
# a required param is missing, but this improves diagnostics.
missing_params: list[str]
def _entrypoint_caster(o: str) -> object:
raise chz.tiepin.CastError("Will not interpret entrypoint as a value")
class Blueprint(Generic[_T_cov_def]):
def __init__(
self, target: chz.factories.MetaFactory | type[_T_cov_def] | Callable[..., _T_cov_def]
) -> None:
"""Instantiate a Blueprint.
Args:
target: The target object or callable we will instantiate or call.
"""
self.target = target
if isinstance(target, chz.factories.MetaFactory):
self.meta_factory = target
if isinstance(target, chz.factories.standard):
entrypoint_type = target.annotation
entrypoint_doc = getattr(entrypoint_type, "__doc__", "")
else:
entrypoint_type = object
entrypoint_doc = ""
self.entrypoint_repr = type_repr(entrypoint_type)
else:
self.meta_factory = chz.factories.standard(target)
entrypoint_type = target
if self.meta_factory.unspecified_factory() is None:
if not callable(target):
raise ValueError(f"{target} is not callable")
self.meta_factory = chz.factories.standard(object, unspecified=target)
entrypoint_type = object
self.entrypoint_repr = type_repr(target)
entrypoint_doc = getattr(target, "__doc__", "")
self.param = _Param(
name="",
type=entrypoint_type,
meta_factory=self.meta_factory,
default=None,
doc=entrypoint_doc.strip() if entrypoint_doc else "",
blueprint_cast=_entrypoint_caster,
metadata={},
)
self._arg_map = ArgumentMap([])
def clone(self) -> Blueprint[_T_cov_def]:
"""Make a copy of this Blueprint."""
return Blueprint(self.target).apply(self)
def apply(
self,
values: Blueprint[_T_cov_def] | Mapping[str, Any],
layer_name: str | None = None,
*,
subpath: str | None = None,
strict: bool = False,
) -> Blueprint[_T_cov_def]:
"""Modify this Blueprint by partially applying some arguments.
Args:
values: The values to partially apply to this Blueprint
layer_name: The name of the layer to add (allows identification of the source of values)
subpath: A subpath to nest the argument names under
strict: Whether to eagerly check for extraneous arguments. This may not work well in
cases where a polymorphic field is applied later.
"""
if isinstance(values, Mapping):
self._arg_map.add_layer(Layer(values, layer_name).nest_subpath(subpath))
elif isinstance(values, Blueprint):
if subpath is None:
if values.target is not self.target:
raise ValueError(
f"Cannot apply Blueprint for {type_repr(values.target)} to Blueprint for "
f"{type_repr(self.target)}"
)
for layer in values._arg_map._layers:
self._arg_map.add_layer(layer.nest_subpath(subpath))
else:
raise TypeError(f"Expected dict or Blueprint, got {type(values)}")
if strict:
r = self._make_lazy()
self._arg_map.check_extraneous(
r.used_args, r.all_params.keys(), entrypoint_repr=self.entrypoint_repr
)
return self
def apply_from_argv(
self, argv: list[str], allow_hyphens: bool = False, layer_name: str = "command line"
) -> Blueprint[_T_cov_def]:
"""Apply arguments from argv to this Blueprint."""
values = argv_to_blueprint_args(
[a for a in argv if a != "--help"], allow_hyphens=allow_hyphens
)
self.apply(values, layer_name=layer_name)
if "--help" in argv:
argv.remove("--help")
raise EntrypointHelpException(self.get_help(color=sys.stdout.isatty()))
return self
def _make_lazy(self) -> _MakeResult:
all_params: dict[str, _Param] = {}
used_args: set[tuple[str, int]] = set()
meta_factory_value: dict[str, Any] = {}
missing_params: list[str] = []
value_mapping: dict[str, Evaluatable] | ConstructionIssue | None
value_mapping = _construct_param(
self.param,
"",
self._arg_map,
all_params=all_params,
used_args=used_args,
meta_factory_value=meta_factory_value,
missing_params=missing_params,
)
if isinstance(value_mapping, ConstructionIssue):
raise ConstructionException(value_mapping.issue)
if value_mapping is None:
# value_mapping is None if _construct_param / _construct_unspecified_param
# wants us to fallback to the default or thinks we're missing required arguments
# This is sort of equivalent to what happens in _construct_factory
unspecified_factory = self.meta_factory.unspecified_factory()
if unspecified_factory is None:
raise ConstructionException(
f"Cannot construct {self.target} because it has no unspecified factory"
)
value_mapping = {"": Thunk(unspecified_factory, {})}
if "" in missing_params:
missing_params.remove("")
return _MakeResult(
value_mapping=value_mapping,
all_params=all_params,
used_args=used_args,
meta_factory_value=meta_factory_value,
missing_params=missing_params,
)
def _make_from_make_result(self, r: _MakeResult) -> _T_cov_def:
self._arg_map.check_extraneous(
r.used_args, r.all_params.keys(), entrypoint_repr=self.entrypoint_repr
)
check_reference_targets(r.value_mapping, r.all_params.keys())
# Note we check for extraneous args first, so we get better errors for typos
if r.missing_params:
raise MissingBlueprintArg(
f"Missing required arguments for parameter(s): {', '.join(r.missing_params)}"
)
# __chz_blueprint__
return evaluate(r.value_mapping)
def make(self) -> _T_cov_def:
"""Instantiate or call the target object or callable."""
r = self._make_lazy()
return self._make_from_make_result(r)
def make_from_argv(
self, argv: list[str] | None = None, allow_hyphens: bool = False
) -> _T_cov_def:
"""Like make, but suitable for command line entrypoints.
This will apply arguments from argv to this Blueprint before attempting to make the target.
If "--help" is in argv, this will print help text and exit.
"""
if argv is None:
argv = sys.argv[1:]
self.apply_from_argv(argv, allow_hyphens=allow_hyphens)
return self.make()
def get_help(self, *, color: bool = False) -> str:
"""Get help text for this Blueprint.
Note that applied arguments may affect output here, e.g. in case of polymorphically
constructed fields.
"""
r = self._make_lazy()
f = io.StringIO()
output = functools.partial(print, file=f)
try:
self._arg_map.check_extraneous(
r.used_args, r.all_params.keys(), entrypoint_repr=self.entrypoint_repr
)
except ExtraneousBlueprintArg as e:
output(f"WARNING: {e}\n")
try:
check_reference_targets(r.value_mapping, r.all_params.keys())
except InvalidBlueprintArg as e:
output(f"WARNING: {e}\n")
if r.missing_params:
output(
f"WARNING: Missing required arguments for parameter(s): {', '.join(r.missing_params)}\n"
)
output(f"Entry point: {self.entrypoint_repr}")
output()
if self.param.doc:
output(textwrap.indent(self.param.doc, " "))
output()
param_output = []
for param_path, param in r.all_params.items():
# TODO: cast or meta_factory info, not just type
found_arg = self._arg_map.get_kv(param_path)
if (
not isinstance(self.target, chz.factories.MetaFactory)
and param_path == ""
and found_arg is None
):
# If we're not using root polymorphism, skip this param
continue
if found_arg is None:
if param_path in r.meta_factory_value:
found_arg_str = type_repr(r.meta_factory_value[param_path])
if color:
found_arg_str += " \033[90m(meta_factory)\033[0m"
else:
found_arg_str += " (meta_factory)"
elif param.default is not None:
found_arg_str = param.default.to_help_str()
if color:
found_arg_str += " \033[90m(default)\033[0m"
else:
found_arg_str += " (default)"
elif (
param.meta_factory is not None
and (factory := param.meta_factory.unspecified_factory()) is not None
and factory is not param.type
):
if getattr(factory, "__name__", None) == "<lambda>":
found_arg_str = _lambda_repr(factory) or type_repr(factory)
else:
found_arg_str = type_repr(factory)
if color:
found_arg_str += " \033[90m(blueprint_unspecified)\033[0m"
else:
found_arg_str += " (blueprint_unspecified)"
else:
found_arg_str = "-"
else:
if isinstance(found_arg.value, Castable):
found_arg_str = repr(found_arg.value.value)[1:-1]
elif isinstance(found_arg.value, Reference):
found_arg_str = f"@={found_arg.value.ref}"
else:
found_arg_str = type_repr(found_arg.value)
if found_arg.layer_name:
if color:
found_arg_str += (
f" \033[90m(from \033[94m{found_arg.layer_name}\033[90m)\033[0m"
)
else:
found_arg_str += f" (from {found_arg.layer_name})"
param_output.append(
(param_path or "<entrypoint>", type_repr(param.type), found_arg_str, param.doc)
)
clip = 40
lens = tuple(min(clip, max(map(len, column))) for column in zip(*param_output))
output("Arguments:")
for p, typ, arg, doc in param_output:
col = 0
target_col = 0
line = io.StringIO()
add = functools.partial(print, file=line, end="")
raw_string = p
add(" ")
if color:
add(f"\033[1m{raw_string}\033[0m")
else:
add(raw_string)
col += 2 + len(raw_string)
target_col += 2 + lens[0]
pad = (target_col - col) if col <= target_col else (-len(raw_string)) % 20
add(" " * pad)
col += pad
raw_string = typ
add(" ")
add(raw_string)
col += 2 + len(raw_string)
target_col += 2 + lens[1]
pad = (target_col - col) if col <= target_col else (-len(raw_string)) % 20
add(" " * pad)
col += pad
raw_string = arg
add(" ")
add(raw_string)
col += 2 + len(raw_string)
target_col += 2 + lens[2]
pad = (target_col - col) if col <= target_col else (-len(raw_string)) % 20
add(" " * pad)
col += pad
raw_string = doc
add(" ")
if color:
add(f"\033[90m{raw_string}\033[0m")
else:
add(raw_string)
output(line.getvalue().rstrip())
return f.getvalue()
def _lambda_repr(fn) -> str | None:
try:
src = inspect.getsource(fn).strip()
tree = ast.parse(src)
nodes = [node for node in ast.walk(tree) if isinstance(node, ast.Lambda)]
if len(nodes) != 1:
return None
return ast.unparse(nodes[0])
except Exception:
return None
@dataclass(frozen=True, kw_only=True)
class _Default:
value: Any | _MISSING_TYPE
factory: Callable[..., Any] | _MISSING_TYPE
def to_help_str(self) -> str:
if self.factory is not MISSING:
if getattr(self.factory, "__name__", None) == "<lambda>":
return f"({_lambda_repr(self.factory)})()" or "<default>"
# type_repr works reasonably well for functions too
return f"{type_repr(self.factory)}()"
ret = repr(self.value)
if len(ret) > 40:
return "<default>"
return ret
def instantiate(self) -> Any:
if not isinstance(self.factory, _MISSING_TYPE):
return self.factory()
return self.value
@classmethod
def from_field(cls, field: Field) -> _Default | None:
if field._default is MISSING and field._default_factory is MISSING:
return None
return _Default(value=field._default, factory=field._default_factory)
@classmethod
def from_inspect_param(cls, sigparam: inspect.Parameter) -> _Default | None:
if sigparam.default is sigparam.empty:
return None
return _Default(value=sigparam.default, factory=MISSING)
@dataclass(frozen=True, kw_only=True)
class _Param:
name: str
type: Any
meta_factory: chz.factories.MetaFactory | None
default: _Default | None
doc: str
blueprint_cast: Callable[[str], object] | None
metadata: dict[str, Any]
def cast(self, value: str) -> object:
# If we have a field-level cast, always use that!
if self.blueprint_cast is not None:
return self.blueprint_cast(value)
# If we have a meta_factory, route casting through it. This allows user expectations
# of types that result from casting to better match their expectations from polymorphic
# construction (e.g. using default_cls from chz.factories.subclass)
if self.meta_factory is not None:
return self.meta_factory.perform_cast(value)
# TODO: maybe MetaFactory should have default impl and this should be:
# return chz.factories.MetaFactory().perform_cast(value, self.type)
return _simplistic_try_cast(value, self.type)
def _get_variadic_elements(obj_path: str, arg_map: ArgumentMap) -> set[str]:
elements = set()
for subpath in arg_map.subpaths(obj_path, strict=True):
assert subpath
if subpath[0] != ".":
element = subpath.split(".")[0]
assert element
elements.add(element)
return elements
def _collect_params_from_chz(
obj: Any, obj_path: str, arg_map: ArgumentMap
) -> tuple[list[_Param], Callable[..., Any], list[Any]]:
obj_origin = getattr(obj, "__origin__", obj)
params = []
for field in chz.chz_fields(obj_origin).values():
params.append(
_Param(
name=field.logical_name,
type=field.x_type,
meta_factory=field.meta_factory,
default=_Default.from_field(field),
doc=field._doc,
blueprint_cast=field._blueprint_cast,
metadata=(field.metadata or {}),
)
)
return params, obj, []
def _collect_params_from_sequence(
obj: Any, obj_path: str, arg_map: ArgumentMap
) -> tuple[list[_Param], Callable[..., Any], list[Any]]:
elements = _get_variadic_elements(obj_path, arg_map)
max_element = max((int(e) for e in elements), default=-1)
obj_origin = getattr(obj, "__origin__", obj)
obj_type_construct = obj_origin
type_for_index: Callable[[int], type]
if obj_origin is list:
element_type = getattr(obj, "__args__", [object])[0]
type_for_index = lambda i: element_type
variadic_types = [element_type]
elif obj_origin is collections.abc.Sequence:
element_type = getattr(obj, "__args__", [object])[0]
type_for_index = lambda i: element_type
variadic_types = [element_type]
obj_type_construct = tuple
elif obj_origin is tuple:
args: tuple[Any, ...] = getattr(obj, "__args__", ())
if len(args) == 2 and args[-1] is ...:
# homogeneous tuple
type_for_index = lambda i: args[0]
variadic_types = [args[0]]
else:
# heterogeneous tuple
if max_element >= len(args):
raise TypeError(
f"Tuple type {obj} for {obj_path!r} must take {len(args)} items; "
f"arguments for index {max_element} were specified"
+ (
f". Homogeneous tuples should be typed as tuple[{type_repr(args[0])}, ...] not tuple[{type_repr(args[0])}]"
if len(args) == 1
else ""
)
)
type_for_index = lambda i: args[i]
variadic_types = list(args)
else:
raise AssertionError
params = []
for i in range(max_element + 1):
element_type = type_for_index(i)
params.append(
_Param(
name=str(i),
type=element_type,
meta_factory=chz.factories.standard(element_type),
default=None,
doc="",
blueprint_cast=None,
metadata={},
)
)
def sequence_constructor(**kwargs):
return obj_type_construct(kwargs[str(i)] for i in range(max_element + 1))
obj_constructor = sequence_constructor
return params, obj_constructor, variadic_types
def _collect_params_from_mapping(
obj: Any, obj_path: str, arg_map: ArgumentMap
) -> tuple[list[_Param], Callable[..., Any], list[Any]] | ConstructionIssue:
elements = _get_variadic_elements(obj_path, arg_map)
args: tuple[Any, ...] = getattr(obj, "__args__", ())
if len(args) == 0:
element_type = object
elif len(args) == 2:
if args[0] is not str:
if elements:
raise TypeError(f"Variadic dict type must take str keys, not {type_repr(args[0])}")
return ConstructionIssue(
f"Variadic dict type must take str keys, not {type_repr(args[0])}"
)
element_type = args[1]
else:
raise TypeError(f"Dict type {obj} must take 0 or 2 items")
params = []
for element in elements:
params.append(
_Param(
name=element,
type=element_type,
meta_factory=chz.factories.standard(element_type),
default=None,
doc="",
blueprint_cast=None,
metadata={},
)
)
return params, dict, [element_type]
def _collect_params_from_typed_dict(
obj: Any, obj_path: str, arg_map: ArgumentMap
) -> tuple[list[_Param], Callable[..., Any], list[Any]]:
obj_origin = getattr(obj, "__origin__", obj)
params = []
variadic_types = []
for key, annotation in typing.get_type_hints(obj_origin).items():
required = key in obj_origin.__required_keys__
params.append(
_Param(
name=key,
type=annotation,
meta_factory=chz.factories.standard(annotation),
# Mark the default as NotRequired to improve --help output
# We don't actually use the default values in Blueprint since we let
# instantiation handle insertion of default values
default=(None if required else _Default(value=typing.NotRequired, factory=MISSING)),
doc="",
blueprint_cast=None,
metadata={},
)
)
variadic_types.append(annotation)
return params, obj_origin, variadic_types
def _collect_params_from_callable(
obj: Any, obj_path: str, arg_map: ArgumentMap
) -> tuple[list[_Param], Callable[..., Any], list[Any]] | ConstructionIssue:
# Note you probably don't want to call this if obj is a primitive
try:
signature = inspect.signature(obj)
except ValueError:
return ConstructionIssue(f"Failed to get signature for {obj.__name__}")
obj_constructor = obj
has_pos_only = False
has_pos_or_kwarg = False
elements: set[str] | None = None
params = []
for i, (name, sigparam) in enumerate(signature.parameters.items()):
param_annot: Any
if sigparam.annotation is sigparam.empty:
if i == 0 and "." in obj.__qualname__:
# potentially first parameter of a method, default the annotation to the class
try:
cls = getattr(inspect.getmodule(obj), obj.__qualname__.rsplit(".", 1)[0])
param_annot = cls
except Exception:
param_annot = object
else:
param_annot = object
else:
param_annot = sigparam.annotation
if isinstance(param_annot, str):
param_annot = eval_in_context(param_annot, obj)
if sigparam.kind == sigparam.POSITIONAL_ONLY:
has_pos_only = True
name = str(i)
if sigparam.kind == sigparam.POSITIONAL_OR_KEYWORD:
has_pos_or_kwarg = True
if sigparam.kind == sigparam.VAR_POSITIONAL:
if elements is None:
elements = _get_variadic_elements(obj_path, arg_map)
max_element = max((int(e) for e in elements if e.isdigit()), default=-1)
if has_pos_or_kwarg and max_element >= 0:
return ConstructionIssue(
"Cannot collect parameters with both positional-or-keyword and variadic positional parameters"
)
has_pos_only = True
for j in range(i, max_element + 1):
params.append(
_Param(
name=str(j),
type=param_annot,
meta_factory=chz.factories.standard(param_annot),
default=None,
doc="",
blueprint_cast=None,
metadata={},
)
)
continue
if sigparam.kind == sigparam.VAR_KEYWORD:
if is_kwargs_unpack(param_annot):
if len(param_annot.__args__) != 1 or not is_typed_dict(param_annot.__args__[0]):
return ConstructionIssue(
f"Cannot collect parameters from {obj.__name__}, expected Unpack[TypedDict], not {param_annot}"
)
for key, annotation in typing.get_type_hints(param_annot.__args__[0]).items():
# TODO: handle total=False and PEP 655
params.append(
_Param(
name=key,
type=annotation,
meta_factory=chz.factories.standard(annotation),
default=None,
doc="",
blueprint_cast=None,
metadata={},
)
)
else:
if elements is None:
elements = _get_variadic_elements(obj_path, arg_map)
for element in elements - {p.name for p in params}:
params.append(
_Param(
name=element,
type=param_annot,
meta_factory=chz.factories.standard(param_annot),
default=None,
doc="",
blueprint_cast=None,
metadata={},
)
)
continue
# It could be interesting to let function defaults be chz.Field :-)
# TODO: could be fun to parse function docstring
params.append(
_Param(
name=name,
type=param_annot,
meta_factory=chz.factories.standard(param_annot),
default=_Default.from_inspect_param(sigparam),
doc="",
blueprint_cast=None,
metadata={},
)
)
if has_pos_only:
def positional_constructor(**kwargs):
a = []
kw = {}
for k, v in kwargs.items():
if k.isdigit():
a.append((int(k), v))
else:
kw[k] = v
a = [v for _, v in sorted(a)]
return obj(*a, **kw)
obj_constructor = positional_constructor
# Note params may be empty here if obj doesn't take any parameters.
# This is usually okay, but has some interaction with fully defaulted subcomponents.
# See test_nested_all_defaults and variants
return params, obj_constructor, []
def _collect_params(
obj: Any, obj_path: str, arg_map: ArgumentMap
) -> (
ConstructionIssue
| tuple[
list[_Param], # params discovered
Callable[..., Any], # constructor to call
list[Any], # vaguely like [p.type for p in params], used only for sanity checking
]
):
obj_origin = getattr(obj, "__origin__", obj)
if chz.is_chz(obj_origin):
return _collect_params_from_chz(obj, obj_path, arg_map)
if isinstance(obj, functools.partial) and chz.is_chz(obj.func):
if obj.args:
return ConstructionIssue(
f"Cannot collect parameters from partial function of chz class "
f"{type_repr(obj.func)} with positional arguments"
)
result = _collect_params(obj.func, obj_path, arg_map)
if isinstance(result, ConstructionIssue):
return result
params, _constructor, variadic_types = result
new_params = []
for param in params:
if param.name in obj.keywords:
# The actual value of the default should only matter for --help output
param = dataclasses.replace(
param, default=_Default(value=obj.keywords[param.name], factory=MISSING)
)
new_params.append(param)
return new_params, obj, variadic_types
if obj_origin in {list, tuple, collections.abc.Sequence}:
return _collect_params_from_sequence(obj, obj_path, arg_map)
if obj_origin in {dict, collections.abc.Mapping}:
return _collect_params_from_mapping(obj, obj_path, arg_map)
if is_typed_dict(obj_origin):
return _collect_params_from_typed_dict(obj, obj_path, arg_map)
if obj_origin in {bool, int, float, str, bytes, None, type(None)}:
return ConstructionIssue("Cannot collect parameters from primitive")
if "enum" in sys.modules:
import enum
if type(obj) is enum.EnumMeta:
return ConstructionIssue("Cannot collect parameters from Enum class")
if callable(obj):
return _collect_params_from_callable(obj, obj_path, arg_map)
return ConstructionIssue(
f"Could not collect parameters to construct {obj} of type {type_repr(obj)}"
)
_K = TypeVar("_K")
_V = TypeVar("_V", contravariant=True)
class _WriteOnlyMapping(Generic[_K, _V], Protocol):
def __setitem__(self, __key: _K, __value: _V, /) -> None: ...
def update(self, __m: Mapping[_K, _V], /) -> None: ...
class ConstructionIssue:
def __init__(self, issue: str) -> None:
self.issue = issue
def _construct_factory(
obj: Callable[..., _T],
obj_path: str,
arg_map: ArgumentMap,
*,
# Output parameters, do not use within this function
# Typing these as write-only should help prevent accidental unsound use within this function
# See _MakeResult for docs about these parameters
all_params: _WriteOnlyMapping[str, _Param],
used_args: set[tuple[str, int]],
meta_factory_value: _WriteOnlyMapping[str, Any],
missing_params: list[str],
) -> dict[str, Evaluatable] | ConstructionIssue:
result = _collect_params(obj, obj_path, arg_map)
del obj
if isinstance(result, ConstructionIssue):
return result
params, obj_constructor, _ = result
# Ideas:
# TODO: Allow automatically accessing any attribute on parent for factories?
# This eases the responsibility of getting the right structure for the config and could mean
# we don't need to support wildcards? Reduces problems of something not getting specified
# correctly.
# "If you need a value, move it one level up."
# TODO: Allow factories to return blueprints? This would allow for better presets, e.g. you
# could do model=d4-dev model.n_layers=5
kwargs: dict[str, ParamRef] = {}
value_mapping: dict[str, Evaluatable] = {}
for param in params:
sub_value_mapping = _construct_param(
param,
obj_path,
arg_map,
all_params=all_params,
used_args=used_args,
meta_factory_value=meta_factory_value,
missing_params=missing_params,
)
if isinstance(sub_value_mapping, ConstructionIssue):
return sub_value_mapping
if sub_value_mapping is None:
continue
param_path = (obj_path + "." if obj_path else "") + param.name
value_mapping.update(sub_value_mapping)
kwargs[param.name] = ParamRef(param_path)
value_mapping[obj_path] = Thunk(obj_constructor, kwargs)
return value_mapping
def _construct_unspecified_param(
param: _Param,
*,
param_path: str,
arg_map: ArgumentMap,
# Output parameters, do not use within this function
# See _MakeResult for docs about these parameters
all_params: _WriteOnlyMapping[str, _Param],
used_args: set[tuple[str, int]],
meta_factory_value: _WriteOnlyMapping[str, Any],
missing_params: list[str],
) -> dict[str, Evaluatable] | ConstructionIssue | None:
if (
param.meta_factory is not None
and (factory := param.meta_factory.unspecified_factory()) is not None
):
assert callable(factory)
sub_all_params: dict[str, _Param] = {}
sub_missing_params: list[str] = []
sub_used_args: set[tuple[str, int]] = set()
sub_meta_factory_value: dict[str, Any] = {}
value_mapping = _construct_factory(
factory,
param_path,
arg_map,
all_params=sub_all_params,
used_args=sub_used_args,
meta_factory_value=sub_meta_factory_value,
missing_params=sub_missing_params,
)
all_params.update(sub_all_params)
used_args.update(sub_used_args) # TODO: should this be gated by use?
meta_factory_value.update(sub_meta_factory_value)
if isinstance(value_mapping, ConstructionIssue):
if param_path == "" and param.default is None:
# This is a little bit special case-y. But if we have a construction issue with
# the root param, it's much better to forward it than for the user to get an error
# about a missing required root argument.
return value_mapping
else:
thunk = value_mapping[param_path]
assert isinstance(thunk, Thunk)
# Only do this if we have subcomponents specified (including wildcards)
if thunk.kwargs:
meta_factory_value[param_path] = factory
missing_params.extend(sub_missing_params)
return value_mapping
# Alternatively, if a) we do not have a default, b) we're making a chz object
# c) we know instantiation would always work, that's fine too.
# A little special-case-y, but somewhat sane. It turns something that would
# error due to lack of default into something reasonable.
# See test_nested_all_defaults and variants
if (
param.default is None
and (
chz.is_chz(factory)
or (isinstance(factory, functools.partial) and chz.is_chz(factory.func))
)
and all(p.default is not None for p in sub_all_params.values())
):
assert not sub_missing_params
meta_factory_value[param_path] = factory
return value_mapping
# If we have a default, make sure we don't extend missing_params
if param.default is None:
if sub_missing_params:
missing_params.extend(sub_missing_params)
else:
# Happens if we collect no params, like non-chz field or variadics
missing_params.append(param_path)
else:
# If we have a default, do some validation about wildcards + variadics
_check_for_wildcard_matching_variadic_top_level(factory, param, param_path, arg_map)
return None
assert not sub_missing_params
# If we have no subcomponents specified or we have no factory, we don't add any kwargs
# When the object is created, this will be equivalent to:
# `attr = default` or `attr = default_factory()`
# If there is no default, we will raise MissingBlueprintArg, instead of relying on the
# normal Python error during instantiation. We also rely on raising ExtraneousBlueprintArg
# if there are arguments that go unused.
# (In the case of Blueprint implementation bugs, if we're missing a param, __init__ will
# have our back, but the extraneous logic has no backup)
if param.default is None:
missing_params.append(param_path)
return None
def _construct_param(
param: _Param,
obj_path: str,
arg_map: ArgumentMap,
*,
# Output parameters, do not use within this function
# See _MakeResult for docs about these parameters
all_params: _WriteOnlyMapping[str, _Param],
used_args: set[tuple[str, int]],
meta_factory_value: _WriteOnlyMapping[str, Any],
missing_params: list[str],
) -> dict[str, Evaluatable] | ConstructionIssue | None:
# Returns None if we don't need to pass any value. This doesn't mean there's an error,
# we might simply want the default or default_factory value.
param_path: Final = (obj_path + "." if obj_path else "") + param.name
all_params[param_path] = param
found_arg = arg_map.get_kv(param_path)
# If nothing is specified, check if we have a factory we can feed subcomponents to and if there
# are specified subcomponents we could feed to it. Otherwise, if a default or default_factory
# exists, we'll use that.
if found_arg is None:
return _construct_unspecified_param(
param,
param_path=param_path,
arg_map=arg_map,
all_params=all_params,
used_args=used_args,
meta_factory_value=meta_factory_value,
missing_params=missing_params,
)
used_args.add((found_arg.key, found_arg.layer_index))
spec: object = found_arg.value
# Something is specified, so we must either add something to kwargs or error out
# If something is specified, and is of the expected type, we just assign it:
# `attr = spec`
if not isinstance(spec, SpecialArg) and is_subtype_instance(spec, param.type):
# (ignore SpecialArg's here, in case param.type is object)
if not (param.meta_factory is not None and arg_map.subpaths(param_path, strict=True)):
# TODO: deep copy?
return {param_path: Value(spec)}
# Otherwise, we see if we can cast it to the expected type:
# `attr = trycast(spec.value, param.type)`
if isinstance(spec, Castable):
# If we have a meta_factory and we have args that are prefixed with the param path, we
# will always want to construct that (if we successfully casted here when subcomponents
# are specified, we'd just fail later because those subcomponents would be extraneous)
if not (param.meta_factory is not None and arg_map.subpaths(param_path, strict=True)):
try:
casted_spec = param.cast(spec.value)
return {param_path: Value(casted_spec)}
except CastError:
pass
# ..or if it's a Reference to some other parameter
if isinstance(spec, Reference):
if spec.ref == param_path:
# If it's a self reference, treat it as if it were unspecified
value_mapping = _construct_unspecified_param(
param,
param_path=param_path,
arg_map=arg_map,
all_params=all_params,
used_args=used_args,
meta_factory_value=meta_factory_value,
missing_params=missing_params,
)
if isinstance(value_mapping, ConstructionIssue):
return value_mapping
if value_mapping is None and param.default is not None:
# See test_blueprint_reference_wildcard_default
# TODO: this is the only place we instantiate a default
default = param.default.instantiate()
return {param_path: Value(default)}
return value_mapping
return {param_path: ParamRef(spec.ref)}
# Otherwise, see if it's something that can construct the expected type. For instance,
# maybe it's a subclass of param.type, or more generally any `Callable[..., param.type]`,
# in which case we do:
# `attr = spec(...)`
factory: Callable[..., Any]
if is_subtype_instance(spec, Callable[..., param.type]):
assert callable(spec)
factory = spec
value_mapping = _construct_factory(
factory,
param_path,
arg_map,
all_params=all_params,
used_args=used_args,
meta_factory_value=meta_factory_value,
missing_params=missing_params,
)
if isinstance(value_mapping, ConstructionIssue):
return value_mapping
meta_factory_value[param_path] = factory
return value_mapping
# Otherwise, see if it's something that can be casted into something that can construct
# the expected type. For instance, maybe it's a string that's the name of a subclass of
# param.type or "module:func" where module.func is a `func: Callable[..., param.type]`.
# `attr = trycast(spec, constructor_type)(...)`
if isinstance(spec, Castable):
if param.meta_factory is not None:
try:
factory = param.meta_factory.from_string(spec.value)
except chz.factories.MetaFromString as e:
cast_error = None
try:
param.cast(spec.value)
except CastError as e2:
cast_error = str(e2)
if cast_error is None:
subpaths = arg_map.subpaths(param_path, strict=True)
assert subpaths
cast_error = f"Not a value, since subparameters were provided (e.g. {join_arg_path(param_path, subpaths[0])!r})"
raise InvalidBlueprintArg(
f"Could not interpret argument {spec.value!r} provided for param {param_path!r}...\n\n"
f"- Failed to interpret it as a value:\n{cast_error}\n\n"
f"- Failed to interpret it as a factory for polymorphic construction:\n{e}"
) from None
assert callable(factory)
value_mapping = _construct_factory(
factory,
param_path,
arg_map,
all_params=all_params,
used_args=used_args,
meta_factory_value=meta_factory_value,
missing_params=missing_params,
)
if isinstance(value_mapping, ConstructionIssue):
return value_mapping
meta_factory_value[param_path] = factory
return value_mapping
# This cast is just to raise the error we caught previously
try:
param.cast(spec.value)
except CastError as e:
raise InvalidBlueprintArg(
f"Could not cast {spec.value!r} to {type_repr(param.type)}:\n{e}"
) from e
# This next line should be unreachable...
raise TypeError(
f"Expected {param_path!r} to be castable to {type_repr(param.type)}, got {spec.value!r}"
)
if not isinstance(spec, SpecialArg) and is_subtype_instance(spec, param.type):
if param.meta_factory is not None:
subpaths = arg_map.subpaths(param_path, strict=True)
if subpaths:
raise InvalidBlueprintArg(
f"Could not interpret {spec!r} provided for param {param_path!r} "
f"as a value, since subparameters were provided "
f"(e.g. {join_arg_path(param_path, subpaths[0])!r})"
)
raise TypeError(
f"Expected {param_path!r} to be {type_repr(param.type)}, got {type_repr(_simplistic_type_of_value(spec))}"
)
def _check_for_wildcard_matching_variadic_top_level(
obj: object, param: _Param, obj_path: str, arg_map: ArgumentMap
):
assert param.default is not None
if (
type(param.default.value) is tuple and param.default.value == ()
) or param.default.factory in {tuple, list, dict}:
return
result = _collect_params(obj, obj_path, arg_map)
if isinstance(result, ConstructionIssue):
return
variadic_params, _, variadic_types = result
if variadic_params:
return
if isinstance(param.default.value, (tuple, list)):
variadic_types = list(
set(variadic_types) | {type(element) for element in param.default.value}
)
elif isinstance(param.default.value, dict):
variadic_types = list(
set(variadic_types) | {type(element) for element in param.default.value.values()}
)
if not variadic_types:
return
# The case we're checking here is if we:
# 1) have a param with a default
# 1.5) the default is not an empty tuple or list or dict
# 2) have a variadic factory for that param
# 3) we do not find any variadic params
# Then we check if any wildcards would have matched a param if we had one, since it can be
# unintuitive that the default will not be affected by the wildcard (default / default_factory
# are opaque and have no interaction with wildcards beyond their presence or absence).
# See test_variadic_default_wildcard_error
for element_type in variadic_types:
result = _collect_params(element_type, obj_path + ".__chz_empty_variadic", arg_map)
if isinstance(result, ConstructionIssue):
continue
subparams, _, _ = result
for subparam in subparams:
param_path = obj_path + ".__chz_empty_variadic." + subparam.name
found_arg = arg_map.get_kv(param_path)
param_path = obj_path + ".(variadic)." + subparam.name
if found_arg is not None:
raise ConstructionException(
f"\n\nYou've hit an interesting case.\n\n"
f'The parameter "{obj_path}" is variadic ({type_repr(obj)}), but no '
"parametrisation was found (either variadic subparameters or a polymorphic "
"parametrisation).\n"
f'This is fine in theory, because "{obj_path}" has a '
f"default value.\n\n"
f'However, you also specified the wildcard "{found_arg.key}" and you may '
f'have expected it to modify the value of "{param_path}".\n'
"This is not possible -- default values / default_factory results are "
"opaque to chz. "
"The only way in which default / default_factory interact with Blueprint "
"is presence / absence. So out of caution, here's an error!\n\n"
"If this error is a false positive, consider scoping the wildcard more "
"narrowly or using exact keys. As always, appending --help to a chz command "
"will show you what gets mapped to which param."
)