python/pyfury/serializer.py (544 lines of code) (raw):
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import array
import itertools
import os
import pickle
import typing
from weakref import WeakValueDictionary
import pyfury.lib.mmh3
from pyfury.buffer import Buffer
from pyfury.codegen import (
gen_write_nullable_basic_stmts,
gen_read_nullable_basic_stmts,
compile_function,
)
from pyfury.error import ClassNotCompatibleError
from pyfury.lib.collection import WeakIdentityKeyDictionary
from pyfury.resolver import NULL_FLAG, NOT_NULL_VALUE_FLAG
try:
import numpy as np
except ImportError:
np = None
from pyfury._fury import (
NOT_NULL_INT64_FLAG,
BufferObject,
)
_WINDOWS = os.name == "nt"
from pyfury._serialization import ENABLE_FURY_CYTHON_SERIALIZATION
if ENABLE_FURY_CYTHON_SERIALIZATION:
from pyfury._serialization import ( # noqa: F401, F811
Serializer,
CrossLanguageCompatibleSerializer,
BooleanSerializer,
ByteSerializer,
Int16Serializer,
Int32Serializer,
Int64Serializer,
Float32Serializer,
Float64Serializer,
StringSerializer,
DateSerializer,
TimestampSerializer,
CollectionSerializer,
ListSerializer,
TupleSerializer,
StringArraySerializer,
SetSerializer,
MapSerializer,
SubMapSerializer,
EnumSerializer,
SliceSerializer,
)
else:
from pyfury._serializer import ( # noqa: F401 # pylint: disable=unused-import
Serializer,
CrossLanguageCompatibleSerializer,
BooleanSerializer,
ByteSerializer,
Int16Serializer,
Int32Serializer,
Int64Serializer,
Float32Serializer,
Float64Serializer,
StringSerializer,
DateSerializer,
TimestampSerializer,
CollectionSerializer,
ListSerializer,
TupleSerializer,
StringArraySerializer,
SetSerializer,
MapSerializer,
SubMapSerializer,
EnumSerializer,
SliceSerializer,
)
from pyfury.type import (
Int16ArrayType,
Int32ArrayType,
Int64ArrayType,
Float32ArrayType,
Float64ArrayType,
BoolNDArrayType,
Int16NDArrayType,
Int32NDArrayType,
Int64NDArrayType,
Float32NDArrayType,
Float64NDArrayType,
TypeId,
)
class NoneSerializer(Serializer):
def __init__(self, fury):
super().__init__(fury, None)
self.need_to_write_ref = False
def xwrite(self, buffer, value):
raise NotImplementedError
def xread(self, buffer):
raise NotImplementedError
def write(self, buffer, value):
pass
def read(self, buffer):
return None
class _PickleStub:
pass
class PickleStrongCacheStub:
pass
class PickleCacheStub:
pass
class PickleStrongCacheSerializer(Serializer):
"""If we can't create weak ref to object, use this cache serializer instead.
clear cache by threshold to avoid memory leak."""
__slots__ = "_cached", "_clear_threshold", "_counter"
def __init__(self, fury, clear_threshold: int = 1000):
super().__init__(fury, PickleStrongCacheStub)
self._cached = {}
self._clear_threshold = clear_threshold
def write(self, buffer, value):
serialized = self._cached.get(value)
if serialized is None:
serialized = pickle.dumps(value)
self._cached[value] = serialized
buffer.write_bytes_and_size(serialized)
if len(self._cached) == self._clear_threshold:
self._cached.clear()
def read(self, buffer):
return pickle.loads(buffer.read_bytes_and_size())
def xwrite(self, buffer, value):
raise NotImplementedError
def xread(self, buffer):
raise NotImplementedError
class PickleCacheSerializer(Serializer):
__slots__ = "_cached", "_reverse_cached"
def __init__(self, fury):
super().__init__(fury, PickleCacheStub)
self._cached = WeakIdentityKeyDictionary()
self._reverse_cached = WeakValueDictionary()
def write(self, buffer, value):
cache = self._cached.get(value)
if cache is None:
serialized = pickle.dumps(value)
value_hash = pyfury.lib.mmh3.hash_buffer(serialized)[0]
cache = value_hash, serialized
self._cached[value] = cache
buffer.write_int64(cache[0])
buffer.write_bytes_and_size(cache[1])
def read(self, buffer):
value_hash = buffer.read_int64()
value = self._reverse_cached.get(value_hash)
if value is None:
value = pickle.loads(buffer.read_bytes_and_size())
self._reverse_cached[value_hash] = value
else:
size = buffer.read_int32()
buffer.skip(size)
return value
def xwrite(self, buffer, value):
raise NotImplementedError
def xread(self, buffer):
raise NotImplementedError
class PandasRangeIndexSerializer(Serializer):
__slots__ = "_cached"
def __init__(self, fury):
import pandas as pd
super().__init__(fury, pd.RangeIndex)
def write(self, buffer, value):
fury = self.fury
start = value.start
stop = value.stop
step = value.step
if type(start) is int:
buffer.write_int16(NOT_NULL_INT64_FLAG)
buffer.write_varint64(start)
else:
if start is None:
buffer.write_int8(NULL_FLAG)
else:
buffer.write_int8(NOT_NULL_VALUE_FLAG)
fury.serialize_nonref(buffer, start)
if type(stop) is int:
buffer.write_int16(NOT_NULL_INT64_FLAG)
buffer.write_varint64(stop)
else:
if stop is None:
buffer.write_int8(NULL_FLAG)
else:
buffer.write_int8(NOT_NULL_VALUE_FLAG)
fury.serialize_nonref(buffer, stop)
if type(step) is int:
buffer.write_int16(NOT_NULL_INT64_FLAG)
buffer.write_varint64(step)
else:
if step is None:
buffer.write_int8(NULL_FLAG)
else:
buffer.write_int8(NOT_NULL_VALUE_FLAG)
fury.serialize_nonref(buffer, step)
fury.serialize_ref(buffer, value.dtype)
fury.serialize_ref(buffer, value.name)
def read(self, buffer):
if buffer.read_int8() == NULL_FLAG:
start = None
else:
start = self.fury.deserialize_nonref(buffer)
if buffer.read_int8() == NULL_FLAG:
stop = None
else:
stop = self.fury.deserialize_nonref(buffer)
if buffer.read_int8() == NULL_FLAG:
step = None
else:
step = self.fury.deserialize_nonref(buffer)
dtype = self.fury.deserialize_ref(buffer)
name = self.fury.deserialize_ref(buffer)
return self.type_(start, stop, step, dtype=dtype, name=name)
def xwrite(self, buffer, value):
raise NotImplementedError
def xread(self, buffer):
raise NotImplementedError
_jit_context = locals()
_ENABLE_FURY_PYTHON_JIT = os.environ.get("ENABLE_FURY_PYTHON_JIT", "True").lower() in (
"true",
"1",
)
class DataClassSerializer(Serializer):
def __init__(self, fury, clz: type):
super().__init__(fury, clz)
# This will get superclass type hints too.
self._type_hints = typing.get_type_hints(clz)
self._field_names = sorted(self._type_hints.keys())
self._has_slots = hasattr(clz, "__slots__")
# TODO compute hash
self._hash = len(self._field_names)
self._generated_write_method = self._gen_write_method()
self._generated_read_method = self._gen_read_method()
if _ENABLE_FURY_PYTHON_JIT:
# don't use `__slots__`, which will make instance method readonly
self.write = self._gen_write_method()
self.read = self._gen_read_method()
def _gen_write_method(self):
context = {}
counter = itertools.count(0)
buffer, fury, value, value_dict = "buffer", "fury", "value", "value_dict"
context[fury] = self.fury
stmts = [
f'"""write method for {self.type_}"""',
f"{buffer}.write_int32({self._hash})",
]
if not self._has_slots:
stmts.append(f"{value_dict} = {value}.__dict__")
for field_name in self._field_names:
field_type = self._type_hints[field_name]
field_value = f"field_value{next(counter)}"
if not self._has_slots:
stmts.append(f"{field_value} = {value_dict}['{field_name}']")
else:
stmts.append(f"{field_value} = {value}.{field_name}")
if field_type is bool:
stmts.extend(gen_write_nullable_basic_stmts(buffer, field_value, bool))
elif field_type == int:
stmts.extend(gen_write_nullable_basic_stmts(buffer, field_value, int))
elif field_type == float:
stmts.extend(gen_write_nullable_basic_stmts(buffer, field_value, float))
elif field_type == str:
stmts.extend(gen_write_nullable_basic_stmts(buffer, field_value, str))
else:
stmts.append(f"{fury}.write_ref_pyobject({buffer}, {field_value})")
self._write_method_code, func = compile_function(
f"write_{self.type_.__module__}_{self.type_.__qualname__}".replace(
".", "_"
),
[buffer, value],
stmts,
context,
)
return func
def _gen_read_method(self):
context = dict(_jit_context)
buffer, fury, obj_class, obj, obj_dict = (
"buffer",
"fury",
"obj_class",
"obj",
"obj_dict",
)
ref_resolver = "ref_resolver"
context[fury] = self.fury
context[obj_class] = self.type_
context[ref_resolver] = self.fury.ref_resolver
stmts = [
f'"""read method for {self.type_}"""',
f"{obj} = {obj_class}.__new__({obj_class})",
f"{ref_resolver}.reference({obj})",
f"read_hash = {buffer}.read_int32()",
f"if read_hash != {self._hash}:",
f""" raise ClassNotCompatibleError(
"Hash read_hash is not consistent with {self._hash} for {self.type_}")""",
]
if not self._has_slots:
stmts.append(f"{obj_dict} = {obj}.__dict__")
def set_action(value: str):
if not self._has_slots:
return f"{obj_dict}['{field_name}'] = {value}"
else:
return f"{obj}.{field_name} = {value}"
for field_name in self._field_names:
field_type = self._type_hints[field_name]
if field_type is bool:
stmts.extend(gen_read_nullable_basic_stmts(buffer, bool, set_action))
elif field_type == int:
stmts.extend(gen_read_nullable_basic_stmts(buffer, int, set_action))
elif field_type == float:
stmts.extend(gen_read_nullable_basic_stmts(buffer, float, set_action))
elif field_type == str:
stmts.extend(gen_read_nullable_basic_stmts(buffer, str, set_action))
else:
stmts.append(f"{obj}.{field_name} = {fury}.read_ref_pyobject({buffer})")
stmts.append(f"return {obj}")
self._read_method_code, func = compile_function(
f"read_{self.type_.__module__}_{self.type_.__qualname__}".replace(".", "_"),
[buffer],
stmts,
context,
)
return func
def write(self, buffer, value):
buffer.write_int32(self._hash)
for field_name in self._field_names:
field_value = getattr(value, field_name)
self.fury.serialize_ref(buffer, field_value)
def read(self, buffer):
hash_ = buffer.read_int32()
if hash_ != self._hash:
raise ClassNotCompatibleError(
f"Hash {hash_} is not consistent with {self._hash} "
f"for class {self.type_}",
)
obj = self.type_.__new__(self.type_)
self.fury.ref_resolver.reference(obj)
for field_name in self._field_names:
field_value = self.fury.deserialize_ref(buffer)
setattr(
obj,
field_name,
field_value,
)
return obj
def xwrite(self, buffer: Buffer, value):
raise NotImplementedError
def xread(self, buffer):
raise NotImplementedError
# Use numpy array or python array module.
typecode_dict = (
{
# use bytes serializer for byte array.
"h": (2, Int16ArrayType, TypeId.INT16_ARRAY),
"i": (4, Int32ArrayType, TypeId.INT32_ARRAY),
"l": (8, Int64ArrayType, TypeId.INT64_ARRAY),
"f": (4, Float32ArrayType, TypeId.FLOAT32_ARRAY),
"d": (8, Float64ArrayType, TypeId.FLOAT64_ARRAY),
}
if not _WINDOWS
else {
"h": (2, Int16ArrayType, TypeId.INT16_ARRAY),
"l": (4, Int32ArrayType, TypeId.INT32_ARRAY),
"q": (8, Int64ArrayType, TypeId.INT64_ARRAY),
"f": (4, Float32ArrayType, TypeId.FLOAT32_ARRAY),
"d": (8, Float64ArrayType, TypeId.FLOAT64_ARRAY),
}
)
typeid_code = (
{
TypeId.INT16_ARRAY: "h",
TypeId.INT32_ARRAY: "i",
TypeId.INT64_ARRAY: "l",
TypeId.FLOAT32_ARRAY: "f",
TypeId.FLOAT64_ARRAY: "d",
}
if not _WINDOWS
else {
TypeId.INT16_ARRAY: "h",
TypeId.INT32_ARRAY: "l",
TypeId.INT64_ARRAY: "q",
TypeId.FLOAT32_ARRAY: "f",
TypeId.FLOAT64_ARRAY: "d",
}
)
class PyArraySerializer(CrossLanguageCompatibleSerializer):
typecode_dict = typecode_dict
typecodearray_type = (
{
"h": Int16ArrayType,
"i": Int32ArrayType,
"l": Int64ArrayType,
"f": Float32ArrayType,
"d": Float64ArrayType,
}
if not _WINDOWS
else {
"h": Int16ArrayType,
"l": Int32ArrayType,
"q": Int64ArrayType,
"f": Float32ArrayType,
"d": Float64ArrayType,
}
)
def __init__(self, fury, ftype, type_id: str):
super().__init__(fury, ftype)
self.typecode = typeid_code[type_id]
self.itemsize, ftype, self.type_id = typecode_dict[self.typecode]
def xwrite(self, buffer, value):
assert value.itemsize == self.itemsize
view = memoryview(value)
assert view.format == self.typecode
assert view.itemsize == self.itemsize
assert view.c_contiguous # TODO handle contiguous
nbytes = len(value) * self.itemsize
buffer.write_varuint32(nbytes)
buffer.write_buffer(value)
def xread(self, buffer):
data = buffer.read_bytes_and_size()
arr = array.array(self.typecode, [])
arr.frombytes(data)
return arr
def write(self, buffer, value: array.array):
nbytes = len(value) * value.itemsize
buffer.write_string(value.typecode)
buffer.write_varuint32(nbytes)
buffer.write_buffer(value)
def read(self, buffer):
typecode = buffer.read_string()
data = buffer.read_bytes_and_size()
arr = array.array(typecode, [])
arr.frombytes(data)
return arr
class DynamicPyArraySerializer(Serializer):
def xwrite(self, buffer, value):
itemsize, ftype, type_id = typecode_dict[value.typecode]
view = memoryview(value)
nbytes = len(value) * itemsize
buffer.write_varuint32(type_id)
buffer.write_varuint32(nbytes)
if not view.c_contiguous:
buffer.write_bytes(value.tobytes())
else:
buffer.write_buffer(value)
def xread(self, buffer):
type_id = buffer.read_varint32()
typecode = typeid_code[type_id]
data = buffer.read_bytes_and_size()
arr = array.array(typecode, [])
arr.frombytes(data)
return arr
def write(self, buffer, value):
buffer.write_varuint32(PickleSerializer.PICKLE_CLASS_ID)
self.fury.handle_unsupported_write(buffer, value)
def read(self, buffer):
return self.fury.handle_unsupported_read(buffer)
if np:
_np_dtypes_dict = (
{
# use bytes serializer for byte array.
np.dtype(np.bool_): (1, "?", BoolNDArrayType, TypeId.BOOL_ARRAY),
np.dtype(np.int16): (2, "h", Int16NDArrayType, TypeId.INT16_ARRAY),
np.dtype(np.int32): (4, "i", Int32NDArrayType, TypeId.INT32_ARRAY),
np.dtype(np.int64): (8, "l", Int64NDArrayType, TypeId.INT64_ARRAY),
np.dtype(np.float32): (4, "f", Float32NDArrayType, TypeId.FLOAT32_ARRAY),
np.dtype(np.float64): (8, "d", Float64NDArrayType, TypeId.FLOAT64_ARRAY),
}
if not _WINDOWS
else {
np.dtype(np.bool_): (1, "?", BoolNDArrayType, TypeId.BOOL_ARRAY),
np.dtype(np.int16): (2, "h", Int16NDArrayType, TypeId.INT16_ARRAY),
np.dtype(np.int32): (4, "l", Int32NDArrayType, TypeId.INT32_ARRAY),
np.dtype(np.int64): (8, "q", Int64NDArrayType, TypeId.INT64_ARRAY),
np.dtype(np.float32): (4, "f", Float32NDArrayType, TypeId.FLOAT32_ARRAY),
np.dtype(np.float64): (8, "d", Float64NDArrayType, TypeId.FLOAT64_ARRAY),
}
)
else:
_np_dtypes_dict = {}
class Numpy1DArraySerializer(Serializer):
dtypes_dict = _np_dtypes_dict
def __init__(self, fury, ftype, dtype):
super().__init__(fury, ftype)
self.dtype = dtype
self.itemsize, self.format, self.typecode, self.type_id = _np_dtypes_dict[
self.dtype
]
def xwrite(self, buffer, value):
assert value.itemsize == self.itemsize
view = memoryview(value)
try:
assert view.format == self.typecode
except AssertionError as e:
raise e
assert view.itemsize == self.itemsize
nbytes = len(value) * self.itemsize
buffer.write_varuint32(nbytes)
if self.dtype == np.dtype("bool") or not view.c_contiguous:
buffer.write_bytes(value.tobytes())
else:
buffer.write_buffer(value)
def xread(self, buffer):
data = buffer.read_bytes_and_size()
return np.frombuffer(data, dtype=self.dtype)
def write(self, buffer, value):
buffer.write_int8(PickleSerializer.PICKLE_CLASS_ID)
self.fury.handle_unsupported_write(buffer, value)
def read(self, buffer):
return self.fury.handle_unsupported_read(buffer)
class NDArraySerializer(Serializer):
def xwrite(self, buffer, value):
itemsize, typecode, ftype, type_id = _np_dtypes_dict[value.dtype]
view = memoryview(value)
nbytes = len(value) * itemsize
buffer.write_varuint32(type_id)
buffer.write_varuint32(nbytes)
if value.dtype == np.dtype("bool") or not view.c_contiguous:
buffer.write_bytes(value.tobytes())
else:
buffer.write_buffer(value)
def xread(self, buffer):
raise NotImplementedError("Multi-dimensional array not supported currently")
def write(self, buffer, value):
buffer.write_int8(PickleSerializer.PICKLE_CLASS_ID)
self.fury.handle_unsupported_write(buffer, value)
def read(self, buffer):
return self.fury.handle_unsupported_read(buffer)
class BytesSerializer(CrossLanguageCompatibleSerializer):
def write(self, buffer, value):
self.fury.write_buffer_object(buffer, BytesBufferObject(value))
def read(self, buffer):
fury_buf = self.fury.read_buffer_object(buffer)
return fury_buf.to_pybytes()
class BytesBufferObject(BufferObject):
__slots__ = ("binary",)
def __init__(self, binary: bytes):
self.binary = binary
def total_bytes(self) -> int:
return len(self.binary)
def write_to(self, buffer: "Buffer"):
buffer.write_bytes(self.binary)
def to_buffer(self) -> "Buffer":
return Buffer(self.binary)
class PickleSerializer(Serializer):
PICKLE_CLASS_ID = 96
def xwrite(self, buffer, value):
raise NotImplementedError
def xread(self, buffer):
raise NotImplementedError
def write(self, buffer, value):
self.fury.handle_unsupported_write(buffer, value)
def read(self, buffer):
return self.fury.handle_unsupported_read(buffer)