chz/blueprint/_blueprint.py (952 lines of code) (raw):

from __future__ import annotations import ast import collections.abc import dataclasses import functools import inspect import io import sys import textwrap import typing from dataclasses import dataclass from typing import Any, Callable, Final, Generic, Mapping, Protocol from typing_extensions import TypeVar import chz from chz.blueprint._argmap import ArgumentMap, Layer, join_arg_path from chz.blueprint._argv import argv_to_blueprint_args from chz.blueprint._entrypoint import ( ConstructionException, EntrypointHelpException, ExtraneousBlueprintArg, InvalidBlueprintArg, MissingBlueprintArg, ) from chz.blueprint._lazy import ( Evaluatable, ParamRef, Thunk, Value, check_reference_targets, evaluate, ) from chz.field import Field from chz.tiepin import ( CastError, _simplistic_try_cast, _simplistic_type_of_value, eval_in_context, is_kwargs_unpack, is_subtype_instance, is_typed_dict, type_repr, ) from chz.util import _MISSING_TYPE, MISSING _T = TypeVar("_T") _T_cov_def = TypeVar("_T_cov_def", covariant=True, default=Any) class SpecialArg: ... class Castable(SpecialArg): """A wrapper class for str if you want a Blueprint value to be magically type aware casted.""" def __init__(self, value: str) -> None: self.value = value def __repr__(self) -> str: return f"Castable({self.value!r})" def __hash__(self) -> int: return hash(self.value) def __eq__(self, other: object) -> bool: if not isinstance(other, Castable): try: return _simplistic_try_cast(self.value, type(other)) == other except CastError: return False return self.value == other.value class Reference(SpecialArg): """A reference to another parameter in a Blueprint.""" def __init__(self, ref: str) -> None: if "..." in ref: raise ValueError("Cannot use wildcard as a reference target") self.ref = ref def __repr__(self) -> str: return f"Reference({self.ref!r})" @dataclass(frozen=True) class _MakeResult: # `value_mapping` is a dictionary mapping from parameter paths to Evaluatable values. # This ultimately contains all the kinds of values we will use in instantiation. # See chz.blueprint._lazy.evaluate for an example of using Evaluatable. value_mapping: dict[str, Evaluatable] # `all_params` is a dictionary containing all parameters we discover, mapping from that param # path to the parameter. Note what parameters we discover will depend on polymorphic # construction via meta_factories. We use all_params to provide a useful --help (and various # other things, e.g. detect clobbering when checking for extraneous arguments) all_params: dict[str, _Param] # `used_args` is a set of (key, layer_index) tuples that we use to track which arguments from # arg_map we've used. We use this to check for extraneous arguments. used_args: set[tuple[str, int]] # `meta_factory_value` records what meta_factory we're using. This makes --help more # understandable in the presence of polymorphism, especially when factories come from # blueprint_unspecified. It's conceptually the same information as in Thunk.fn in value_mapping, # but preserves user input for variadics or generics (instead of being a constructor function) meta_factory_value: dict[str, Any] # `missing_params` is a list of parameters we know need are required but haven't been # specified. In theory, this is unnecessary because `__init__` will raise an error if # a required param is missing, but this improves diagnostics. missing_params: list[str] def _entrypoint_caster(o: str) -> object: raise chz.tiepin.CastError("Will not interpret entrypoint as a value") class Blueprint(Generic[_T_cov_def]): def __init__( self, target: chz.factories.MetaFactory | type[_T_cov_def] | Callable[..., _T_cov_def] ) -> None: """Instantiate a Blueprint. Args: target: The target object or callable we will instantiate or call. """ self.target = target if isinstance(target, chz.factories.MetaFactory): self.meta_factory = target if isinstance(target, chz.factories.standard): entrypoint_type = target.annotation entrypoint_doc = getattr(entrypoint_type, "__doc__", "") else: entrypoint_type = object entrypoint_doc = "" self.entrypoint_repr = type_repr(entrypoint_type) else: self.meta_factory = chz.factories.standard(target) entrypoint_type = target if self.meta_factory.unspecified_factory() is None: if not callable(target): raise ValueError(f"{target} is not callable") self.meta_factory = chz.factories.standard(object, unspecified=target) entrypoint_type = object self.entrypoint_repr = type_repr(target) entrypoint_doc = getattr(target, "__doc__", "") self.param = _Param( name="", type=entrypoint_type, meta_factory=self.meta_factory, default=None, doc=entrypoint_doc.strip() if entrypoint_doc else "", blueprint_cast=_entrypoint_caster, metadata={}, ) self._arg_map = ArgumentMap([]) def clone(self) -> Blueprint[_T_cov_def]: """Make a copy of this Blueprint.""" return Blueprint(self.target).apply(self) def apply( self, values: Blueprint[_T_cov_def] | Mapping[str, Any], layer_name: str | None = None, *, subpath: str | None = None, strict: bool = False, ) -> Blueprint[_T_cov_def]: """Modify this Blueprint by partially applying some arguments. Args: values: The values to partially apply to this Blueprint layer_name: The name of the layer to add (allows identification of the source of values) subpath: A subpath to nest the argument names under strict: Whether to eagerly check for extraneous arguments. This may not work well in cases where a polymorphic field is applied later. """ if isinstance(values, Mapping): self._arg_map.add_layer(Layer(values, layer_name).nest_subpath(subpath)) elif isinstance(values, Blueprint): if subpath is None: if values.target is not self.target: raise ValueError( f"Cannot apply Blueprint for {type_repr(values.target)} to Blueprint for " f"{type_repr(self.target)}" ) for layer in values._arg_map._layers: self._arg_map.add_layer(layer.nest_subpath(subpath)) else: raise TypeError(f"Expected dict or Blueprint, got {type(values)}") if strict: r = self._make_lazy() self._arg_map.check_extraneous( r.used_args, r.all_params.keys(), entrypoint_repr=self.entrypoint_repr ) return self def apply_from_argv( self, argv: list[str], allow_hyphens: bool = False, layer_name: str = "command line" ) -> Blueprint[_T_cov_def]: """Apply arguments from argv to this Blueprint.""" values = argv_to_blueprint_args( [a for a in argv if a != "--help"], allow_hyphens=allow_hyphens ) self.apply(values, layer_name=layer_name) if "--help" in argv: argv.remove("--help") raise EntrypointHelpException(self.get_help(color=sys.stdout.isatty())) return self def _make_lazy(self) -> _MakeResult: all_params: dict[str, _Param] = {} used_args: set[tuple[str, int]] = set() meta_factory_value: dict[str, Any] = {} missing_params: list[str] = [] value_mapping: dict[str, Evaluatable] | ConstructionIssue | None value_mapping = _construct_param( self.param, "", self._arg_map, all_params=all_params, used_args=used_args, meta_factory_value=meta_factory_value, missing_params=missing_params, ) if isinstance(value_mapping, ConstructionIssue): raise ConstructionException(value_mapping.issue) if value_mapping is None: # value_mapping is None if _construct_param / _construct_unspecified_param # wants us to fallback to the default or thinks we're missing required arguments # This is sort of equivalent to what happens in _construct_factory unspecified_factory = self.meta_factory.unspecified_factory() if unspecified_factory is None: raise ConstructionException( f"Cannot construct {self.target} because it has no unspecified factory" ) value_mapping = {"": Thunk(unspecified_factory, {})} if "" in missing_params: missing_params.remove("") return _MakeResult( value_mapping=value_mapping, all_params=all_params, used_args=used_args, meta_factory_value=meta_factory_value, missing_params=missing_params, ) def _make_from_make_result(self, r: _MakeResult) -> _T_cov_def: self._arg_map.check_extraneous( r.used_args, r.all_params.keys(), entrypoint_repr=self.entrypoint_repr ) check_reference_targets(r.value_mapping, r.all_params.keys()) # Note we check for extraneous args first, so we get better errors for typos if r.missing_params: raise MissingBlueprintArg( f"Missing required arguments for parameter(s): {', '.join(r.missing_params)}" ) # __chz_blueprint__ return evaluate(r.value_mapping) def make(self) -> _T_cov_def: """Instantiate or call the target object or callable.""" r = self._make_lazy() return self._make_from_make_result(r) def make_from_argv( self, argv: list[str] | None = None, allow_hyphens: bool = False ) -> _T_cov_def: """Like make, but suitable for command line entrypoints. This will apply arguments from argv to this Blueprint before attempting to make the target. If "--help" is in argv, this will print help text and exit. """ if argv is None: argv = sys.argv[1:] self.apply_from_argv(argv, allow_hyphens=allow_hyphens) return self.make() def get_help(self, *, color: bool = False) -> str: """Get help text for this Blueprint. Note that applied arguments may affect output here, e.g. in case of polymorphically constructed fields. """ r = self._make_lazy() f = io.StringIO() output = functools.partial(print, file=f) try: self._arg_map.check_extraneous( r.used_args, r.all_params.keys(), entrypoint_repr=self.entrypoint_repr ) except ExtraneousBlueprintArg as e: output(f"WARNING: {e}\n") try: check_reference_targets(r.value_mapping, r.all_params.keys()) except InvalidBlueprintArg as e: output(f"WARNING: {e}\n") if r.missing_params: output( f"WARNING: Missing required arguments for parameter(s): {', '.join(r.missing_params)}\n" ) output(f"Entry point: {self.entrypoint_repr}") output() if self.param.doc: output(textwrap.indent(self.param.doc, " ")) output() param_output = [] for param_path, param in r.all_params.items(): # TODO: cast or meta_factory info, not just type found_arg = self._arg_map.get_kv(param_path) if ( not isinstance(self.target, chz.factories.MetaFactory) and param_path == "" and found_arg is None ): # If we're not using root polymorphism, skip this param continue if found_arg is None: if param_path in r.meta_factory_value: found_arg_str = type_repr(r.meta_factory_value[param_path]) if color: found_arg_str += " \033[90m(meta_factory)\033[0m" else: found_arg_str += " (meta_factory)" elif param.default is not None: found_arg_str = param.default.to_help_str() if color: found_arg_str += " \033[90m(default)\033[0m" else: found_arg_str += " (default)" elif ( param.meta_factory is not None and (factory := param.meta_factory.unspecified_factory()) is not None and factory is not param.type ): if getattr(factory, "__name__", None) == "<lambda>": found_arg_str = _lambda_repr(factory) or type_repr(factory) else: found_arg_str = type_repr(factory) if color: found_arg_str += " \033[90m(blueprint_unspecified)\033[0m" else: found_arg_str += " (blueprint_unspecified)" else: found_arg_str = "-" else: if isinstance(found_arg.value, Castable): found_arg_str = repr(found_arg.value.value)[1:-1] elif isinstance(found_arg.value, Reference): found_arg_str = f"@={found_arg.value.ref}" else: found_arg_str = type_repr(found_arg.value) if found_arg.layer_name: if color: found_arg_str += ( f" \033[90m(from \033[94m{found_arg.layer_name}\033[90m)\033[0m" ) else: found_arg_str += f" (from {found_arg.layer_name})" param_output.append( (param_path or "<entrypoint>", type_repr(param.type), found_arg_str, param.doc) ) clip = 40 lens = tuple(min(clip, max(map(len, column))) for column in zip(*param_output)) output("Arguments:") for p, typ, arg, doc in param_output: col = 0 target_col = 0 line = io.StringIO() add = functools.partial(print, file=line, end="") raw_string = p add(" ") if color: add(f"\033[1m{raw_string}\033[0m") else: add(raw_string) col += 2 + len(raw_string) target_col += 2 + lens[0] pad = (target_col - col) if col <= target_col else (-len(raw_string)) % 20 add(" " * pad) col += pad raw_string = typ add(" ") add(raw_string) col += 2 + len(raw_string) target_col += 2 + lens[1] pad = (target_col - col) if col <= target_col else (-len(raw_string)) % 20 add(" " * pad) col += pad raw_string = arg add(" ") add(raw_string) col += 2 + len(raw_string) target_col += 2 + lens[2] pad = (target_col - col) if col <= target_col else (-len(raw_string)) % 20 add(" " * pad) col += pad raw_string = doc add(" ") if color: add(f"\033[90m{raw_string}\033[0m") else: add(raw_string) output(line.getvalue().rstrip()) return f.getvalue() def _lambda_repr(fn) -> str | None: try: src = inspect.getsource(fn).strip() tree = ast.parse(src) nodes = [node for node in ast.walk(tree) if isinstance(node, ast.Lambda)] if len(nodes) != 1: return None return ast.unparse(nodes[0]) except Exception: return None @dataclass(frozen=True, kw_only=True) class _Default: value: Any | _MISSING_TYPE factory: Callable[..., Any] | _MISSING_TYPE def to_help_str(self) -> str: if self.factory is not MISSING: if getattr(self.factory, "__name__", None) == "<lambda>": return f"({_lambda_repr(self.factory)})()" or "<default>" # type_repr works reasonably well for functions too return f"{type_repr(self.factory)}()" ret = repr(self.value) if len(ret) > 40: return "<default>" return ret def instantiate(self) -> Any: if not isinstance(self.factory, _MISSING_TYPE): return self.factory() return self.value @classmethod def from_field(cls, field: Field) -> _Default | None: if field._default is MISSING and field._default_factory is MISSING: return None return _Default(value=field._default, factory=field._default_factory) @classmethod def from_inspect_param(cls, sigparam: inspect.Parameter) -> _Default | None: if sigparam.default is sigparam.empty: return None return _Default(value=sigparam.default, factory=MISSING) @dataclass(frozen=True, kw_only=True) class _Param: name: str type: Any meta_factory: chz.factories.MetaFactory | None default: _Default | None doc: str blueprint_cast: Callable[[str], object] | None metadata: dict[str, Any] def cast(self, value: str) -> object: # If we have a field-level cast, always use that! if self.blueprint_cast is not None: return self.blueprint_cast(value) # If we have a meta_factory, route casting through it. This allows user expectations # of types that result from casting to better match their expectations from polymorphic # construction (e.g. using default_cls from chz.factories.subclass) if self.meta_factory is not None: return self.meta_factory.perform_cast(value) # TODO: maybe MetaFactory should have default impl and this should be: # return chz.factories.MetaFactory().perform_cast(value, self.type) return _simplistic_try_cast(value, self.type) def _get_variadic_elements(obj_path: str, arg_map: ArgumentMap) -> set[str]: elements = set() for subpath in arg_map.subpaths(obj_path, strict=True): assert subpath if subpath[0] != ".": element = subpath.split(".")[0] assert element elements.add(element) return elements def _collect_params_from_chz( obj: Any, obj_path: str, arg_map: ArgumentMap ) -> tuple[list[_Param], Callable[..., Any], list[Any]]: obj_origin = getattr(obj, "__origin__", obj) params = [] for field in chz.chz_fields(obj_origin).values(): params.append( _Param( name=field.logical_name, type=field.x_type, meta_factory=field.meta_factory, default=_Default.from_field(field), doc=field._doc, blueprint_cast=field._blueprint_cast, metadata=(field.metadata or {}), ) ) return params, obj, [] def _collect_params_from_sequence( obj: Any, obj_path: str, arg_map: ArgumentMap ) -> tuple[list[_Param], Callable[..., Any], list[Any]]: elements = _get_variadic_elements(obj_path, arg_map) max_element = max((int(e) for e in elements), default=-1) obj_origin = getattr(obj, "__origin__", obj) obj_type_construct = obj_origin type_for_index: Callable[[int], type] if obj_origin is list: element_type = getattr(obj, "__args__", [object])[0] type_for_index = lambda i: element_type variadic_types = [element_type] elif obj_origin is collections.abc.Sequence: element_type = getattr(obj, "__args__", [object])[0] type_for_index = lambda i: element_type variadic_types = [element_type] obj_type_construct = tuple elif obj_origin is tuple: args: tuple[Any, ...] = getattr(obj, "__args__", ()) if len(args) == 2 and args[-1] is ...: # homogeneous tuple type_for_index = lambda i: args[0] variadic_types = [args[0]] else: # heterogeneous tuple if max_element >= len(args): raise TypeError( f"Tuple type {obj} for {obj_path!r} must take {len(args)} items; " f"arguments for index {max_element} were specified" + ( f". Homogeneous tuples should be typed as tuple[{type_repr(args[0])}, ...] not tuple[{type_repr(args[0])}]" if len(args) == 1 else "" ) ) type_for_index = lambda i: args[i] variadic_types = list(args) else: raise AssertionError params = [] for i in range(max_element + 1): element_type = type_for_index(i) params.append( _Param( name=str(i), type=element_type, meta_factory=chz.factories.standard(element_type), default=None, doc="", blueprint_cast=None, metadata={}, ) ) def sequence_constructor(**kwargs): return obj_type_construct(kwargs[str(i)] for i in range(max_element + 1)) obj_constructor = sequence_constructor return params, obj_constructor, variadic_types def _collect_params_from_mapping( obj: Any, obj_path: str, arg_map: ArgumentMap ) -> tuple[list[_Param], Callable[..., Any], list[Any]] | ConstructionIssue: elements = _get_variadic_elements(obj_path, arg_map) args: tuple[Any, ...] = getattr(obj, "__args__", ()) if len(args) == 0: element_type = object elif len(args) == 2: if args[0] is not str: if elements: raise TypeError(f"Variadic dict type must take str keys, not {type_repr(args[0])}") return ConstructionIssue( f"Variadic dict type must take str keys, not {type_repr(args[0])}" ) element_type = args[1] else: raise TypeError(f"Dict type {obj} must take 0 or 2 items") params = [] for element in elements: params.append( _Param( name=element, type=element_type, meta_factory=chz.factories.standard(element_type), default=None, doc="", blueprint_cast=None, metadata={}, ) ) return params, dict, [element_type] def _collect_params_from_typed_dict( obj: Any, obj_path: str, arg_map: ArgumentMap ) -> tuple[list[_Param], Callable[..., Any], list[Any]]: obj_origin = getattr(obj, "__origin__", obj) params = [] variadic_types = [] for key, annotation in typing.get_type_hints(obj_origin).items(): required = key in obj_origin.__required_keys__ params.append( _Param( name=key, type=annotation, meta_factory=chz.factories.standard(annotation), # Mark the default as NotRequired to improve --help output # We don't actually use the default values in Blueprint since we let # instantiation handle insertion of default values default=(None if required else _Default(value=typing.NotRequired, factory=MISSING)), doc="", blueprint_cast=None, metadata={}, ) ) variadic_types.append(annotation) return params, obj_origin, variadic_types def _collect_params_from_callable( obj: Any, obj_path: str, arg_map: ArgumentMap ) -> tuple[list[_Param], Callable[..., Any], list[Any]] | ConstructionIssue: # Note you probably don't want to call this if obj is a primitive try: signature = inspect.signature(obj) except ValueError: return ConstructionIssue(f"Failed to get signature for {obj.__name__}") obj_constructor = obj has_pos_only = False has_pos_or_kwarg = False elements: set[str] | None = None params = [] for i, (name, sigparam) in enumerate(signature.parameters.items()): param_annot: Any if sigparam.annotation is sigparam.empty: if i == 0 and "." in obj.__qualname__: # potentially first parameter of a method, default the annotation to the class try: cls = getattr(inspect.getmodule(obj), obj.__qualname__.rsplit(".", 1)[0]) param_annot = cls except Exception: param_annot = object else: param_annot = object else: param_annot = sigparam.annotation if isinstance(param_annot, str): param_annot = eval_in_context(param_annot, obj) if sigparam.kind == sigparam.POSITIONAL_ONLY: has_pos_only = True name = str(i) if sigparam.kind == sigparam.POSITIONAL_OR_KEYWORD: has_pos_or_kwarg = True if sigparam.kind == sigparam.VAR_POSITIONAL: if elements is None: elements = _get_variadic_elements(obj_path, arg_map) max_element = max((int(e) for e in elements if e.isdigit()), default=-1) if has_pos_or_kwarg and max_element >= 0: return ConstructionIssue( "Cannot collect parameters with both positional-or-keyword and variadic positional parameters" ) has_pos_only = True for j in range(i, max_element + 1): params.append( _Param( name=str(j), type=param_annot, meta_factory=chz.factories.standard(param_annot), default=None, doc="", blueprint_cast=None, metadata={}, ) ) continue if sigparam.kind == sigparam.VAR_KEYWORD: if is_kwargs_unpack(param_annot): if len(param_annot.__args__) != 1 or not is_typed_dict(param_annot.__args__[0]): return ConstructionIssue( f"Cannot collect parameters from {obj.__name__}, expected Unpack[TypedDict], not {param_annot}" ) for key, annotation in typing.get_type_hints(param_annot.__args__[0]).items(): # TODO: handle total=False and PEP 655 params.append( _Param( name=key, type=annotation, meta_factory=chz.factories.standard(annotation), default=None, doc="", blueprint_cast=None, metadata={}, ) ) else: if elements is None: elements = _get_variadic_elements(obj_path, arg_map) for element in elements - {p.name for p in params}: params.append( _Param( name=element, type=param_annot, meta_factory=chz.factories.standard(param_annot), default=None, doc="", blueprint_cast=None, metadata={}, ) ) continue # It could be interesting to let function defaults be chz.Field :-) # TODO: could be fun to parse function docstring params.append( _Param( name=name, type=param_annot, meta_factory=chz.factories.standard(param_annot), default=_Default.from_inspect_param(sigparam), doc="", blueprint_cast=None, metadata={}, ) ) if has_pos_only: def positional_constructor(**kwargs): a = [] kw = {} for k, v in kwargs.items(): if k.isdigit(): a.append((int(k), v)) else: kw[k] = v a = [v for _, v in sorted(a)] return obj(*a, **kw) obj_constructor = positional_constructor # Note params may be empty here if obj doesn't take any parameters. # This is usually okay, but has some interaction with fully defaulted subcomponents. # See test_nested_all_defaults and variants return params, obj_constructor, [] def _collect_params( obj: Any, obj_path: str, arg_map: ArgumentMap ) -> ( ConstructionIssue | tuple[ list[_Param], # params discovered Callable[..., Any], # constructor to call list[Any], # vaguely like [p.type for p in params], used only for sanity checking ] ): obj_origin = getattr(obj, "__origin__", obj) if chz.is_chz(obj_origin): return _collect_params_from_chz(obj, obj_path, arg_map) if isinstance(obj, functools.partial) and chz.is_chz(obj.func): if obj.args: return ConstructionIssue( f"Cannot collect parameters from partial function of chz class " f"{type_repr(obj.func)} with positional arguments" ) result = _collect_params(obj.func, obj_path, arg_map) if isinstance(result, ConstructionIssue): return result params, _constructor, variadic_types = result new_params = [] for param in params: if param.name in obj.keywords: # The actual value of the default should only matter for --help output param = dataclasses.replace( param, default=_Default(value=obj.keywords[param.name], factory=MISSING) ) new_params.append(param) return new_params, obj, variadic_types if obj_origin in {list, tuple, collections.abc.Sequence}: return _collect_params_from_sequence(obj, obj_path, arg_map) if obj_origin in {dict, collections.abc.Mapping}: return _collect_params_from_mapping(obj, obj_path, arg_map) if is_typed_dict(obj_origin): return _collect_params_from_typed_dict(obj, obj_path, arg_map) if obj_origin in {bool, int, float, str, bytes, None, type(None)}: return ConstructionIssue("Cannot collect parameters from primitive") if "enum" in sys.modules: import enum if type(obj) is enum.EnumMeta: return ConstructionIssue("Cannot collect parameters from Enum class") if callable(obj): return _collect_params_from_callable(obj, obj_path, arg_map) return ConstructionIssue( f"Could not collect parameters to construct {obj} of type {type_repr(obj)}" ) _K = TypeVar("_K") _V = TypeVar("_V", contravariant=True) class _WriteOnlyMapping(Generic[_K, _V], Protocol): def __setitem__(self, __key: _K, __value: _V, /) -> None: ... def update(self, __m: Mapping[_K, _V], /) -> None: ... class ConstructionIssue: def __init__(self, issue: str) -> None: self.issue = issue def _construct_factory( obj: Callable[..., _T], obj_path: str, arg_map: ArgumentMap, *, # Output parameters, do not use within this function # Typing these as write-only should help prevent accidental unsound use within this function # See _MakeResult for docs about these parameters all_params: _WriteOnlyMapping[str, _Param], used_args: set[tuple[str, int]], meta_factory_value: _WriteOnlyMapping[str, Any], missing_params: list[str], ) -> dict[str, Evaluatable] | ConstructionIssue: result = _collect_params(obj, obj_path, arg_map) del obj if isinstance(result, ConstructionIssue): return result params, obj_constructor, _ = result # Ideas: # TODO: Allow automatically accessing any attribute on parent for factories? # This eases the responsibility of getting the right structure for the config and could mean # we don't need to support wildcards? Reduces problems of something not getting specified # correctly. # "If you need a value, move it one level up." # TODO: Allow factories to return blueprints? This would allow for better presets, e.g. you # could do model=d4-dev model.n_layers=5 kwargs: dict[str, ParamRef] = {} value_mapping: dict[str, Evaluatable] = {} for param in params: sub_value_mapping = _construct_param( param, obj_path, arg_map, all_params=all_params, used_args=used_args, meta_factory_value=meta_factory_value, missing_params=missing_params, ) if isinstance(sub_value_mapping, ConstructionIssue): return sub_value_mapping if sub_value_mapping is None: continue param_path = (obj_path + "." if obj_path else "") + param.name value_mapping.update(sub_value_mapping) kwargs[param.name] = ParamRef(param_path) value_mapping[obj_path] = Thunk(obj_constructor, kwargs) return value_mapping def _construct_unspecified_param( param: _Param, *, param_path: str, arg_map: ArgumentMap, # Output parameters, do not use within this function # See _MakeResult for docs about these parameters all_params: _WriteOnlyMapping[str, _Param], used_args: set[tuple[str, int]], meta_factory_value: _WriteOnlyMapping[str, Any], missing_params: list[str], ) -> dict[str, Evaluatable] | ConstructionIssue | None: if ( param.meta_factory is not None and (factory := param.meta_factory.unspecified_factory()) is not None ): assert callable(factory) sub_all_params: dict[str, _Param] = {} sub_missing_params: list[str] = [] sub_used_args: set[tuple[str, int]] = set() sub_meta_factory_value: dict[str, Any] = {} value_mapping = _construct_factory( factory, param_path, arg_map, all_params=sub_all_params, used_args=sub_used_args, meta_factory_value=sub_meta_factory_value, missing_params=sub_missing_params, ) all_params.update(sub_all_params) used_args.update(sub_used_args) # TODO: should this be gated by use? meta_factory_value.update(sub_meta_factory_value) if isinstance(value_mapping, ConstructionIssue): if param_path == "" and param.default is None: # This is a little bit special case-y. But if we have a construction issue with # the root param, it's much better to forward it than for the user to get an error # about a missing required root argument. return value_mapping else: thunk = value_mapping[param_path] assert isinstance(thunk, Thunk) # Only do this if we have subcomponents specified (including wildcards) if thunk.kwargs: meta_factory_value[param_path] = factory missing_params.extend(sub_missing_params) return value_mapping # Alternatively, if a) we do not have a default, b) we're making a chz object # c) we know instantiation would always work, that's fine too. # A little special-case-y, but somewhat sane. It turns something that would # error due to lack of default into something reasonable. # See test_nested_all_defaults and variants if ( param.default is None and ( chz.is_chz(factory) or (isinstance(factory, functools.partial) and chz.is_chz(factory.func)) ) and all(p.default is not None for p in sub_all_params.values()) ): assert not sub_missing_params meta_factory_value[param_path] = factory return value_mapping # If we have a default, make sure we don't extend missing_params if param.default is None: if sub_missing_params: missing_params.extend(sub_missing_params) else: # Happens if we collect no params, like non-chz field or variadics missing_params.append(param_path) else: # If we have a default, do some validation about wildcards + variadics _check_for_wildcard_matching_variadic_top_level(factory, param, param_path, arg_map) return None assert not sub_missing_params # If we have no subcomponents specified or we have no factory, we don't add any kwargs # When the object is created, this will be equivalent to: # `attr = default` or `attr = default_factory()` # If there is no default, we will raise MissingBlueprintArg, instead of relying on the # normal Python error during instantiation. We also rely on raising ExtraneousBlueprintArg # if there are arguments that go unused. # (In the case of Blueprint implementation bugs, if we're missing a param, __init__ will # have our back, but the extraneous logic has no backup) if param.default is None: missing_params.append(param_path) return None def _construct_param( param: _Param, obj_path: str, arg_map: ArgumentMap, *, # Output parameters, do not use within this function # See _MakeResult for docs about these parameters all_params: _WriteOnlyMapping[str, _Param], used_args: set[tuple[str, int]], meta_factory_value: _WriteOnlyMapping[str, Any], missing_params: list[str], ) -> dict[str, Evaluatable] | ConstructionIssue | None: # Returns None if we don't need to pass any value. This doesn't mean there's an error, # we might simply want the default or default_factory value. param_path: Final = (obj_path + "." if obj_path else "") + param.name all_params[param_path] = param found_arg = arg_map.get_kv(param_path) # If nothing is specified, check if we have a factory we can feed subcomponents to and if there # are specified subcomponents we could feed to it. Otherwise, if a default or default_factory # exists, we'll use that. if found_arg is None: return _construct_unspecified_param( param, param_path=param_path, arg_map=arg_map, all_params=all_params, used_args=used_args, meta_factory_value=meta_factory_value, missing_params=missing_params, ) used_args.add((found_arg.key, found_arg.layer_index)) spec: object = found_arg.value # Something is specified, so we must either add something to kwargs or error out # If something is specified, and is of the expected type, we just assign it: # `attr = spec` if not isinstance(spec, SpecialArg) and is_subtype_instance(spec, param.type): # (ignore SpecialArg's here, in case param.type is object) if not (param.meta_factory is not None and arg_map.subpaths(param_path, strict=True)): # TODO: deep copy? return {param_path: Value(spec)} # Otherwise, we see if we can cast it to the expected type: # `attr = trycast(spec.value, param.type)` if isinstance(spec, Castable): # If we have a meta_factory and we have args that are prefixed with the param path, we # will always want to construct that (if we successfully casted here when subcomponents # are specified, we'd just fail later because those subcomponents would be extraneous) if not (param.meta_factory is not None and arg_map.subpaths(param_path, strict=True)): try: casted_spec = param.cast(spec.value) return {param_path: Value(casted_spec)} except CastError: pass # ..or if it's a Reference to some other parameter if isinstance(spec, Reference): if spec.ref == param_path: # If it's a self reference, treat it as if it were unspecified value_mapping = _construct_unspecified_param( param, param_path=param_path, arg_map=arg_map, all_params=all_params, used_args=used_args, meta_factory_value=meta_factory_value, missing_params=missing_params, ) if isinstance(value_mapping, ConstructionIssue): return value_mapping if value_mapping is None and param.default is not None: # See test_blueprint_reference_wildcard_default # TODO: this is the only place we instantiate a default default = param.default.instantiate() return {param_path: Value(default)} return value_mapping return {param_path: ParamRef(spec.ref)} # Otherwise, see if it's something that can construct the expected type. For instance, # maybe it's a subclass of param.type, or more generally any `Callable[..., param.type]`, # in which case we do: # `attr = spec(...)` factory: Callable[..., Any] if is_subtype_instance(spec, Callable[..., param.type]): assert callable(spec) factory = spec value_mapping = _construct_factory( factory, param_path, arg_map, all_params=all_params, used_args=used_args, meta_factory_value=meta_factory_value, missing_params=missing_params, ) if isinstance(value_mapping, ConstructionIssue): return value_mapping meta_factory_value[param_path] = factory return value_mapping # Otherwise, see if it's something that can be casted into something that can construct # the expected type. For instance, maybe it's a string that's the name of a subclass of # param.type or "module:func" where module.func is a `func: Callable[..., param.type]`. # `attr = trycast(spec, constructor_type)(...)` if isinstance(spec, Castable): if param.meta_factory is not None: try: factory = param.meta_factory.from_string(spec.value) except chz.factories.MetaFromString as e: cast_error = None try: param.cast(spec.value) except CastError as e2: cast_error = str(e2) if cast_error is None: subpaths = arg_map.subpaths(param_path, strict=True) assert subpaths cast_error = f"Not a value, since subparameters were provided (e.g. {join_arg_path(param_path, subpaths[0])!r})" raise InvalidBlueprintArg( f"Could not interpret argument {spec.value!r} provided for param {param_path!r}...\n\n" f"- Failed to interpret it as a value:\n{cast_error}\n\n" f"- Failed to interpret it as a factory for polymorphic construction:\n{e}" ) from None assert callable(factory) value_mapping = _construct_factory( factory, param_path, arg_map, all_params=all_params, used_args=used_args, meta_factory_value=meta_factory_value, missing_params=missing_params, ) if isinstance(value_mapping, ConstructionIssue): return value_mapping meta_factory_value[param_path] = factory return value_mapping # This cast is just to raise the error we caught previously try: param.cast(spec.value) except CastError as e: raise InvalidBlueprintArg( f"Could not cast {spec.value!r} to {type_repr(param.type)}:\n{e}" ) from e # This next line should be unreachable... raise TypeError( f"Expected {param_path!r} to be castable to {type_repr(param.type)}, got {spec.value!r}" ) if not isinstance(spec, SpecialArg) and is_subtype_instance(spec, param.type): if param.meta_factory is not None: subpaths = arg_map.subpaths(param_path, strict=True) if subpaths: raise InvalidBlueprintArg( f"Could not interpret {spec!r} provided for param {param_path!r} " f"as a value, since subparameters were provided " f"(e.g. {join_arg_path(param_path, subpaths[0])!r})" ) raise TypeError( f"Expected {param_path!r} to be {type_repr(param.type)}, got {type_repr(_simplistic_type_of_value(spec))}" ) def _check_for_wildcard_matching_variadic_top_level( obj: object, param: _Param, obj_path: str, arg_map: ArgumentMap ): assert param.default is not None if ( type(param.default.value) is tuple and param.default.value == () ) or param.default.factory in {tuple, list, dict}: return result = _collect_params(obj, obj_path, arg_map) if isinstance(result, ConstructionIssue): return variadic_params, _, variadic_types = result if variadic_params: return if isinstance(param.default.value, (tuple, list)): variadic_types = list( set(variadic_types) | {type(element) for element in param.default.value} ) elif isinstance(param.default.value, dict): variadic_types = list( set(variadic_types) | {type(element) for element in param.default.value.values()} ) if not variadic_types: return # The case we're checking here is if we: # 1) have a param with a default # 1.5) the default is not an empty tuple or list or dict # 2) have a variadic factory for that param # 3) we do not find any variadic params # Then we check if any wildcards would have matched a param if we had one, since it can be # unintuitive that the default will not be affected by the wildcard (default / default_factory # are opaque and have no interaction with wildcards beyond their presence or absence). # See test_variadic_default_wildcard_error for element_type in variadic_types: result = _collect_params(element_type, obj_path + ".__chz_empty_variadic", arg_map) if isinstance(result, ConstructionIssue): continue subparams, _, _ = result for subparam in subparams: param_path = obj_path + ".__chz_empty_variadic." + subparam.name found_arg = arg_map.get_kv(param_path) param_path = obj_path + ".(variadic)." + subparam.name if found_arg is not None: raise ConstructionException( f"\n\nYou've hit an interesting case.\n\n" f'The parameter "{obj_path}" is variadic ({type_repr(obj)}), but no ' "parametrisation was found (either variadic subparameters or a polymorphic " "parametrisation).\n" f'This is fine in theory, because "{obj_path}" has a ' f"default value.\n\n" f'However, you also specified the wildcard "{found_arg.key}" and you may ' f'have expected it to modify the value of "{param_path}".\n' "This is not possible -- default values / default_factory results are " "opaque to chz. " "The only way in which default / default_factory interact with Blueprint " "is presence / absence. So out of caution, here's an error!\n\n" "If this error is a false positive, consider scoping the wildcard more " "narrowly or using exact keys. As always, appending --help to a chz command " "will show you what gets mapped to which param." )