in src/huggingface_hub/dataclasses.py [0:0]
def wrap(cls: Type[T]) -> Type[T]:
if not hasattr(cls, "__dataclass_fields__"):
raise StrictDataclassDefinitionError(
f"Class '{cls.__name__}' must be a dataclass before applying @strict."
)
# List and store validators
field_validators: Dict[str, List[Validator_T]] = {}
for f in fields(cls): # type: ignore [arg-type]
validators = []
validators.append(_create_type_validator(f))
custom_validator = f.metadata.get("validator")
if custom_validator is not None:
if not isinstance(custom_validator, list):
custom_validator = [custom_validator]
for validator in custom_validator:
if not _is_validator(validator):
raise StrictDataclassDefinitionError(
f"Invalid validator for field '{f.name}': {validator}. Must be a callable taking a single argument."
)
validators.extend(custom_validator)
field_validators[f.name] = validators
cls.__validators__ = field_validators # type: ignore
# Override __setattr__ to validate fields on assignment
original_setattr = cls.__setattr__
def __strict_setattr__(self: Any, name: str, value: Any) -> None:
"""Custom __setattr__ method for strict dataclasses."""
# Run all validators
for validator in self.__validators__.get(name, []):
try:
validator(value)
except (ValueError, TypeError) as e:
raise StrictDataclassFieldValidationError(field=name, cause=e) from e
# If validation passed, set the attribute
original_setattr(self, name, value)
cls.__setattr__ = __strict_setattr__ # type: ignore[method-assign]
if accept_kwargs:
# (optional) Override __init__ to accept arbitrary keyword arguments
original_init = cls.__init__
@wraps(original_init)
def __init__(self, **kwargs: Any) -> None:
# Extract only the fields that are part of the dataclass
dataclass_fields = {f.name for f in fields(cls)} # type: ignore [arg-type]
standard_kwargs = {k: v for k, v in kwargs.items() if k in dataclass_fields}
# Call the original __init__ with standard fields
original_init(self, **standard_kwargs)
# Add any additional kwargs as attributes
for name, value in kwargs.items():
if name not in dataclass_fields:
self.__setattr__(name, value)
cls.__init__ = __init__ # type: ignore[method-assign]
# (optional) Override __repr__ to include additional kwargs
original_repr = cls.__repr__
@wraps(original_repr)
def __repr__(self) -> str:
# Call the original __repr__ to get the standard fields
standard_repr = original_repr(self)
# Get additional kwargs
additional_kwargs = [
# add a '*' in front of additional kwargs to let the user know they are not part of the dataclass
f"*{k}={v!r}"
for k, v in self.__dict__.items()
if k not in cls.__dataclass_fields__ # type: ignore [attr-defined]
]
additional_repr = ", ".join(additional_kwargs)
# Combine both representations
return f"{standard_repr[:-1]}, {additional_repr})" if additional_kwargs else standard_repr
cls.__repr__ = __repr__ # type: ignore [method-assign]
# List all public methods starting with `validate_` => class validators.
class_validators = []
for name in dir(cls):
if not name.startswith("validate_"):
continue
method = getattr(cls, name)
if not callable(method):
continue
if len(inspect.signature(method).parameters) != 1:
raise StrictDataclassDefinitionError(
f"Class '{cls.__name__}' has a class validator '{name}' that takes more than one argument."
" Class validators must take only 'self' as an argument. Methods starting with 'validate_'"
" are considered to be class validators."
)
class_validators.append(method)
cls.__class_validators__ = class_validators # type: ignore [attr-defined]
# Add `validate` method to the class, but first check if it already exists
def validate(self: T) -> None:
"""Run class validators on the instance."""
for validator in cls.__class_validators__: # type: ignore [attr-defined]
try:
validator(self)
except (ValueError, TypeError) as e:
raise StrictDataclassClassValidationError(validator=validator.__name__, cause=e) from e
# Hack to be able to raise if `.validate()` already exists except if it was created by this decorator on a parent class
# (in which case we just override it)
validate.__is_defined_by_strict_decorator__ = True # type: ignore [attr-defined]
if hasattr(cls, "validate"):
if not getattr(cls.validate, "__is_defined_by_strict_decorator__", False): # type: ignore [attr-defined]
raise StrictDataclassDefinitionError(
f"Class '{cls.__name__}' already implements a method called 'validate'."
" This method name is reserved when using the @strict decorator on a dataclass."
" If you want to keep your own method, please rename it."
)
cls.validate = validate # type: ignore
# Run class validators after initialization
initial_init = cls.__init__
@wraps(initial_init)
def init_with_validate(self, *args, **kwargs) -> None:
"""Run class validators after initialization."""
initial_init(self, *args, **kwargs) # type: ignore [call-arg]
cls.validate(self) # type: ignore [attr-defined]
setattr(cls, "__init__", init_with_validate)
return cls