python/pyfury/_struct.py (229 lines of code) (raw):
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import datetime
import enum
import logging
import typing
from pyfury.buffer import Buffer
from pyfury.error import ClassNotCompatibleError
from pyfury.serializer import (
ListSerializer,
MapSerializer,
PickleSerializer,
Serializer,
)
from pyfury.type import (
TypeVisitor,
infer_field,
TypeId,
Int8Type,
Int16Type,
Int32Type,
Int64Type,
Float32Type,
Float64Type,
is_py_array_type,
compute_string_hash,
is_primitive_type,
)
from pyfury.type import (
is_list_type,
is_map_type,
get_primitive_type_size,
is_primitive_array_type,
)
from pyfury.type import is_subclass
logger = logging.getLogger(__name__)
basic_types = {
bool,
Int8Type,
Int16Type,
Int32Type,
Int64Type,
Float32Type,
Float64Type,
int,
float,
str,
bytes,
datetime.datetime,
datetime.date,
datetime.time,
}
class ComplexTypeVisitor(TypeVisitor):
def __init__(
self,
fury,
):
self.fury = fury
def visit_list(self, field_name, elem_type, types_path=None):
# Infer type recursively for type such as List[Dict[str, str]]
elem_serializer = infer_field("item", elem_type, self, types_path=types_path)
return ListSerializer(self.fury, list, elem_serializer)
def visit_dict(self, field_name, key_type, value_type, types_path=None):
# Infer type recursively for type such as Dict[str, Dict[str, str]]
key_serializer = infer_field("key", key_type, self, types_path=types_path)
value_serializer = infer_field("value", value_type, self, types_path=types_path)
return MapSerializer(self.fury, dict, key_serializer, value_serializer)
def visit_customized(self, field_name, type_, types_path=None):
return None
def visit_other(self, field_name, type_, types_path=None):
if is_subclass(type_, enum.Enum):
return self.fury.class_resolver.get_serializer(type_)
if type_ not in basic_types and not is_py_array_type(type_):
return None
serializer = self.fury.class_resolver.get_serializer(type_)
assert not isinstance(serializer, (PickleSerializer,))
return serializer
def _get_hash(fury, field_names: list, type_hints: dict):
visitor = StructHashVisitor(fury)
for index, key in enumerate(field_names):
infer_field(key, type_hints[key], visitor, types_path=[])
hash_ = visitor.get_hash()
assert hash_ != 0
return hash_
_UNKNOWN_TYPE_ID = -1
_time_types = {datetime.date, datetime.datetime, datetime.timedelta}
def _sort_fields(class_resolver, field_names, serializers):
boxed_types = []
collection_types = []
map_types = []
final_types = []
other_types = []
type_ids = []
for field_name, serializer in zip(field_names, serializers):
if serializer is None:
other_types.append((_UNKNOWN_TYPE_ID, serializer, field_name))
else:
type_ids.append(
(
class_resolver.get_classinfo(serializer.type_).type_id,
serializer,
field_name,
)
)
for type_id, serializer, field_name in type_ids:
if is_primitive_type(type_id):
container = boxed_types
elif is_list_type(serializer.type_):
container = collection_types
elif is_map_type(serializer.type_):
container = map_types
elif (
type_id in {TypeId.STRING}
or is_primitive_array_type(type_id)
or is_subclass(serializer.type_, enum.Enum)
) or serializer.type_ in _time_types:
container = final_types
else:
container = other_types
container.append((type_id, serializer, field_name))
def sorter(item):
return item[0], item[2]
def numeric_sorter(item):
id_ = item[0]
compress = id_ in {
TypeId.INT32,
TypeId.INT64,
TypeId.VAR_INT32,
TypeId.VAR_INT64,
}
return int(compress), -get_primitive_type_size(id_), item[2]
boxed_types = sorted(boxed_types, key=numeric_sorter)
collection_types = sorted(collection_types, key=sorter)
final_types = sorted(final_types, key=sorter)
map_types = sorted(map_types, key=sorter)
other_types = sorted(other_types, key=sorter)
all_types = boxed_types + final_types + other_types + collection_types + map_types
return [t[1] for t in all_types], [t[2] for t in all_types]
class ComplexObjectSerializer(Serializer):
def __init__(self, fury, clz):
super().__init__(fury, clz)
self._type_hints = typing.get_type_hints(clz)
self._field_names = sorted(self._type_hints.keys())
self._serializers = [None] * len(self._field_names)
visitor = ComplexTypeVisitor(fury)
for index, key in enumerate(self._field_names):
serializer = infer_field(key, self._type_hints[key], visitor, types_path=[])
self._serializers[index] = serializer
self._serializers, self._field_names = _sort_fields(
fury.class_resolver, self._field_names, self._serializers
)
from pyfury import Language
if self.fury.language == Language.PYTHON:
logger.warning(
"Type of class %s shouldn't be serialized using cross-language "
"serializer",
clz,
)
self._hash = 0
def write(self, buffer, value):
return self.xwrite(buffer, value)
def read(self, buffer):
return self.xread(buffer)
def xwrite(self, buffer: Buffer, value):
if self._hash == 0:
self._hash = _get_hash(self.fury, self._field_names, self._type_hints)
buffer.write_int32(self._hash)
for index, field_name in enumerate(self._field_names):
field_value = getattr(value, field_name)
serializer = self._serializers[index]
self.fury.xserialize_ref(buffer, field_value, serializer=serializer)
def xread(self, buffer):
if self._hash == 0:
self._hash = _get_hash(self.fury, self._field_names, self._type_hints)
hash_ = buffer.read_int32()
if hash_ != self._hash:
raise ClassNotCompatibleError(
f"Hash {hash_} is not consistent with {self._hash} "
f"for class {self.type_}",
)
obj = self.type_.__new__(self.type_)
self.fury.ref_resolver.reference(obj)
for index, field_name in enumerate(self._field_names):
serializer = self._serializers[index]
field_value = self.fury.xdeserialize_ref(buffer, serializer=serializer)
setattr(
obj,
field_name,
field_value,
)
return obj
class StructHashVisitor(TypeVisitor):
def __init__(
self,
fury,
):
self.fury = fury
self._hash = 17
def visit_list(self, field_name, elem_type, types_path=None):
# TODO add list element type to hash.
xtype_id = self.fury.class_resolver.get_classinfo(list).type_id
self._hash = self._compute_field_hash(self._hash, abs(xtype_id))
def visit_dict(self, field_name, key_type, value_type, types_path=None):
# TODO add map key/value type to hash.
xtype_id = self.fury.class_resolver.get_classinfo(dict).type_id
self._hash = self._compute_field_hash(self._hash, abs(xtype_id))
def visit_customized(self, field_name, type_, types_path=None):
classinfo = self.fury.class_resolver.get_classinfo(type_, create=False)
hash_value = 0
if classinfo is not None:
hash_value = classinfo.type_id
if TypeId.is_namespaced_type(classinfo.type_id):
hash_value = compute_string_hash(
classinfo.namespace + classinfo.typename
)
self._hash = self._compute_field_hash(self._hash, hash_value)
def visit_other(self, field_name, type_, types_path=None):
classinfo = self.fury.class_resolver.get_classinfo(type_, create=False)
if classinfo is None:
id_ = 0
else:
serializer = classinfo.serializer
assert not isinstance(serializer, (PickleSerializer,))
id_ = classinfo.type_id
assert id_ is not None, serializer
id_ = abs(id_)
self._hash = self._compute_field_hash(self._hash, id_)
@staticmethod
def _compute_field_hash(hash_, id_):
new_hash = hash_ * 31 + id_
while new_hash >= 2**31 - 1:
new_hash = new_hash // 7
return new_hash
def get_hash(self):
return self._hash