neuron-explainer/neuron_explainer/fast_dataclasses/fast_dataclasses.py (60 lines of code) (raw):

# Utilities for dataclasses that are very fast to serialize and deserialize, with limited data # validation. Fields must not be tuples, since they get serialized and then deserialized as lists. # # The unit tests for this library show how to use it. import json from dataclasses import dataclass, field, fields, is_dataclass from functools import partial from typing import Any, Union import orjson dataclasses_by_name = {} dataclasses_by_fieldnames = {} @dataclass class FastDataclass: dataclass_name: str = field(init=False) def __post_init__(self) -> None: self.dataclass_name = self.__class__.__name__ def register_dataclass(cls): # type: ignore assert is_dataclass(cls), "Only dataclasses can be registered." dataclasses_by_name[cls.__name__] = cls name_set = frozenset(f.name for f in fields(cls) if f.name != "dataclass_name") dataclasses_by_fieldnames[name_set] = cls return cls def dumps(obj: Any) -> bytes: return orjson.dumps(obj, option=orjson.OPT_SERIALIZE_NUMPY) def _object_hook(d: Any, backwards_compatible: bool = True) -> Any: # If d is a list, recurse. if isinstance(d, list): return [_object_hook(x, backwards_compatible=backwards_compatible) for x in d] # If d is not a dict, return it as is. if not isinstance(d, dict): return d cls = None if "dataclass_name" in d: if d["dataclass_name"] in dataclasses_by_name: cls = dataclasses_by_name[d["dataclass_name"]] else: assert backwards_compatible, ( f"Dataclass {d['dataclass_name']} not found, set backwards_compatible=True if you " f"are okay with that." ) # Load objects created without dataclass_name set. else: # Try our best to find a dataclass if backwards_compatible is True. if backwards_compatible: d_fields = frozenset(d.keys()) if d_fields in dataclasses_by_fieldnames: cls = dataclasses_by_fieldnames[d_fields] elif len(d_fields) > 0: # Check if the fields are a subset of a dataclass (if the dataclass had extra fields # added since the data was created). Note that this will fail if fields were removed # from the dataclass. for key, possible_cls in dataclasses_by_fieldnames.items(): if d_fields.issubset(key): cls = possible_cls break else: print(f"Could not find dataclass for {d_fields} {cls}") new_d = { k: _object_hook(v, backwards_compatible=backwards_compatible) for k, v in d.items() if k != "dataclass_name" } if cls is not None: return cls(**new_d) else: return new_d def loads(s: Union[str, bytes], backwards_compatible: bool = True) -> Any: return json.loads( s, object_hook=partial(_object_hook, backwards_compatible=backwards_compatible), )