neuron_explainer/fast_dataclasses/fast_dataclasses.py (88 lines of code) (raw):
# Utilities for dataclasses that are very fast to serialize and deserialize, with limited data
# validation. Fields must not be tuples, since they get serialized and then deserialized as lists.
#
# The unit tests for this library show how to use it.
import json
from dataclasses import dataclass, field, fields, is_dataclass
from functools import partial
from typing import Any
import orjson
dataclasses_by_name = {}
# TODO(williamrs): remove deserialization using fieldnames once all data is in the new format
dataclasses_by_fieldnames = {}
@dataclass
class FastDataclass:
dataclass_name: str = field(init=False)
def __post_init__(self) -> None:
self.dataclass_name = self.__class__.__name__
@classmethod
def field_renamed(cls, old_name: str, new_name: str) -> None:
if not hasattr(cls, "field_renames"):
cls.field_renames = {} # type: ignore
assert old_name not in cls.field_renames, f"{old_name} already renamed" # type: ignore
existing_field_names = [f.name for f in fields(cls)]
assert new_name in existing_field_names, f"{new_name} not in {existing_field_names}"
assert old_name not in existing_field_names, f"{old_name} still in {existing_field_names}"
cls.field_renames[old_name] = new_name # type: ignore
# Type checking is suppressed in these two functions because mypy doesn't like that we're
# setting attributes at runtime. We need to do that because if we defined the attributes at the
# class level, they'd be shared across subclasses.
@classmethod
def field_deleted(cls, field_name: str) -> None:
if not hasattr(cls, "deleted_fields"):
cls.deleted_fields = [] # type: ignore
assert field_name not in cls.deleted_fields, f"{field_name} already deleted" # type: ignore
existing_field_names = [f.name for f in fields(cls)]
assert field_name not in existing_field_names, f"{field_name} still in use"
cls.deleted_fields.append(field_name) # type: ignore
@classmethod
def was_previously_named(cls, old_name: str) -> None:
assert old_name not in dataclasses_by_name, f"{old_name} still in use as a dataclass name"
dataclasses_by_name[old_name] = cls
name_set = frozenset(f.name for f in fields(cls) if f.name != "dataclass_name")
dataclasses_by_fieldnames[name_set] = cls
def register_dataclass(cls): # type: ignore
assert is_dataclass(cls), "Only dataclasses can be registered."
dataclasses_by_name[cls.__name__] = cls
name_set = frozenset(f.name for f in fields(cls) if f.name != "dataclass_name")
dataclasses_by_fieldnames[name_set] = cls
return cls
def dumps(obj: Any) -> bytes:
return orjson.dumps(obj, option=orjson.OPT_SERIALIZE_NUMPY)
def _object_hook(d: Any, backwards_compatible: bool = True) -> Any:
# If d is a list, recurse.
if isinstance(d, list):
return [_object_hook(x, backwards_compatible=backwards_compatible) for x in d]
# If d is not a dict, return it as is.
if not isinstance(d, dict):
return d
cls = None
if "dataclass_name" in d:
if d["dataclass_name"] in dataclasses_by_name:
cls = dataclasses_by_name[d["dataclass_name"]]
else:
assert backwards_compatible, (
f"Dataclass {d['dataclass_name']} not found, set backwards_compatible=True if you "
f"are okay with that."
)
# Load objects created without dataclass_name set.
else:
# Try our best to find a dataclass if backwards_compatible is True.
if backwards_compatible:
d_fields = frozenset(d.keys())
if d_fields in dataclasses_by_fieldnames:
cls = dataclasses_by_fieldnames[d_fields]
elif len(d_fields) > 0:
# Check if the fields are a subset of a dataclass (if the dataclass had extra fields
# added since the data was created). Note that this will fail if fields were removed
# from the dataclass.
for key, possible_cls in dataclasses_by_fieldnames.items():
if d_fields.issubset(key):
cls = possible_cls
break
else:
print(f"Could not find dataclass for {d_fields} {cls}")
if cls is not None:
# Apply renames and deletions.
if hasattr(cls, "field_renames"):
d = {cls.field_renames.get(k, k): v for k, v in d.items()}
if hasattr(cls, "deleted_fields"):
d = {k: v for k, v in d.items() if k not in cls.deleted_fields}
new_d = {
k: _object_hook(v, backwards_compatible=backwards_compatible)
for k, v in d.items()
if k != "dataclass_name"
}
if cls is not None:
return cls(**new_d)
else:
return new_d
def loads(s: str | bytes, backwards_compatible: bool = True) -> Any:
return json.loads(
s,
object_hook=partial(_object_hook, backwards_compatible=backwards_compatible),
)