serde-generate/runtime/python/serde_binary/__init__.py (311 lines of code) (raw):
# Copyright (c) Facebook, Inc. and its affiliates
# SPDX-License-Identifier: MIT OR Apache-2.0
"""
Module describing the "binary" serialization formats.
Note: This internal module is currently only meant to share code between the BCS and bincode formats. Internal APIs could change in the future.
"""
import dataclasses
import collections
import io
import typing
from typing import get_type_hints
import serde_types as st
@dataclasses.dataclass
class BinarySerializer:
"""Serialization primitives for binary formats (abstract class).
"Binary" serialization formats may differ in the way they encode sequence lengths, variant
index, and how they sort map entries (or not).
"""
output: io.BytesIO
container_depth_budget: typing.Optional[int]
primitive_type_serializer: typing.Mapping = dataclasses.field(init=False)
def __post_init__(self):
self.primitive_type_serializer = {
bool: self.serialize_bool,
st.uint8: self.serialize_u8,
st.uint16: self.serialize_u16,
st.uint32: self.serialize_u32,
st.uint64: self.serialize_u64,
st.uint128: self.serialize_u128,
st.int8: self.serialize_i8,
st.int16: self.serialize_i16,
st.int32: self.serialize_i32,
st.int64: self.serialize_i64,
st.int128: self.serialize_i128,
st.float32: self.serialize_f32,
st.float64: self.serialize_f64,
st.unit: self.serialize_unit,
st.char: self.serialize_char,
str: self.serialize_str,
bytes: self.serialize_bytes,
}
def serialize_bytes(self, value: bytes):
self.serialize_len(len(value))
self.output.write(value)
def serialize_str(self, value: str):
self.serialize_bytes(value.encode())
def serialize_unit(self, value: st.unit):
pass
def serialize_bool(self, value: bool):
self.output.write(int(value).to_bytes(1, "little", signed=False))
def serialize_u8(self, value: st.uint8):
self.output.write(int(value).to_bytes(1, "little", signed=False))
def serialize_u16(self, value: st.uint16):
self.output.write(int(value).to_bytes(2, "little", signed=False))
def serialize_u32(self, value: st.uint32):
self.output.write(int(value).to_bytes(4, "little", signed=False))
def serialize_u64(self, value: st.uint64):
self.output.write(int(value).to_bytes(8, "little", signed=False))
def serialize_u128(self, value: st.uint128):
self.output.write(int(value).to_bytes(16, "little", signed=False))
def serialize_i8(self, value: st.uint8):
self.output.write(int(value).to_bytes(1, "little", signed=True))
def serialize_i16(self, value: st.uint16):
self.output.write(int(value).to_bytes(2, "little", signed=True))
def serialize_i32(self, value: st.uint32):
self.output.write(int(value).to_bytes(4, "little", signed=True))
def serialize_i64(self, value: st.uint64):
self.output.write(int(value).to_bytes(8, "little", signed=True))
def serialize_i128(self, value: st.uint128):
self.output.write(int(value).to_bytes(16, "little", signed=True))
def serialize_f32(self, value: st.float32):
raise NotImplementedError
def serialize_f64(self, value: st.float64):
raise NotImplementedError
def serialize_char(self, value: st.char):
raise NotImplementedError
def get_buffer_offset(self) -> int:
return len(self.output.getbuffer())
def get_buffer(self) -> bytes:
return self.output.getvalue()
def increase_container_depth(self):
if self.container_depth_budget is not None:
if self.container_depth_budget == 0:
raise st.SerializationError("Exceeded maximum container depth")
self.container_depth_budget -= 1
def decrease_container_depth(self):
if self.container_depth_budget is not None:
self.container_depth_budget += 1
def serialize_len(self, value: int):
raise NotImplementedError
def serialize_variant_index(self, value: int):
raise NotImplementedError
def sort_map_entries(self, offsets: typing.List[int]):
raise NotImplementedError
# noqa: C901
def serialize_any(self, obj: typing.Any, obj_type):
if obj_type in self.primitive_type_serializer:
self.primitive_type_serializer[obj_type](obj)
elif hasattr(obj_type, "__origin__"): # Generic type
types = getattr(obj_type, "__args__")
if getattr(obj_type, "__origin__") == collections.abc.Sequence: # Sequence
assert len(types) == 1
item_type = types[0]
self.serialize_len(len(obj))
for item in obj:
self.serialize_any(item, item_type)
elif getattr(obj_type, "__origin__") == tuple: # Tuple
if len(types) is not 1 or types[0] is not ():
for i in range(len(obj)):
self.serialize_any(obj[i], types[i])
elif getattr(obj_type, "__origin__") == typing.Union: # Option
assert len(types) == 2 and types[1] == type(None)
if obj is None:
self.output.write(b"\x00")
else:
self.output.write(b"\x01")
self.serialize_any(obj, types[0])
elif getattr(obj_type, "__origin__") == dict: # Map
assert len(types) == 2
self.serialize_len(len(obj))
offsets = []
for key, value in obj.items():
offsets.append(self.get_buffer_offset())
self.serialize_any(key, types[0])
self.serialize_any(value, types[1])
self.sort_map_entries(offsets)
else:
raise st.SerializationError("Unexpected type", obj_type)
else:
if not dataclasses.is_dataclass(obj_type): # Enum
if not hasattr(obj_type, "VARIANTS"):
raise st.SerializationError("Unexpected type", obj_type)
if not hasattr(obj, "INDEX"):
raise st.SerializationError(
"Wrong Value for the type", obj, obj_type
)
self.serialize_variant_index(obj.__class__.INDEX)
# Proceed to variant
obj_type = obj_type.VARIANTS[obj.__class__.INDEX]
if not dataclasses.is_dataclass(obj_type):
raise st.SerializationError("Unexpected type", obj_type)
# pyre-ignore
if not isinstance(obj, obj_type):
raise st.SerializationError("Wrong Value for the type", obj, obj_type)
# Content of struct or variant
fields = dataclasses.fields(obj_type)
types = get_type_hints(obj_type)
self.increase_container_depth()
for field in fields:
field_value = obj.__dict__[field.name]
field_type = types[field.name]
self.serialize_any(field_value, field_type)
self.decrease_container_depth()
@dataclasses.dataclass
class BinaryDeserializer:
"""Deserialization primitives for binary formats (abstract class).
"Binary" serialization formats may differ in the way they encode sequence lengths, variant
index, and how they verify the ordering of keys in map entries (or not).
"""
input: io.BytesIO
container_depth_budget: typing.Optional[int]
primitive_type_deserializer: typing.Mapping = dataclasses.field(init=False)
def __post_init__(self):
self.primitive_type_deserializer = {
bool: self.deserialize_bool,
st.uint8: self.deserialize_u8,
st.uint16: self.deserialize_u16,
st.uint32: self.deserialize_u32,
st.uint64: self.deserialize_u64,
st.uint128: self.deserialize_u128,
st.int8: self.deserialize_i8,
st.int16: self.deserialize_i16,
st.int32: self.deserialize_i32,
st.int64: self.deserialize_i64,
st.int128: self.deserialize_i128,
st.float32: self.deserialize_f32,
st.float64: self.deserialize_f64,
st.unit: self.deserialize_unit,
st.char: self.deserialize_char,
str: self.deserialize_str,
bytes: self.deserialize_bytes,
}
def read(self, length: int) -> bytes:
value = self.input.read(length)
if value is None or len(value) < length:
raise st.DeserializationError("Input is too short")
return value
def deserialize_bytes(self) -> bytes:
length = self.deserialize_len()
return self.read(length)
def deserialize_str(self) -> str:
content = self.deserialize_bytes()
try:
return content.decode()
except UnicodeDecodeError:
raise st.DeserializationError("Invalid unicode string:", content)
def deserialize_unit(self) -> st.unit:
pass
def deserialize_bool(self) -> bool:
b = int.from_bytes(self.read(1), byteorder="little", signed=False)
if b == 0:
return False
elif b == 1:
return True
else:
raise st.DeserializationError("Unexpected boolean value:", b)
def deserialize_u8(self) -> st.uint8:
return st.uint8(int.from_bytes(self.read(1), byteorder="little", signed=False))
def deserialize_u16(self) -> st.uint16:
return st.uint16(int.from_bytes(self.read(2), byteorder="little", signed=False))
def deserialize_u32(self) -> st.uint32:
return st.uint32(int.from_bytes(self.read(4), byteorder="little", signed=False))
def deserialize_u64(self) -> st.uint64:
return st.uint64(int.from_bytes(self.read(8), byteorder="little", signed=False))
def deserialize_u128(self) -> st.uint128:
return st.uint128(
int.from_bytes(self.read(16), byteorder="little", signed=False)
)
def deserialize_i8(self) -> st.int8:
return st.int8(int.from_bytes(self.read(1), byteorder="little", signed=True))
def deserialize_i16(self) -> st.int16:
return st.int16(int.from_bytes(self.read(2), byteorder="little", signed=True))
def deserialize_i32(self) -> st.int32:
return st.int32(int.from_bytes(self.read(4), byteorder="little", signed=True))
def deserialize_i64(self) -> st.int64:
return st.int64(int.from_bytes(self.read(8), byteorder="little", signed=True))
def deserialize_i128(self) -> st.int128:
return st.int128(int.from_bytes(self.read(16), byteorder="little", signed=True))
def deserialize_f32(self) -> st.float32:
raise NotImplementedError
def deserialize_f64(self) -> st.float64:
raise NotImplementedError
def deserialize_char(self) -> st.char:
raise NotImplementedError
def get_buffer_offset(self) -> int:
return self.input.tell()
def get_remaining_buffer(self) -> bytes:
buf = self.input.getbuffer()
offset = self.get_buffer_offset()
return bytes(buf[offset:])
def increase_container_depth(self):
if self.container_depth_budget is not None:
if self.container_depth_budget == 0:
raise st.DeserializationError("Exceeded maximum container depth")
self.container_depth_budget -= 1
def decrease_container_depth(self):
if self.container_depth_budget is not None:
self.container_depth_budget += 1
def deserialize_len(self) -> int:
raise NotImplementedError
def deserialize_variant_index(self) -> int:
raise NotImplementedError
def check_that_key_slices_are_increasing(
self, slice1: typing.Tuple[int, int], slice2: typing.Tuple[int, int]
) -> bool:
raise NotImplementedError
# noqa
def deserialize_any(self, obj_type) -> typing.Any:
if obj_type in self.primitive_type_deserializer:
return self.primitive_type_deserializer[obj_type]()
elif hasattr(obj_type, "__origin__"): # Generic type
types = getattr(obj_type, "__args__")
if getattr(obj_type, "__origin__") == collections.abc.Sequence: # Sequence
assert len(types) == 1
item_type = types[0]
length = self.deserialize_len()
result = []
for i in range(0, length):
item = self.deserialize_any(item_type)
result.append(item)
return result
elif getattr(obj_type, "__origin__") == tuple: # Tuple
result = []
if len(types) is 1 and types[0] is ():
return tuple()
for i in range(len(types)):
item = self.deserialize_any(types[i])
result.append(item)
return tuple(result)
elif getattr(obj_type, "__origin__") == typing.Union: # Option
assert len(types) == 2 and types[1] == type(None)
tag = int.from_bytes(self.read(1), byteorder="little", signed=False)
if tag == 0:
return None
elif tag == 1:
return self.deserialize_any(types[0])
else:
raise st.DeserializationError("Wrong tag for Option value")
elif getattr(obj_type, "__origin__") == dict: # Map
assert len(types) == 2
length = self.deserialize_len()
result = dict()
previous_key_slice = None
for i in range(0, length):
key_start = self.get_buffer_offset()
key = self.deserialize_any(types[0])
key_end = self.get_buffer_offset()
value = self.deserialize_any(types[1])
key_slice = (key_start, key_end)
if previous_key_slice is not None:
self.check_that_key_slices_are_increasing(
previous_key_slice, key_slice
)
previous_key_slice = key_slice
result[key] = value
return result
else:
raise st.DeserializationError("Unexpected type", obj_type)
else:
# handle structs
if dataclasses.is_dataclass(obj_type):
values = []
fields = dataclasses.fields(obj_type)
typing_hints = get_type_hints(obj_type)
self.increase_container_depth()
for field in fields:
field_type = typing_hints[field.name]
field_value = self.deserialize_any(field_type)
values.append(field_value)
self.decrease_container_depth()
return obj_type(*values)
# handle variant
elif hasattr(obj_type, "VARIANTS"):
variant_index = self.deserialize_variant_index()
if variant_index not in range(len(obj_type.VARIANTS)):
raise st.DeserializationError(
"Unexpected variant index", variant_index
)
new_type = obj_type.VARIANTS[variant_index]
return self.deserialize_any(new_type)
else:
raise st.DeserializationError("Unexpected type", obj_type)