python/pyfury/resolver.py (132 lines of code) (raw):

# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. import logging from abc import ABC, abstractmethod from typing import Dict, List, Any, Tuple try: import numpy as np except ImportError: np = None logger = logging.getLogger(__name__) NULL_FLAG = -3 # This flag indicates that object is a not-null value. # We don't use another byte to indicate REF, so that we can save one byte. REF_FLAG = -2 # this flag indicates that the object is a non-null value. NOT_NULL_VALUE_FLAG = -1 # this flag indicates that the object is a referencable and first read. REF_VALUE_FLAG = 0 class RefResolver(ABC): @abstractmethod def write_ref_or_null(self, buffer, obj): """ Write reference and tag for the obj if the obj has been written previously, write null/not-null tag otherwise. Returns ------- true if no bytes need to be written for the object. """ @abstractmethod def read_ref_or_null(self, buffer): """ Returns ------- `REF_FLAG` if a reference to a previously read object was read. `NULL_FLAG` if the object is null. `REF_VALUE_FLAG` if the object is not null and reference tracking is not enabled or the object is first read. """ @abstractmethod def preserve_ref_id(self) -> int: """ Preserve a reference id, which is used by `setReadObject` to set up reference for object that is first deserialized. Returns ------- a reference id or -1 if reference is not enabled. """ @abstractmethod def try_preserve_ref_id(self, buffer) -> int: """ Preserve and return a `refId` which is `>=` {@link NOT_NULL_VALUE_FLAG} if the value is not null. If the value is referencable value, the `refId` will be {@link #preserveReferenceId}. Returns ------- a reference id """ @abstractmethod def reference(self, obj): """ Call this method immediately after composited object such as object array/map/collection/bean is created, so that circular reference can be deserialized correctly. """ @abstractmethod def get_read_object(self, id_=None): """ Returns ------- the object for the specified id. """ @abstractmethod def set_read_object(self, id_, obj): """ Sets the id for an object that has been read. Parameters ---------- id_: int The id from {@link #nextReadRefId)}. obj: the object that has been read """ @abstractmethod def reset(self): pass @abstractmethod def reset_write(self): pass @abstractmethod def reset_read(self): pass class MapRefResolver(RefResolver): written_objects: Dict[int, Tuple[int, Any]] # id(obj) -> (ref_id, obj) read_objects: List[Any] read_ref_ids: List[int] def __init__(self): self.written_objects = dict() self.read_objects = list() self.read_ref_ids = list() self.read_object = None def write_ref_or_null(self, buffer, obj): if obj is None: buffer.write_int8(NULL_FLAG) return True else: object_id = id(obj) written_id = self.written_objects.get(object_id, None) # The obj has been written previously. if written_id is not None: buffer.write_int8(REF_FLAG) buffer.write_varuint32(written_id[0]) return True else: written_id = len(self.written_objects) # Hold object to avoid tmp object gc when serialize nested # fields/objects. self.written_objects[object_id] = (written_id, obj) buffer.write_int8(REF_VALUE_FLAG) return False def read_ref_or_null(self, buffer): head_flag = buffer.read_int8() if head_flag == REF_FLAG: # read reference id and get object from reference resolver ref_id = buffer.read_varuint32() self.read_object = self.get_read_object(ref_id) return REF_FLAG else: self.read_object = None return head_flag def preserve_ref_id(self) -> int: next_read_ref_id = len(self.read_objects) self.read_objects.append(None) self.read_ref_ids.append(next_read_ref_id) return next_read_ref_id def try_preserve_ref_id(self, buffer) -> int: head_flag = buffer.read_int8() if head_flag == REF_FLAG: # read reference id and get object from reference resolver ref_id = buffer.read_varuint32() self.read_object = self.get_read_object(id_=ref_id) else: self.read_object = None if head_flag == REF_VALUE_FLAG: return self.preserve_ref_id() # `head_flag` except `REF_FLAG` can be used as stub reference id because we use # `refId >= NOT_NULL_VALUE_FLAG` to read data. return head_flag def reference(self, obj): ref_id = self.read_ref_ids.pop() self.set_read_object(ref_id, obj) def get_read_object(self, id_=None): if id_ is None: return self.read_object return self.read_objects[id_] def set_read_object(self, id_, obj): if id_ >= 0: if id_ >= len(self.read_objects): raise RuntimeError(f"Ref id {id_} invalid") self.read_objects[id_] = obj def reset(self): self.reset_write() self.reset_read() def reset_write(self): self.written_objects.clear() def reset_read(self): self.read_objects.clear() self.read_ref_ids.clear() self.read_object = None class NoRefResolver(RefResolver): def write_ref_or_null(self, buffer, obj): if obj is None: buffer.write_int8(NULL_FLAG) return True else: buffer.write_int8(NOT_NULL_VALUE_FLAG) return False def read_ref_or_null(self, buffer): return buffer.read_int8() def preserve_ref_id(self) -> int: return -1 def try_preserve_ref_id(self, buffer) -> int: # `NOT_NULL_VALUE_FLAG` can be used as stub reference id because we use # `refId >= NOT_NULL_VALUE_FLAG` to read data. return buffer.read_int8() def reference(self, obj): pass def get_read_object(self, id_=None): return None def set_read_object(self, id_, obj): pass def reset(self): pass def reset_write(self): pass def reset_read(self): pass