python/pyfury/_registry.py (512 lines of code) (raw):

# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. import array import dataclasses import datetime import enum import functools import logging from typing import TypeVar, Union from enum import Enum from pyfury._serialization import ENABLE_FURY_CYTHON_SERIALIZATION from pyfury import Language from pyfury.error import TypeUnregisteredError from pyfury.serializer import ( Serializer, Numpy1DArraySerializer, NDArraySerializer, PyArraySerializer, DynamicPyArraySerializer, _PickleStub, PickleStrongCacheStub, PickleCacheStub, NoneSerializer, BooleanSerializer, ByteSerializer, Int16Serializer, Int32Serializer, Int64Serializer, Float32Serializer, Float64Serializer, StringSerializer, DateSerializer, TimestampSerializer, BytesSerializer, ListSerializer, TupleSerializer, MapSerializer, SetSerializer, EnumSerializer, SliceSerializer, PickleCacheSerializer, PickleStrongCacheSerializer, PickleSerializer, DataClassSerializer, ) from pyfury._struct import ComplexObjectSerializer from pyfury.meta.metastring import MetaStringEncoder, MetaStringDecoder from pyfury.type import ( TypeId, Int8Type, Int16Type, Int32Type, Int64Type, Float32Type, Float64Type, load_class, ) from pyfury._fury import ( DYNAMIC_TYPE_ID, # preserve 0 as flag for class id not set in ClassInfo` NO_CLASS_ID, ) try: import numpy as np except ImportError: np = None logger = logging.getLogger(__name__) if ENABLE_FURY_CYTHON_SERIALIZATION: from pyfury._serialization import ClassInfo else: class ClassInfo: __slots__ = ( "cls", "type_id", "serializer", "namespace_bytes", "typename_bytes", "dynamic_type", ) def __init__( self, cls: type = None, type_id: int = NO_CLASS_ID, serializer: Serializer = None, namespace_bytes=None, typename_bytes=None, dynamic_type: bool = False, ): self.cls = cls self.type_id = type_id self.serializer = serializer self.namespace_bytes = namespace_bytes self.typename_bytes = typename_bytes self.dynamic_type = dynamic_type def __repr__(self): return ( f"ClassInfo(cls={self.cls}, type_id={self.type_id}, " f"serializer={self.serializer})" ) class ClassResolver: __slots__ = ( "fury", "_metastr_to_str", "_type_id_counter", "_classes_info", "_hash_to_metastring", "_metastr_to_class", "_hash_to_classinfo", "_dynamic_id_to_classinfo_list", "_dynamic_id_to_metastr_list", "_dynamic_write_string_id", "_dynamic_written_metastr", "_ns_type_to_classinfo", "_named_type_to_classinfo", "namespace_encoder", "namespace_decoder", "typename_encoder", "typename_decoder", "require_registration", "metastring_resolver", "language", "_type_id_to_classinfo", ) def __init__(self, fury): self.fury = fury self.metastring_resolver = fury.metastring_resolver self.language = fury.language self.require_registration = fury.require_class_registration self._metastr_to_str = dict() self._metastr_to_class = dict() self._hash_to_metastring = dict() self._hash_to_classinfo = dict() self._dynamic_written_metastr = [] self._type_id_to_classinfo = dict() self._type_id_counter = 64 self._dynamic_write_string_id = 0 # hold objects to avoid gc, since `flat_hash_map/vector` doesn't # hold python reference. self._classes_info = dict() self._ns_type_to_classinfo = dict() self._named_type_to_classinfo = dict() self.namespace_encoder = MetaStringEncoder(".", "_") self.namespace_decoder = MetaStringDecoder(".", "_") self.typename_encoder = MetaStringEncoder("$", "_") self.typename_decoder = MetaStringDecoder("$", "_") def initialize(self): self._initialize_xlang() if self.fury.language == Language.PYTHON: self._initialize_py() def _initialize_py(self): register = functools.partial(self._register_type, internal=True) register( _PickleStub, type_id=PickleSerializer.PICKLE_CLASS_ID, serializer=PickleSerializer, ) register( PickleStrongCacheStub, type_id=97, serializer=PickleStrongCacheSerializer(self.fury), ) register( PickleCacheStub, type_id=98, serializer=PickleCacheSerializer(self.fury), ) register(type(None), serializer=NoneSerializer) register(tuple, serializer=TupleSerializer) register(slice, serializer=SliceSerializer) def _initialize_xlang(self): register = functools.partial(self._register_type, internal=True) register(None, type_id=TypeId.NA, serializer=NoneSerializer) register(bool, type_id=TypeId.BOOL, serializer=BooleanSerializer) register(Int8Type, type_id=TypeId.INT8, serializer=ByteSerializer) register(Int16Type, type_id=TypeId.INT16, serializer=Int16Serializer) register(Int32Type, type_id=TypeId.INT32, serializer=Int32Serializer) register(Int64Type, type_id=TypeId.INT64, serializer=Int64Serializer) register(int, type_id=TypeId.INT64, serializer=Int64Serializer) register( Float32Type, type_id=TypeId.FLOAT32, serializer=Float32Serializer, ) register( Float64Type, type_id=TypeId.FLOAT64, serializer=Float64Serializer, ) register(float, type_id=TypeId.FLOAT64, serializer=Float64Serializer) register(str, type_id=TypeId.STRING, serializer=StringSerializer) # TODO(chaokunyang) DURATION DECIMAL register( datetime.datetime, type_id=TypeId.TIMESTAMP, serializer=TimestampSerializer ) register(datetime.date, type_id=TypeId.LOCAL_DATE, serializer=DateSerializer) register(bytes, type_id=TypeId.BINARY, serializer=BytesSerializer) for itemsize, ftype, typeid in PyArraySerializer.typecode_dict.values(): register( ftype, type_id=typeid, serializer=PyArraySerializer(self.fury, ftype, typeid), ) register( array.array, type_id=DYNAMIC_TYPE_ID, serializer=DynamicPyArraySerializer ) if np: # overwrite pyarray with same type id. # if pyarray are needed, one must annotate that value with XXXArrayType # as a field of a struct. for dtype, ( itemsize, format, ftype, typeid, ) in Numpy1DArraySerializer.dtypes_dict.items(): register( ftype, type_id=typeid, serializer=Numpy1DArraySerializer(self.fury, ftype, dtype), ) register(np.ndarray, type_id=DYNAMIC_TYPE_ID, serializer=NDArraySerializer) register(list, type_id=TypeId.LIST, serializer=ListSerializer) register(set, type_id=TypeId.SET, serializer=SetSerializer) register(dict, type_id=TypeId.MAP, serializer=MapSerializer) try: import pyarrow as pa from pyfury.format.serializer import ( ArrowRecordBatchSerializer, ArrowTableSerializer, ) register( pa.RecordBatch, type_id=TypeId.ARROW_RECORD_BATCH, serializer=ArrowRecordBatchSerializer, ) register( pa.Table, type_id=TypeId.ARROW_TABLE, serializer=ArrowTableSerializer ) except Exception: pass def register_type( self, cls: Union[type, TypeVar], *, type_id: int = None, namespace: str = None, typename: str = None, serializer=None, ): return self._register_type( cls, type_id=type_id, namespace=namespace, typename=typename, serializer=serializer, ) def _register_type( self, cls: Union[type, TypeVar], *, type_id: int = None, namespace: str = None, typename: str = None, serializer=None, internal=False, ): """Register class with given type id or typename. If typename is not None, it will be used for cross-language serialization.""" if serializer is not None and not isinstance(serializer, Serializer): try: serializer = serializer(self.fury, cls) except BaseException: try: serializer = serializer(self.fury) except BaseException: serializer = serializer() n_params = len({typename, type_id, None}) - 1 if n_params == 0 and typename is None: type_id = self._next_type_id() if n_params == 2: raise TypeError( f"type name {typename} and id {type_id} should not be set at the same time" ) if type_id not in {0, None}: # multiple class can have same tpe id if type_id in self._type_id_to_classinfo and cls in self._classes_info: raise TypeError(f"{cls} registered already") elif cls in self._classes_info: raise TypeError(f"{cls} registered already") register_type = ( self._register_xtype if self.fury.language == Language.XLANG else self._register_pytype ) return register_type( cls, type_id=type_id, namespace=namespace, typename=typename, serializer=serializer, internal=internal, ) def _register_xtype( self, cls: Union[type, TypeVar], *, type_id: int = None, namespace: str = None, typename: str = None, serializer=None, internal=False, ): if serializer is None: if issubclass(cls, enum.Enum): serializer = EnumSerializer(self.fury, cls) type_id = ( TypeId.NAMED_ENUM if type_id is None else ((type_id << 8) + TypeId.ENUM) ) else: serializer = ComplexObjectSerializer(self.fury, cls) type_id = ( TypeId.NAMED_STRUCT if type_id is None else ((type_id << 8) + TypeId.STRUCT) ) elif not internal: type_id = ( TypeId.NAMED_EXT if type_id is None else ((type_id << 8) + TypeId.EXT) ) return self.__register_type( cls, type_id=type_id, serializer=serializer, namespace=namespace, typename=typename, internal=internal, ) def _register_pytype( self, cls: Union[type, TypeVar], *, type_id: int = None, namespace: str = None, typename: str = None, serializer: Serializer = None, internal: bool = False, ): return self.__register_type( cls, type_id=type_id, namespace=namespace, typename=typename, serializer=serializer, internal=internal, ) def __register_type( self, cls: Union[type, TypeVar], *, type_id: int = None, namespace: str = None, typename: str = None, serializer: Serializer = None, internal: bool = False, ): dynamic_type = type_id < 0 if not internal and serializer is None: serializer = self._create_serializer(cls) if typename is None: classinfo = ClassInfo(cls, type_id, serializer, None, None, dynamic_type) else: if namespace is None: splits = typename.rsplit(".", 1) if len(splits) == 2: namespace, typename = splits ns_metastr = self.namespace_encoder.encode(namespace or "") ns_meta_bytes = self.metastring_resolver.get_metastr_bytes(ns_metastr) type_metastr = self.typename_encoder.encode(typename) type_meta_bytes = self.metastring_resolver.get_metastr_bytes(type_metastr) classinfo = ClassInfo( cls, type_id, serializer, ns_meta_bytes, type_meta_bytes, dynamic_type ) self._named_type_to_classinfo[(namespace, typename)] = classinfo self._ns_type_to_classinfo[(ns_meta_bytes, type_meta_bytes)] = classinfo self._classes_info[cls] = classinfo if type_id > 0 and ( self.language == Language.PYTHON or not TypeId.is_namespaced_type(type_id) ): if type_id not in self._type_id_to_classinfo or not internal: self._type_id_to_classinfo[type_id] = classinfo self._classes_info[cls] = classinfo return classinfo def _next_type_id(self): type_id = self._type_id_counter = self._type_id_counter + 1 while type_id in self._type_id_to_classinfo: type_id = self._type_id_counter = self._type_id_counter + 1 return type_id def register_serializer(self, cls: Union[type, TypeVar], serializer): assert isinstance(cls, (type, TypeVar)), cls if cls not in self._classes_info: raise TypeUnregisteredError(f"{cls} not registered") classinfo = self._classes_info[cls] if self.fury.language == Language.PYTHON: classinfo.serializer = serializer return type_id = prev_type_id = classinfo.type_id self._type_id_to_classinfo.pop(prev_type_id) if classinfo.serializer is not serializer: if classinfo.typename_bytes is not None: type_id = classinfo.type_id & 0xFFFFFF00 | TypeId.NAMED_EXT else: type_id = classinfo.type_id & 0xFFFFFF00 | TypeId.EXT self._type_id_to_classinfo[type_id] = classinfo def get_serializer(self, cls: type): """ Returns ------- Returns or create serializer for the provided class """ return self.get_classinfo(cls).serializer def get_classinfo(self, cls, create=True): class_info = self._classes_info.get(cls) if class_info is not None: if class_info.serializer is None: class_info.serializer = self._create_serializer(cls) return class_info elif not create: return None if self.language != Language.PYTHON or ( self.require_registration and not issubclass(cls, Enum) ): raise TypeUnregisteredError(f"{cls} not registered") logger.info("Class %s not registered", cls) serializer = self._create_serializer(cls) type_id = None if self.language == Language.PYTHON: if isinstance(serializer, EnumSerializer): type_id = TypeId.NAMED_ENUM elif type(serializer) is PickleSerializer: type_id = PickleSerializer.PICKLE_CLASS_ID if not self.require_registration: if isinstance(serializer, DataClassSerializer): type_id = TypeId.NAMED_STRUCT if type_id is None: raise TypeUnregisteredError( f"{cls} must be registered using `fury.register_type` API" ) return self.__register_type( cls, type_id=type_id, namespace=cls.__module__, typename=cls.__qualname__, serializer=serializer, ) def _create_serializer(self, cls): for clz in cls.__mro__: class_info = self._classes_info.get(clz) if ( class_info and class_info.serializer and class_info.serializer.support_subclass() ): serializer = type(class_info.serializer)(self.fury, cls) break else: if dataclasses.is_dataclass(cls): from pyfury import DataClassSerializer serializer = DataClassSerializer(self.fury, cls) elif issubclass(cls, enum.Enum): serializer = EnumSerializer(self.fury, cls) else: serializer = PickleSerializer(self.fury, cls) return serializer def _load_metabytes_to_classinfo(self, ns_metabytes, type_metabytes): typeinfo = self._ns_type_to_classinfo.get((ns_metabytes, type_metabytes)) if typeinfo is not None: return typeinfo ns = ns_metabytes.decode(self.namespace_decoder) typename = type_metabytes.decode(self.typename_decoder) # the hash computed between languages may be different. typeinfo = self._named_type_to_classinfo.get((ns, typename)) if typeinfo is not None: self._ns_type_to_classinfo[(ns_metabytes, type_metabytes)] = typeinfo return typeinfo cls = load_class(ns + "#" + typename) classinfo = self.get_classinfo(cls) self._ns_type_to_classinfo[(ns_metabytes, type_metabytes)] = classinfo return classinfo def write_typeinfo(self, buffer, classinfo): if classinfo.dynamic_type: return type_id = classinfo.type_id internal_type_id = type_id & 0xFF buffer.write_varuint32(type_id) if TypeId.is_namespaced_type(internal_type_id): self.metastring_resolver.write_meta_string_bytes( buffer, classinfo.namespace_bytes ) self.metastring_resolver.write_meta_string_bytes( buffer, classinfo.typename_bytes ) def read_typeinfo(self, buffer): type_id = buffer.read_varuint32() internal_type_id = type_id & 0xFF if TypeId.is_namespaced_type(internal_type_id): ns_metabytes = self.metastring_resolver.read_meta_string_bytes(buffer) type_metabytes = self.metastring_resolver.read_meta_string_bytes(buffer) typeinfo = self._ns_type_to_classinfo.get((ns_metabytes, type_metabytes)) if typeinfo is None: ns = ns_metabytes.decode(self.namespace_decoder) typename = type_metabytes.decode(self.typename_decoder) typeinfo = self._named_type_to_classinfo.get((ns, typename)) if typeinfo is not None: self._ns_type_to_classinfo[ (ns_metabytes, type_metabytes) ] = typeinfo return typeinfo # TODO(chaokunyang) generate a dynamic class and serializer # when meta share is enabled. name = ns + "." + typename if ns else typename raise TypeUnregisteredError(f"{name} not registered") return typeinfo else: return self._type_id_to_classinfo[type_id] def reset(self): pass def reset_read(self): pass def reset_write(self): pass