python/pyfury/format/infer.py (139 lines of code) (raw):
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import datetime
import typing
import pyarrow as pa
from functools import partial
from typing import Optional
from pyfury.type import get_qualified_classname, TypeVisitor, infer_field
__class_map__ = {}
__schemas__ = {} # ensure `id(schema)` doesn't get duplicate.
def get_cls_by_schema(schema):
id_ = id(schema)
if id_ not in __class_map__:
meta = {} if schema.metadata is None else schema.metadata
cls_name = meta.get(b"cls", b"").decode()
if cls_name:
import importlib
module_name, class_name = cls_name.rsplit(".", 1)
mod = importlib.import_module(module_name)
cls_ = getattr(mod, class_name)
else:
from pyfury.type import record_class_factory
cls_ = record_class_factory(
"Record" + str(id(schema)), [f.name for f in schema]
)
__class_map__[id_] = cls_
__schemas__[id_] = schema
return __class_map__[id_]
def remove_schema(schema):
__schemas__.pop(id(schema))
def reset():
__class_map__.clear()
__schemas__.clear()
_supported_types = {
pa.bool_,
pa.int8,
pa.int16,
pa.int32,
pa.int64,
pa.float32,
pa.float64,
str,
bytes,
typing.List,
typing.Dict,
}
_supported_types_str = [
f"{t.__module__}.{getattr(t, '__name__', t)}" for t in _supported_types
]
_supported_types_mapping = {t: t for t in _supported_types}
_supported_types_mapping.update(
{
str: pa.utf8,
bytes: pa.binary,
list: pa.list_,
dict: pa.map_,
typing.List: pa.list_,
typing.Dict: pa.map_,
bool: pa.bool_,
datetime.date: pa.date32,
datetime.datetime: partial(pa.timestamp, "us"),
}
)
def infer_schema(clz, types_path=None) -> pa.Schema:
types_path = list(types_path or [])
type_hints = typing.get_type_hints(clz)
keys = sorted(type_hints.keys())
fields = [
infer_field(
field_name,
type_hints[field_name],
ArrowTypeVisitor(),
types_path=types_path,
)
for field_name in keys
]
return pa.schema(fields, metadata={"cls": get_qualified_classname(clz)})
class ArrowTypeVisitor(TypeVisitor):
def visit_list(self, field_name, elem_type, types_path=None):
# Infer type recursively for type such as List[Dict[str, str]]
elem_field = infer_field("item", elem_type, self, types_path=types_path)
return pa.field(field_name, pa.list_(elem_field.type))
def visit_dict(self, field_name, key_type, value_type, types_path=None):
# Infer type recursively for type such as Dict[str, Dict[str, str]]
key_field = infer_field("key", key_type, self, types_path=types_path)
value_field = infer_field("value", value_type, self, types_path=types_path)
return pa.field(field_name, pa.map_(key_field.type, value_field.type))
def visit_customized(self, field_name, type_, types_path=None):
# type_ is a pojo
pojo_schema = infer_schema(type_)
fields = list(pojo_schema)
return pa.field(
field_name,
pa.struct(fields),
metadata={"cls": get_qualified_classname(type_)},
)
def visit_other(self, field_name, type_, types_path=None):
# use _supported_types_mapping instead of _supported_types, because
# typing.List/typing.Dict's origin will be list/dict
if type_ not in _supported_types_mapping:
raise TypeError(
f"Type {type_} not supported, currently only "
f"compositions of {_supported_types_str} are supported. "
f"types_path is {types_path}"
)
arrow_type_func = _supported_types_mapping.get(type_)
return pa.field(field_name, arrow_type_func())
def infer_data_type(clz) -> Optional[pa.DataType]:
try:
return infer_field("", clz, ArrowTypeVisitor()).type
except TypeError:
return None
def get_type_id(clz) -> Optional[int]:
type_ = infer_data_type(clz)
if type_:
return type_.id
else:
return None
def compute_schema_hash(schema: pa.Schema):
hash_ = 17
for f in schema:
hash_ = _compute_hash(hash_, f.type)
return hash_
def _compute_hash(hash_: int, type_: pa.DataType):
while True:
h = hash_ * 31 + type_.id
if h > 2**63 - 1:
hash_ = hash_ >> 2
else:
hash_ = h
break
types = []
if isinstance(type_, pa.ListType):
types.append(type_.value_type)
elif isinstance(type_, pa.MapType):
types.append(type_.key_type)
types.append(type_.item_type)
elif isinstance(type_, pa.StructType):
types.extend([f.type for f in type_])
else:
assert (
type_.num_fields == 0
), f"field type should not be nested, but got type {type_}."
for t in types:
hash_ = _compute_hash(hash_, t)
return hash_