core/maxframe/serialization/serializables/core.py (366 lines of code) (raw):

# Copyright 1999-2025 Alibaba Group Holding Ltd. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import logging import weakref from collections import OrderedDict from typing import Any, Dict, List, Optional, Tuple, Type import msgpack from ...errors import MaxFrameDeprecationError from ...lib.mmh3 import hash from ...utils import no_default from ..core import Placeholder, Serializer, buffered, load_type from .field import Field from .field_type import DictType, ListType, PrimitiveFieldType, TupleType try: from ..deserializer import get_legacy_module_name except ImportError: get_legacy_module_name = lambda x: x logger = logging.getLogger(__name__) _deprecate_log_key = "_SER_DEPRECATE_LOGGED" def _is_field_primitive_compound(field: Field): if field.on_serialize is not None or field.on_deserialize is not None: return False def check_type(field_type): if isinstance(field_type, PrimitiveFieldType): return True if isinstance(field_type, (ListType, TupleType)): if all( check_type(element_type) or element_type is Ellipsis for element_type in field_type._field_types ): return True if isinstance(field_type, DictType): if all( isinstance(element_type, PrimitiveFieldType) or element_type is Ellipsis for element_type in (field_type.key_type, field_type.value_type) ): return True return False return check_type(field.field_type) class SerializableMeta(type): def __new__(mcs, name: str, bases: Tuple[Type], properties: Dict): # All the fields including misc fields. legacy_name_hash = hash( f"{get_legacy_module_name(properties.get('__module__'))}.{name}" ) name_hash = hash( f"{properties.get('__module__')}.{properties.get('__qualname__')}" ) all_fields = dict() # mapping field names to base classes field_to_cls_hash = dict() # mapping legacy name hash to name hashes legacy_to_new_name_hash = {legacy_name_hash: name_hash} for base in bases: if not hasattr(base, "_FIELDS"): continue all_fields.update(base._FIELDS) field_to_cls_hash.update(base._FIELD_TO_NAME_HASH) legacy_to_new_name_hash.update(base._LEGACY_TO_NEW_NAME_HASH) properties_without_fields = {} properties_field_slot_names = [] for k, v in properties.items(): if not isinstance(v, Field): properties_without_fields[k] = v continue field = all_fields.get(k) # record the field for the class being created field_to_cls_hash[k] = name_hash if field is None: properties_field_slot_names.append(k) else: v.name = field.name v.get = field.get v.set = field.set v.__delete__ = field.__delete__ all_fields[k] = v # Make field order deterministic to serialize it as list instead of dict. field_order = list(all_fields) primitive_fields = [] primitive_field_names = set() non_primitive_fields = [] for field_name, v in all_fields.items(): if _is_field_primitive_compound(v): primitive_fields.append(v) primitive_field_names.add(field_name) else: non_primitive_fields.append(v) # count number of fields for every base class cls_to_primitive_field_count = OrderedDict() cls_to_non_primitive_field_count = OrderedDict() for field_name in field_order: cls_hash = field_to_cls_hash[field_name] if field_name in primitive_field_names: cls_to_primitive_field_count[cls_hash] = ( cls_to_primitive_field_count.get(cls_hash, 0) + 1 ) else: cls_to_non_primitive_field_count[cls_hash] = ( cls_to_non_primitive_field_count.get(cls_hash, 0) + 1 ) slots = set(properties.pop("__slots__", set())) slots.update(properties_field_slot_names) properties = properties_without_fields # todo remove this prop when all versions below v1.0.0rc1 is eliminated properties["_LEGACY_NAME_HASH"] = legacy_name_hash properties["_NAME_HASH"] = name_hash properties["_LEGACY_TO_NEW_NAME_HASH"] = legacy_to_new_name_hash properties["_FIELDS"] = all_fields properties["_FIELD_ORDER"] = field_order properties["_FIELD_TO_NAME_HASH"] = field_to_cls_hash properties["_PRIMITIVE_FIELDS"] = primitive_fields properties["_CLS_TO_PRIMITIVE_FIELD_COUNT"] = OrderedDict( cls_to_primitive_field_count ) properties["_NON_PRIMITIVE_FIELDS"] = non_primitive_fields properties["_CLS_TO_NON_PRIMITIVE_FIELD_COUNT"] = OrderedDict( cls_to_non_primitive_field_count ) properties["__slots__"] = tuple(slots) clz = type.__new__(mcs, name, bases, properties) # Bind slot member_descriptor with field. for name in properties_field_slot_names: member_descriptor = getattr(clz, name) field = all_fields[name] field.name = member_descriptor.__name__ field.get = member_descriptor.__get__ field.set = member_descriptor.__set__ field.__delete__ = member_descriptor.__delete__ setattr(clz, name, field) return clz class Serializable(metaclass=SerializableMeta): __slots__ = ("__weakref__",) _cache_primitive_serial = False _ignore_non_existing_keys = False _LEGACY_NAME_HASH: int _NAME_HASH: int _LEGACY_TO_NEW_NAME_HASH: Dict[int, int] _FIELDS: Dict[str, Field] _FIELD_ORDER: List[str] _FIELD_TO_NAME_HASH: Dict[str, int] _PRIMITIVE_FIELDS: List[str] _CLS_TO_PRIMITIVE_FIELD_COUNT: Dict[int, int] _NON_PRIMITIVE_FIELDS: List[str] _CLS_TO_NON_PRIMITIVE_FIELD_COUNT: Dict[int, int] def __init__(self, *args, **kwargs): fields = self._FIELDS field_order = self._FIELD_ORDER assert len(args) <= len(field_order) if args: # pragma: no cover values = dict(zip(field_order, args)) values.update(kwargs) else: values = kwargs for k, v in values.items(): try: fields[k].set(self, v) except KeyError: if not self._ignore_non_existing_keys: raise def __on_deserialize__(self): pass def __repr__(self): values = ", ".join( [ "{}={!r}".format(slot, getattr(self, slot, None)) for slot in self.__slots__ ] ) return "{}({})".format(self.__class__.__name__, values) def copy_to(self, target: "Serializable") -> "Serializable": copied_fields = target._FIELDS for k, field in self._FIELDS.items(): try: # Slightly faster than getattr. value = field.get(self, k) try: copied_fields[k].set(target, value) except KeyError: copied_fields["_" + k].set(target, value) except AttributeError: continue return target def copy(self) -> "Serializable": return self.copy_to(type(self)()) _primitive_serial_cache = weakref.WeakKeyDictionary() class _NoFieldValue: pass _no_field_value = _NoFieldValue() def _to_primitive_placeholder(v: Any) -> Any: if v is _no_field_value or v is no_default: return {} return v def _restore_primitive_placeholder(v: Any) -> Any: if type(v) is dict: if v == {}: return _no_field_value else: return v else: return v class SerializableSerializer(Serializer): """ Leverage DictSerializer to perform serde. """ @classmethod def _log_legacy(cls, context: Dict, key: Any, msg: str, *args, **kwargs): level = kwargs.pop("level", logging.WARNING) try: logged_keys = context[_deprecate_log_key] except KeyError: logged_keys = context[_deprecate_log_key] = set() if key not in logged_keys: logged_keys.add(key) logger.log(level, msg, *args, **kwargs) @classmethod def _get_obj_field_count_key(cls, obj: Serializable, legacy: bool = False): return f"FC_{obj._NAME_HASH if not legacy else obj._LEGACY_NAME_HASH}" @classmethod def _get_field_values(cls, obj: Serializable, fields): values = [] for field in fields: try: value = field.get(obj) if field.on_serialize is not None: value = field.on_serialize(value) except AttributeError: # Most field values are not None, serialize by list is more efficient than dict. value = _no_field_value values.append(value) return values @buffered def serial(self, obj: Serializable, context: Dict): if obj._cache_primitive_serial and obj in _primitive_serial_cache: primitive_vals = _primitive_serial_cache[obj] else: primitive_vals = self._get_field_values(obj, obj._PRIMITIVE_FIELDS) # replace _no_field_value as {} to make them msgpack-serializable primitive_vals = [_to_primitive_placeholder(v) for v in primitive_vals] if obj._cache_primitive_serial: primitive_vals = msgpack.dumps(primitive_vals) _primitive_serial_cache[obj] = primitive_vals compound_vals = self._get_field_values(obj, obj._NON_PRIMITIVE_FIELDS) cls_module = f"{type(obj).__module__}#{type(obj).__qualname__}" field_count_key = self._get_obj_field_count_key(obj) if not self.is_public_data_exist(context, field_count_key): # store field distribution for current Serializable counts = [ list(obj._CLS_TO_PRIMITIVE_FIELD_COUNT.items()), list(obj._CLS_TO_NON_PRIMITIVE_FIELD_COUNT.items()), ] field_count_data = msgpack.dumps(counts) self.put_public_data( context, self._get_obj_field_count_key(obj), field_count_data ) return [cls_module, primitive_vals], [compound_vals], False @staticmethod def _set_field_value(obj: Serializable, field: Field, value): if value is _no_field_value: return if type(value) is Placeholder: if field.on_deserialize is not None: value.callbacks.append( lambda v: field.set(obj, field.on_deserialize(v)) ) else: value.callbacks.append(lambda v: field.set(obj, v)) else: if field.on_deserialize is not None: field.set(obj, field.on_deserialize(value)) else: field.set(obj, value) @classmethod def _prune_server_fields( cls, client_cls_to_field_count: Optional[Dict[int, int]], server_cls_to_field_count: Dict[int, int], server_fields: list, legacy_to_new_hash: Dict[int, int], ) -> list: if set(client_cls_to_field_count.keys()) == set( server_cls_to_field_count.keys() ): return server_fields new_to_legacy_hash = {v: k for k, v in legacy_to_new_hash.items()} ret_server_fields = [] server_pos = 0 for cls_hash, count in server_cls_to_field_count.items(): if ( cls_hash in client_cls_to_field_count or new_to_legacy_hash.get(cls_hash) in client_cls_to_field_count ): ret_server_fields.extend(server_fields[server_pos : server_pos + count]) server_pos += count return ret_server_fields @classmethod def _set_field_values( cls, obj: Serializable, values: List[Any], client_cls_to_field_count: Optional[Dict[int, int]], is_primitive: bool = True, ): obj_class = type(obj) legacy_to_new_hash = obj_class._LEGACY_TO_NEW_NAME_HASH if is_primitive: server_cls_to_field_count = obj_class._CLS_TO_PRIMITIVE_FIELD_COUNT field_def_list = obj_class._PRIMITIVE_FIELDS else: server_cls_to_field_count = obj_class._CLS_TO_NON_PRIMITIVE_FIELD_COUNT field_def_list = obj_class._NON_PRIMITIVE_FIELDS server_fields = cls._prune_server_fields( client_cls_to_field_count, server_cls_to_field_count, field_def_list, legacy_to_new_hash, ) field_num, server_field_num = 0, 0 for cls_hash, count in client_cls_to_field_count.items(): # cut values and fields given field distribution # at client and server end cls_fields = server_fields[server_field_num : field_num + count] cls_values = values[field_num : field_num + count] for field, value in zip(cls_fields, cls_values): if is_primitive: value = _restore_primitive_placeholder(value) if not is_primitive or value is not _no_field_value: cls._set_field_value(obj, field, value) field_num += count try: server_field_num += server_cls_to_field_count[cls_hash] except KeyError: try: server_field_num += server_cls_to_field_count[ legacy_to_new_hash[cls_hash] ] except KeyError: # it is possible that certain type of field does not exist # at server side pass def deserial(self, serialized: List, context: Dict, subs: List) -> Serializable: obj_class_name, primitives = serialized obj_class = load_type(obj_class_name, Serializable) if type(primitives) is not list: primitives = msgpack.loads(primitives) obj = obj_class.__new__(obj_class) field_count_data = self.get_public_data( context, self._get_obj_field_count_key(obj) ) if field_count_data is None: # try using legacy field count key to get counts field_count_data = self.get_public_data( context, self._get_obj_field_count_key(obj, legacy=True) ) if field_count_data is None: self._log_legacy( context, ("MISSING_CLASS", obj_class_name), "Field count info of %s not found in serialized data", obj_class_name, level=logging.ERROR, ) raise MaxFrameDeprecationError( "Failed to deserialize request. Please upgrade your " "MaxFrame client to the latest release." ) else: self._log_legacy( context, ("LEGACY_CLASS", obj_class_name), "Class %s used in legacy client", obj_class_name, ) cls_to_prim_key, cls_to_non_prim_key = msgpack.loads(field_count_data) cls_to_prim_key = dict(cls_to_prim_key) cls_to_non_prim_key = dict(cls_to_non_prim_key) if primitives: self._set_field_values(obj, primitives, cls_to_prim_key, True) if obj_class._NON_PRIMITIVE_FIELDS: self._set_field_values(obj, subs[0], cls_to_non_prim_key, False) obj.__on_deserialize__() return obj class NoFieldValueSerializer(Serializer): def serial(self, obj, context): return [], [], True def deserial(self, serialized, context, subs): return _no_field_value SerializableSerializer.register(Serializable) NoFieldValueSerializer.register(_NoFieldValue)