python/pyfury/type.py (306 lines of code) (raw):
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import array
import dataclasses
import importlib
import inspect
import typing
from typing import TypeVar
from abc import ABC, abstractmethod
try:
import numpy as np
ndarray = np.ndarray
except ImportError:
np, ndarray = None, None
# modified from `fluent python`
def record_class_factory(cls_name, field_names):
"""
record_factory: create simple classes just for holding data fields
>>> Dog = record_class_factory('Dog', 'name weight owner')
>>> rex = Dog('Rex', 30, 'Bob')
>>> rex
Dog(name='Rex', weight=30, owner='Bob')
>>> name, weight, _ = rex
>>> name, weight
('Rex', 30)
>>> "{2}'s dog weighs {1}kg".format(*rex)
"Bob's dog weighs 30kg"
>>> rex.weight = 32
>>> rex
Dog(name='Rex', weight=32, owner='Bob')
>>> Dog.__mro__
(<class '_util.Dog'>, <class 'object'>)
The factory also accepts a list or tuple of identifiers:
>>> Dog = record_class_factory('Dog', ['name', 'weight', 'owner'])
>>> Dog.__slots__
('name', 'weight', 'owner')
"""
try:
field_names = field_names.replace(",", " ").split()
except AttributeError: # no .replace or .split
pass # assume it's already a sequence of identifiers
field_names = tuple(field_names)
def __init__(self, *args, **kwargs):
attrs = dict(zip(self.__slots__, args))
attrs.update(kwargs)
for name, value in attrs.items():
setattr(self, name, value)
def __iter__(self):
for name in self.__slots__:
yield getattr(self, name)
def __eq__(self, other):
if not isinstance(other, self.__class__):
return False
if not self.__slots__ == other.__slots__:
return False
else:
for name in self.__slots__:
if not getattr(self, name, None) == getattr(other, name, None):
return False
return True
def __hash__(self):
return hash([getattr(self, name, None) for name in self.__slots__])
def __str__(self):
values = ", ".join("{}={!r}".format(*i) for i in zip(self.__slots__, self))
return values
def __repr__(self):
values = ", ".join("{}={!r}".format(*i) for i in zip(self.__slots__, self))
return "{}({})".format(self.__class__.__name__, values)
def __reduce__(self):
return self.__class__, tuple(self)
cls_attrs = dict(
__slots__=field_names,
__init__=__init__,
__iter__=__iter__,
__eq__=__eq__,
__hash__=__hash__,
__str__=__str__,
__repr__=__repr__,
__reduce__=__reduce__,
)
cls_ = type(cls_name, (object,), cls_attrs)
# combined with __reduce__ to make it pickable
globals()[cls_name] = cls_
return cls_
def get_qualified_classname(obj):
import inspect
t = obj if inspect.isclass(obj) else type(obj)
return t.__module__ + "." + t.__name__
class TypeId:
"""
Fury type for cross-language serialization.
See `org.apache.fury.types.Type`
"""
# null value
NA = 0
# a boolean value (true or false).
BOOL = 1
# a 8-bit signed integer.
INT8 = 2
# a 16-bit signed integer.
INT16 = 3
# a 32-bit signed integer.
INT32 = 4
# a 32-bit signed integer which uses fury var_int32 encoding.
VAR_INT32 = 5
# a 64-bit signed integer.
INT64 = 6
# a 64-bit signed integer which uses fury PVL encoding.
VAR_INT64 = 7
# a 64-bit signed integer which uses fury SLI encoding.
SLI_INT64 = 8
# a 16-bit floating point number.
FLOAT16 = 9
# a 32-bit floating point number.
FLOAT32 = 10
# a 64-bit floating point number including NaN and Infinity.
FLOAT64 = 11
# a text string encoded using Latin1/UTF16/UTF-8 encoding.
STRING = 12
# a data type consisting of a set of named values. Rust enum with non-predefined field values are not supported as
# an enum.
ENUM = 13
# an enum whose value will be serialized as the registered name.
NAMED_ENUM = 14
# a morphic(final) type serialized by Fury Struct serializer. i.e., it doesn't have subclasses. Suppose we're
# deserializing `List[SomeClass]`, we can save dynamic serializer dispatch since `SomeClass` is morphic(final).
STRUCT = 15
# a morphic(final) type serialized by Fury compatible Struct serializer.
COMPATIBLE_STRUCT = 16
# a `struct` whose type mapping will be encoded as a name.
NAMED_STRUCT = 17
# a `compatible_struct` whose type mapping will be encoded as a name.
NAMED_COMPATIBLE_STRUCT = 18
# a type which will be serialized by a customized serializer.
EXT = 19
# an `ext` type whose type mapping will be encoded as a name.
NAMED_EXT = 20
# a sequence of objects.
LIST = 21
# an unordered set of unique elements.
SET = 22
# a map of key-value pairs. Mutable types such as `list/map/set/array/tensor/arrow` are not allowed as key of map.
MAP = 23
# an absolute length of time, independent of any calendar/timezone, as a count of nanoseconds.
DURATION = 24
# a point in time, independent of any calendar/timezone, as a count of nanoseconds. The count is relative
# to an epoch at UTC midnight on January 1, 1970.
TIMESTAMP = 25
# a naive date without timezone. The count is days relative to an epoch at UTC midnight on Jan 1, 1970.
LOCAL_DATE = 26
# exact decimal value represented as an integer value in two's complement.
DECIMAL = 27
# a variable-length array of bytes.
BINARY = 28
# a multidimensional array which every sub-array can have different sizes but all have the same type.
# only allow numeric components. Other arrays will be taken as List. The implementation should support the
# interoperability between array and list.
ARRAY = 29
# one dimensional bool array.
BOOL_ARRAY = 30
# one dimensional int8 array.
INT8_ARRAY = 31
# one dimensional int16 array.
INT16_ARRAY = 32
# one dimensional int32 array.
INT32_ARRAY = 33
# one dimensional int64 array.
INT64_ARRAY = 34
# one dimensional half_float_16 array.
FLOAT16_ARRAY = 35
# one dimensional float32 array.
FLOAT32_ARRAY = 36
# one dimensional float64 array.
FLOAT64_ARRAY = 37
# an arrow [record batch](https://arrow.apache.org/docs/cpp/tables.html#record-batches) object.
ARROW_RECORD_BATCH = 38
# an arrow [table](https://arrow.apache.org/docs/cpp/tables.html#tables) object.
ARROW_TABLE = 39
# BOUND id remains at 64
BOUND = 64
@staticmethod
def is_namespaced_type(type_id: int) -> bool:
return type_id in __NAMESPACED_TYPES__
__NAMESPACED_TYPES__ = {
TypeId.NAMED_EXT,
TypeId.NAMED_ENUM,
TypeId.NAMED_STRUCT,
TypeId.NAMED_COMPATIBLE_STRUCT,
}
Int8Type = TypeVar("Int8Type", bound=int)
Int16Type = TypeVar("Int16Type", bound=int)
Int32Type = TypeVar("Int32Type", bound=int)
Int64Type = TypeVar("Int64Type", bound=int)
Float32Type = TypeVar("Float32Type", bound=float)
Float64Type = TypeVar("Float64Type", bound=float)
_primitive_types = {
int,
float,
Int8Type,
Int16Type,
Int32Type,
Int64Type,
Float32Type,
Float64Type,
}
_primitive_types_ids = {
TypeId.BOOL,
TypeId.INT8,
TypeId.INT16,
TypeId.INT32,
TypeId.INT64,
TypeId.FLOAT16,
TypeId.FLOAT32,
TypeId.FLOAT64,
}
# `Union[type, TypeVar]` is not supported in py3.6, so skip adding type hints for `type_` # noqa: E501
# See more at https://github.com/python/typing/issues/492 and
# https://stackoverflow.com/questions/69427175/how-to-pass-forwardref-as-args-to-typevar-in-python-3-6 # noqa: E501
def is_primitive_type(type_) -> bool:
if type(type_) is int:
return type_ in _primitive_types_ids
return type_ in _primitive_types
_primitive_type_sizes = {
TypeId.BOOL: 1,
TypeId.INT8: 1,
TypeId.INT16: 2,
TypeId.INT32: 4,
TypeId.VAR_INT32: 4,
TypeId.INT64: 8,
TypeId.VAR_INT64: 8,
TypeId.FLOAT16: 2,
TypeId.FLOAT32: 4,
TypeId.FLOAT64: 8,
}
def get_primitive_type_size(type_id) -> int:
return _primitive_type_sizes.get(type_id, -1)
# Int8ArrayType = TypeVar("Int8ArrayType", bound=array.ArrayType)
BoolArrayType = TypeVar("BoolArrayType")
Int16ArrayType = TypeVar("Int16ArrayType", bound=array.ArrayType)
Int32ArrayType = TypeVar("Int32ArrayType", bound=array.ArrayType)
Int64ArrayType = TypeVar("Int64ArrayType", bound=array.ArrayType)
Float32ArrayType = TypeVar("Float32ArrayType", bound=array.ArrayType)
Float64ArrayType = TypeVar("Float64ArrayType", bound=array.ArrayType)
BoolNDArrayType = TypeVar("BoolNDArrayType", bound=ndarray)
Int16NDArrayType = TypeVar("Int16NDArrayType", bound=ndarray)
Int32NDArrayType = TypeVar("Int32NDArrayType", bound=ndarray)
Int64NDArrayType = TypeVar("Int64NDArrayType", bound=ndarray)
Float32NDArrayType = TypeVar("Float32NDArrayType", bound=ndarray)
Float64NDArrayType = TypeVar("Float64NDArrayType", bound=ndarray)
_py_array_types = {
# Int8ArrayType,
Int16ArrayType,
Int32ArrayType,
Int64ArrayType,
Float32ArrayType,
Float64ArrayType,
}
_np_array_types = {
BoolNDArrayType,
Int16NDArrayType,
Int32NDArrayType,
Int64NDArrayType,
Float32NDArrayType,
Float64NDArrayType,
}
_primitive_array_types = _py_array_types.union(_np_array_types)
def is_py_array_type(type_) -> bool:
return type_ in _py_array_types
_primitive_array_type_ids = {
TypeId.BOOL_ARRAY,
TypeId.INT8_ARRAY,
TypeId.INT16_ARRAY,
TypeId.INT32_ARRAY,
TypeId.INT64_ARRAY,
TypeId.FLOAT32_ARRAY,
TypeId.FLOAT64_ARRAY,
}
def is_primitive_array_type(type_) -> bool:
if type(type_) is int:
return type_ in _primitive_array_type_ids
return type_ in _primitive_array_types
def is_list_type(type_):
try:
# type_ may not be a instance of type
return issubclass(type_, typing.List)
except TypeError:
return False
def is_map_type(type_):
try:
# type_ may not be a instance of type
return issubclass(type_, typing.Dict)
except TypeError:
return False
def is_subclass(from_type, to_type):
try:
return issubclass(from_type, to_type)
except TypeError:
return False
class TypeVisitor(ABC):
@abstractmethod
def visit_list(self, field_name, elem_type, types_path=None):
pass
@abstractmethod
def visit_dict(self, field_name, key_type, value_type, types_path=None):
pass
@abstractmethod
def visit_customized(self, field_name, type_, types_path=None):
pass
@abstractmethod
def visit_other(self, field_name, type_, types_path=None):
pass
def infer_field(field_name, type_, visitor: TypeVisitor, types_path=None):
types_path = list(types_path or [])
types_path.append(type_)
origin = (
typing.get_origin(type_)
if hasattr(typing, "get_origin")
else getattr(type_, "__origin__", type_)
)
origin = origin or type_
args = (
typing.get_args(type_)
if hasattr(typing, "get_args")
else getattr(type_, "__args__", ())
)
if args:
if origin is list or origin == typing.List:
elem_type = args[0]
return visitor.visit_list(field_name, elem_type, types_path=types_path)
elif origin is dict or origin == typing.Dict:
key_type, value_type = args
return visitor.visit_dict(
field_name, key_type, value_type, types_path=types_path
)
else:
raise TypeError(
f"Collection types should be {list, dict} instead of {type_}"
)
else:
if is_function(origin) or not hasattr(origin, "__annotations__"):
return visitor.visit_other(field_name, type_, types_path=types_path)
else:
return visitor.visit_customized(field_name, type_, types_path=types_path)
def is_function(func):
return inspect.isfunction(func) or is_cython_function(func)
def is_cython_function(func):
return getattr(func, "func_name", None) is not None
def compute_string_hash(string):
string_bytes = string.encode("utf-8")
hash_ = 17
for b in string_bytes:
hash_ = hash_ * 31 + b
while hash_ >= 2**31 - 1:
hash_ = hash_ // 7
return hash_
def qualified_class_name(cls):
if isinstance(cls, TypeVar):
return cls.__module__ + "#" + cls.__name__
else:
return cls.__module__ + "#" + cls.__qualname__
def load_class(classname: str):
mod_name, cls_name = classname.rsplit("#", 1)
try:
mod = importlib.import_module(mod_name)
except ImportError as ex:
raise Exception(f"Can't import module {mod_name}") from ex
try:
classes = cls_name.split(".")
cls = getattr(mod, classes.pop(0))
while classes:
cls = getattr(cls, classes.pop(0))
return cls
except AttributeError as ex:
raise Exception(f"Can't import class {cls_name} from module {mod_name}") from ex
# This method is derived from https://github.com/ericvsmith/dataclasses/blob/5f6568c3468f872e8f447dc20666628387786397/dataclass_tools.py.
def dataslots(cls):
# Need to create a new class, since we can't set __slots__
# after a class has been created.
# Make sure __slots__ isn't already set.
if "__slots__" in cls.__dict__: # pragma: no cover
raise TypeError(f"{cls.__name__} already specifies __slots__")
# Create a new dict for our new class.
cls_dict = dict(cls.__dict__)
field_names = tuple(f.name for f in dataclasses.fields(cls))
cls_dict["__slots__"] = field_names
for field_name in field_names:
# Remove our attributes, if present. They'll still be
# available in _MARKER.
cls_dict.pop(field_name, None)
# Remove __dict__ itself.
cls_dict.pop("__dict__", None)
# And finally create the class.
qualname = getattr(cls, "__qualname__", None)
cls = type(cls)(cls.__name__, cls.__bases__, cls_dict)
if qualname is not None:
cls.__qualname__ = qualname
return cls