in chz/tiepin.py [0:0]
def _simplistic_try_cast(inst_str: str, typ: TypeForm):
origin = getattr(typ, "__origin__", typ)
if is_union_type(origin):
# sort str to last spot
args = _sort_for_union_preference(getattr(typ, "__args__", ()))
for arg in args:
try:
return _simplistic_try_cast(inst_str, arg)
except CastError:
pass
raise CastError(f"Could not cast {repr(inst_str)} to {type_repr(typ)}")
if origin is typing.Any or origin is typing_extensions.Any or origin is object:
try:
return ast.literal_eval(inst_str)
except (ValueError, SyntaxError):
pass
# Also accept some lowercase spellings
if inst_str in {"true", "false"}:
return inst_str == "true"
if inst_str in {"none", "null", "NULL"}:
return None
return inst_str
if origin is typing.Literal or origin is typing_extensions.Literal:
values_by_type = {}
for arg in getattr(typ, "__args__", ()):
values_by_type.setdefault(type(arg), []).append(arg)
for literal_typ, literal_values in values_by_type.items():
try:
value = _simplistic_try_cast(inst_str, literal_typ)
if value in literal_values:
return value
except CastError:
pass
raise CastError(f"Could not cast {repr(inst_str)} to {type_repr(typ)}")
if origin is None or origin is type(None):
if inst_str == "None":
return None
raise CastError(f"Could not cast {repr(inst_str)} to {type_repr(typ)}")
if origin is bool:
if inst_str in {"t", "true", "True", "1"}:
return True
if inst_str in {"f", "false", "False", "0"}:
return False
raise CastError(f"Could not cast {repr(inst_str)} to {type_repr(typ)}")
if origin is str:
return inst_str
if origin is float:
try:
return float(inst_str)
except ValueError as e:
raise CastError(f"Could not cast {repr(inst_str)} to {type_repr(typ)}") from e
if origin is int:
try:
return int(inst_str)
except ValueError as e:
raise CastError(f"Could not cast {repr(inst_str)} to {type_repr(typ)}") from e
if origin is list or origin is collections.abc.Sequence or origin is collections.abc.Iterable:
if not inst_str:
return []
args = getattr(typ, "__args__", ())
item_type = args[0] if args else typing.Any
if inst_str[0] in {"[", "("}:
try:
value = ast.literal_eval(inst_str)
except (ValueError, SyntaxError):
raise CastError(f"Could not cast {repr(inst_str)} to {type_repr(typ)}") from None
if is_subtype_instance(value, typ):
return value
raise CastError(f"Could not cast {repr(inst_str)} to {type_repr(typ)}")
inst_items = inst_str.split(",") if inst_str else []
ret = [_simplistic_try_cast(item, item_type) for item in inst_items]
if origin is list:
return ret
return tuple(ret)
if origin is tuple:
args = getattr(typ, "__args__", ())
inst_items = inst_str.split(",") if inst_str else []
if len(args) == 0:
return tuple(inst_items)
if len(args) == 2 and args[-1] is ...:
item_type = args[0]
return tuple(_simplistic_try_cast(item, item_type) for item in inst_items)
num_unpack = sum(is_args_unpack(arg) for arg in args)
if num_unpack == 0:
# Great, normal heterogenous tuple
if len(args) != len(inst_items):
raise CastError(
f"Could not cast {repr(inst_str)} to {type_repr(typ)} because of length mismatch"
+ (
f". Homogeneous tuples should be typed as tuple[{type_repr(args[0])}, ...] not tuple[{type_repr(args[0])}]"
if len(args) == 1
else ""
)
)
return tuple(
_simplistic_try_cast(item, item_typ) for item, item_typ in zip(inst_items, args)
)
else:
# Cursed PEP 646 stuff
return _cast_unpacked_tuples(inst_items, args)
if origin is dict or origin is collections.abc.Mapping:
if not inst_str:
return {}
if inst_str[0] == "{":
try:
value = ast.literal_eval(inst_str)
except (ValueError, SyntaxError):
raise CastError(f"Could not cast {repr(inst_str)} to {type_repr(typ)}") from None
if is_subtype_instance(value, typ):
return value
raise CastError(f"Could not cast {repr(inst_str)} to {type_repr(typ)}")
if origin is collections.abc.Callable:
# TODO: also support type, callback protocols
# TODO: unify with factories.from_string
# TODO: needs module context
if ":" in inst_str:
try:
module_name, var = inst_str.split(":", 1)
module = _module_from_name(module_name)
value = _module_getattr(module, var)
if not is_subtype_instance(value, typ):
raise CastError(f"{type_repr(value)} is not a subtype of {type_repr(typ)}")
except CastError as e:
raise CastError(
f"Could not cast {repr(inst_str)} to {type_repr(typ)}. {e}"
) from None
return value
else:
raise CastError(
f"Could not cast {repr(inst_str)} to {type_repr(typ)}. Try using a fully qualified name, e.g. module_name:{inst_str}"
)
if "torch" in sys.modules:
import torch
if origin is torch.dtype:
value = getattr(torch, inst_str, None)
if value and isinstance(value, torch.dtype):
return value
raise CastError(f"Could not cast {repr(inst_str)} to {type_repr(typ)}")
if "enum" in sys.modules:
import enum
if isinstance(origin, type) and issubclass(origin, enum.Enum):
try:
# Look up by name
return origin[inst_str]
except KeyError:
pass
# Fallback to looking up by value
for member in origin:
try:
value = _simplistic_try_cast(inst_str, type(member.value))
except CastError:
continue
if value == member.value:
return member
raise CastError(f"Could not cast {repr(inst_str)} to {type_repr(typ)}")
if "fractions" in sys.modules:
import fractions
if origin is fractions.Fraction:
try:
return fractions.Fraction(inst_str)
except ValueError as e:
raise CastError(f"Could not cast {repr(inst_str)} to {type_repr(typ)}") from e
if "pathlib" in sys.modules:
import pathlib
if origin is pathlib.Path:
return pathlib.Path(inst_str)
if hasattr(origin, "__chz_cast__"):
return origin.__chz_cast__(inst_str)
if not isinstance(origin, type):
raise CastError(f"Unrecognised type object {type_repr(typ)}")
raise CastError(f"Could not cast {repr(inst_str)} to {type_repr(typ)}")