pyiceberg/schema.py (1,034 lines of code) (raw):
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=W0511
from __future__ import annotations
import itertools
from abc import ABC, abstractmethod
from dataclasses import dataclass
from functools import cached_property, partial, singledispatch
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Generic,
List,
Literal,
Optional,
Set,
Tuple,
TypeVar,
Union,
)
from pydantic import Field, PrivateAttr, model_validator
from pyiceberg.exceptions import ResolveError
from pyiceberg.typedef import EMPTY_DICT, IcebergBaseModel, StructProtocol
from pyiceberg.types import (
BinaryType,
BooleanType,
DateType,
DecimalType,
DoubleType,
FixedType,
FloatType,
IcebergType,
IntegerType,
ListType,
LongType,
MapType,
NestedField,
PrimitiveType,
StringType,
StructType,
TimestampNanoType,
TimestampType,
TimestamptzNanoType,
TimestamptzType,
TimeType,
UnknownType,
UUIDType,
)
if TYPE_CHECKING:
import pyarrow as pa
from pyiceberg.table.name_mapping import (
NameMapping,
)
T = TypeVar("T")
P = TypeVar("P")
INITIAL_SCHEMA_ID = 0
class Schema(IcebergBaseModel):
"""A table Schema.
Example:
>>> from pyiceberg import schema
>>> from pyiceberg import types
"""
type: Literal["struct"] = "struct"
fields: Tuple[NestedField, ...] = Field(default_factory=tuple)
schema_id: int = Field(alias="schema-id", default=INITIAL_SCHEMA_ID)
identifier_field_ids: List[int] = Field(alias="identifier-field-ids", default_factory=list)
_name_to_id: Dict[str, int] = PrivateAttr()
def __init__(self, *fields: NestedField, **data: Any):
if fields:
data["fields"] = fields
super().__init__(**data)
self._name_to_id = index_by_name(self)
def __str__(self) -> str:
"""Return the string representation of the Schema class."""
return "table {\n" + "\n".join([" " + str(field) for field in self.columns]) + "\n}"
def __repr__(self) -> str:
"""Return the string representation of the Schema class."""
return f"Schema({', '.join(repr(column) for column in self.columns)}, schema_id={self.schema_id}, identifier_field_ids={self.identifier_field_ids})"
def __len__(self) -> int:
"""Return the length of an instance of the Literal class."""
return len(self.fields)
def __eq__(self, other: Any) -> bool:
"""Return the equality of two instances of the Schema class."""
if not other:
return False
if not isinstance(other, Schema):
return False
if len(self.columns) != len(other.columns):
return False
identifier_field_ids_is_equal = self.identifier_field_ids == other.identifier_field_ids
schema_is_equal = all(lhs == rhs for lhs, rhs in zip(self.columns, other.columns))
return identifier_field_ids_is_equal and schema_is_equal
@model_validator(mode="after")
def check_schema(self) -> Schema:
if self.identifier_field_ids:
for field_id in self.identifier_field_ids:
self._validate_identifier_field(field_id)
return self
@property
def columns(self) -> Tuple[NestedField, ...]:
"""A tuple of the top-level fields."""
return self.fields
@cached_property
def _lazy_id_to_field(self) -> Dict[int, NestedField]:
"""Return an index of field ID to NestedField instance.
This is calculated once when called for the first time. Subsequent calls to this method will use a cached index.
"""
return index_by_id(self)
@cached_property
def _lazy_id_to_parent(self) -> Dict[int, int]:
"""Returns an index of field ID to parent field IDs.
This is calculated once when called for the first time. Subsequent calls to this method will use a cached index.
"""
return _index_parents(self)
@cached_property
def _lazy_name_to_id_lower(self) -> Dict[str, int]:
"""Return an index of lower-case field names to field IDs.
This is calculated once when called for the first time. Subsequent calls to this method will use a cached index.
"""
return {name.lower(): field_id for name, field_id in self._name_to_id.items()}
@cached_property
def _lazy_id_to_name(self) -> Dict[int, str]:
"""Return an index of field ID to full name.
This is calculated once when called for the first time. Subsequent calls to this method will use a cached index.
"""
return index_name_by_id(self)
@cached_property
def _lazy_id_to_accessor(self) -> Dict[int, Accessor]:
"""Return an index of field ID to accessor.
This is calculated once when called for the first time. Subsequent calls to this method will use a cached index.
"""
return build_position_accessors(self)
def as_struct(self) -> StructType:
"""Return the schema as a struct."""
return StructType(*self.fields)
def as_arrow(self) -> "pa.Schema":
"""Return the schema as an Arrow schema."""
from pyiceberg.io.pyarrow import schema_to_pyarrow
return schema_to_pyarrow(self)
def find_field(self, name_or_id: Union[str, int], case_sensitive: bool = True) -> NestedField:
"""Find a field using a field name or field ID.
Args:
name_or_id (Union[str, int]): Either a field name or a field ID.
case_sensitive (bool, optional): Whether to perform a case-sensitive lookup using a field name. Defaults to True.
Raises:
ValueError: When the value cannot be found.
Returns:
NestedField: The matched NestedField.
"""
if isinstance(name_or_id, int):
if name_or_id not in self._lazy_id_to_field:
raise ValueError(f"Could not find field with id: {name_or_id}")
return self._lazy_id_to_field[name_or_id]
if case_sensitive:
field_id = self._name_to_id.get(name_or_id)
else:
field_id = self._lazy_name_to_id_lower.get(name_or_id.lower())
if field_id is None:
raise ValueError(f"Could not find field with name {name_or_id}, case_sensitive={case_sensitive}")
return self._lazy_id_to_field[field_id]
def find_type(self, name_or_id: Union[str, int], case_sensitive: bool = True) -> IcebergType:
"""Find a field type using a field name or field ID.
Args:
name_or_id (Union[str, int]): Either a field name or a field ID.
case_sensitive (bool, optional): Whether to perform a case-sensitive lookup using a field name. Defaults to True.
Returns:
NestedField: The type of the matched NestedField.
"""
field = self.find_field(name_or_id=name_or_id, case_sensitive=case_sensitive)
if not field:
raise ValueError(f"Could not find field with name or id {name_or_id}, case_sensitive={case_sensitive}")
return field.field_type
@property
def highest_field_id(self) -> int:
return max(self._lazy_id_to_name.keys(), default=0)
@cached_property
def name_mapping(self) -> NameMapping:
from pyiceberg.table.name_mapping import create_mapping_from_schema
return create_mapping_from_schema(self)
def find_column_name(self, column_id: int) -> Optional[str]:
"""Find a column name given a column ID.
Args:
column_id (int): The ID of the column.
Returns:
str: The column name (or None if the column ID cannot be found).
"""
return self._lazy_id_to_name.get(column_id)
@property
def column_names(self) -> List[str]:
"""
Return a list of all the column names, including nested fields.
Excludes short names.
Returns:
List[str]: The column names.
"""
return list(self._lazy_id_to_name.values())
def accessor_for_field(self, field_id: int) -> Accessor:
"""Find a schema position accessor given a field ID.
Args:
field_id (int): The ID of the field.
Raises:
ValueError: When the value cannot be found.
Returns:
Accessor: An accessor for the given field ID.
"""
if field_id not in self._lazy_id_to_accessor:
raise ValueError(f"Could not find accessor for field with id: {field_id}")
return self._lazy_id_to_accessor[field_id]
def identifier_field_names(self) -> Set[str]:
"""Return the names of the identifier fields.
Returns:
Set of names of the identifier fields
"""
ids = set()
for field_id in self.identifier_field_ids:
column_name = self.find_column_name(field_id)
if column_name is None:
raise ValueError(f"Could not find identifier column id: {field_id}")
ids.add(column_name)
return ids
def select(self, *names: str, case_sensitive: bool = True) -> Schema:
"""Return a new schema instance pruned to a subset of columns.
Args:
names (List[str]): A list of column names.
case_sensitive (bool, optional): Whether to perform a case-sensitive lookup for each column name. Defaults to True.
Returns:
Schema: A new schema with pruned columns.
Raises:
ValueError: If a column is selected that doesn't exist.
"""
try:
if case_sensitive:
ids = {self._name_to_id[name] for name in names}
else:
ids = {self._lazy_name_to_id_lower[name.lower()] for name in names}
except KeyError as e:
raise ValueError(f"Could not find column: {e}") from e
return prune_columns(self, ids)
@property
def field_ids(self) -> Set[int]:
"""Return the IDs of the current schema."""
return set(self._name_to_id.values())
def _validate_identifier_field(self, field_id: int) -> None:
"""Validate that the field with the given ID is a valid identifier field.
Args:
field_id: The ID of the field to validate.
Raises:
ValueError: If the field is not valid.
"""
field = self.find_field(field_id)
if not field.field_type.is_primitive:
raise ValueError(f"Identifier field {field_id} invalid: not a primitive type field")
if not field.required:
raise ValueError(f"Identifier field {field_id} invalid: not a required field")
if isinstance(field.field_type, (DoubleType, FloatType)):
raise ValueError(f"Identifier field {field_id} invalid: must not be float or double field")
# Check whether the nested field is in a chain of required struct fields
# Exploring from root for better error message for list and map types
parent_id = self._lazy_id_to_parent.get(field.field_id)
fields: List[int] = []
while parent_id is not None:
fields.append(parent_id)
parent_id = self._lazy_id_to_parent.get(parent_id)
while fields:
parent = self.find_field(fields.pop())
if not parent.field_type.is_struct:
raise ValueError(f"Cannot add field {field.name} as an identifier field: must not be nested in {parent}")
if not parent.required:
raise ValueError(
f"Cannot add field {field.name} as an identifier field: must not be nested in an optional field {parent}"
)
def check_format_version_compatibility(self, format_version: int) -> None:
"""Check that the schema is compatible for the given table format version.
Args:
format_version: The Iceberg table format version.
Raises:
ValueError: If the schema is not compatible for the format version.
"""
for field in self._lazy_id_to_field.values():
if format_version < field.field_type.minimum_format_version():
raise ValueError(
f"{field.field_type} is only supported in {field.field_type.minimum_format_version()} or higher. Current format version is: {format_version}"
)
class SchemaVisitor(Generic[T], ABC):
def before_field(self, field: NestedField) -> None:
"""Override this method to perform an action immediately before visiting a field."""
def after_field(self, field: NestedField) -> None:
"""Override this method to perform an action immediately after visiting a field."""
def before_list_element(self, element: NestedField) -> None:
"""Override this method to perform an action immediately before visiting an element within a ListType."""
self.before_field(element)
def after_list_element(self, element: NestedField) -> None:
"""Override this method to perform an action immediately after visiting an element within a ListType."""
self.after_field(element)
def before_map_key(self, key: NestedField) -> None:
"""Override this method to perform an action immediately before visiting a key within a MapType."""
self.before_field(key)
def after_map_key(self, key: NestedField) -> None:
"""Override this method to perform an action immediately after visiting a key within a MapType."""
self.after_field(key)
def before_map_value(self, value: NestedField) -> None:
"""Override this method to perform an action immediately before visiting a value within a MapType."""
self.before_field(value)
def after_map_value(self, value: NestedField) -> None:
"""Override this method to perform an action immediately after visiting a value within a MapType."""
self.after_field(value)
@abstractmethod
def schema(self, schema: Schema, struct_result: T) -> T:
"""Visit a Schema."""
@abstractmethod
def struct(self, struct: StructType, field_results: List[T]) -> T:
"""Visit a StructType."""
@abstractmethod
def field(self, field: NestedField, field_result: T) -> T:
"""Visit a NestedField."""
@abstractmethod
def list(self, list_type: ListType, element_result: T) -> T:
"""Visit a ListType."""
@abstractmethod
def map(self, map_type: MapType, key_result: T, value_result: T) -> T:
"""Visit a MapType."""
@abstractmethod
def primitive(self, primitive: PrimitiveType) -> T:
"""Visit a PrimitiveType."""
class PreOrderSchemaVisitor(Generic[T], ABC):
@abstractmethod
def schema(self, schema: Schema, struct_result: Callable[[], T]) -> T:
"""Visit a Schema."""
@abstractmethod
def struct(self, struct: StructType, field_results: List[Callable[[], T]]) -> T:
"""Visit a StructType."""
@abstractmethod
def field(self, field: NestedField, field_result: Callable[[], T]) -> T:
"""Visit a NestedField."""
@abstractmethod
def list(self, list_type: ListType, element_result: Callable[[], T]) -> T:
"""Visit a ListType."""
@abstractmethod
def map(self, map_type: MapType, key_result: Callable[[], T], value_result: Callable[[], T]) -> T:
"""Visit a MapType."""
@abstractmethod
def primitive(self, primitive: PrimitiveType) -> T:
"""Visit a PrimitiveType."""
class SchemaWithPartnerVisitor(Generic[P, T], ABC):
def before_field(self, field: NestedField, field_partner: Optional[P]) -> None:
"""Override this method to perform an action immediately before visiting a field."""
def after_field(self, field: NestedField, field_partner: Optional[P]) -> None:
"""Override this method to perform an action immediately after visiting a field."""
def before_list_element(self, element: NestedField, element_partner: Optional[P]) -> None:
"""Override this method to perform an action immediately before visiting an element within a ListType."""
self.before_field(element, element_partner)
def after_list_element(self, element: NestedField, element_partner: Optional[P]) -> None:
"""Override this method to perform an action immediately after visiting an element within a ListType."""
self.after_field(element, element_partner)
def before_map_key(self, key: NestedField, key_partner: Optional[P]) -> None:
"""Override this method to perform an action immediately before visiting a key within a MapType."""
self.before_field(key, key_partner)
def after_map_key(self, key: NestedField, key_partner: Optional[P]) -> None:
"""Override this method to perform an action immediately after visiting a key within a MapType."""
self.after_field(key, key_partner)
def before_map_value(self, value: NestedField, value_partner: Optional[P]) -> None:
"""Override this method to perform an action immediately before visiting a value within a MapType."""
self.before_field(value, value_partner)
def after_map_value(self, value: NestedField, value_partner: Optional[P]) -> None:
"""Override this method to perform an action immediately after visiting a value within a MapType."""
self.after_field(value, value_partner)
@abstractmethod
def schema(self, schema: Schema, schema_partner: Optional[P], struct_result: T) -> T:
"""Visit a schema with a partner."""
@abstractmethod
def struct(self, struct: StructType, struct_partner: Optional[P], field_results: List[T]) -> T:
"""Visit a struct type with a partner."""
@abstractmethod
def field(self, field: NestedField, field_partner: Optional[P], field_result: T) -> T:
"""Visit a nested field with a partner."""
@abstractmethod
def list(self, list_type: ListType, list_partner: Optional[P], element_result: T) -> T:
"""Visit a list type with a partner."""
@abstractmethod
def map(self, map_type: MapType, map_partner: Optional[P], key_result: T, value_result: T) -> T:
"""Visit a map type with a partner."""
@abstractmethod
def primitive(self, primitive: PrimitiveType, primitive_partner: Optional[P]) -> T:
"""Visit a primitive type with a partner."""
class PrimitiveWithPartnerVisitor(SchemaWithPartnerVisitor[P, T]):
def primitive(self, primitive: PrimitiveType, primitive_partner: Optional[P]) -> T:
"""Visit a PrimitiveType."""
if isinstance(primitive, BooleanType):
return self.visit_boolean(primitive, primitive_partner)
elif isinstance(primitive, IntegerType):
return self.visit_integer(primitive, primitive_partner)
elif isinstance(primitive, LongType):
return self.visit_long(primitive, primitive_partner)
elif isinstance(primitive, FloatType):
return self.visit_float(primitive, primitive_partner)
elif isinstance(primitive, DoubleType):
return self.visit_double(primitive, primitive_partner)
elif isinstance(primitive, DecimalType):
return self.visit_decimal(primitive, primitive_partner)
elif isinstance(primitive, DateType):
return self.visit_date(primitive, primitive_partner)
elif isinstance(primitive, TimeType):
return self.visit_time(primitive, primitive_partner)
elif isinstance(primitive, TimestampType):
return self.visit_timestamp(primitive, primitive_partner)
elif isinstance(primitive, TimestampNanoType):
return self.visit_timestamp_ns(primitive, primitive_partner)
elif isinstance(primitive, TimestamptzType):
return self.visit_timestamptz(primitive, primitive_partner)
elif isinstance(primitive, TimestamptzNanoType):
return self.visit_timestamptz_ns(primitive, primitive_partner)
elif isinstance(primitive, StringType):
return self.visit_string(primitive, primitive_partner)
elif isinstance(primitive, UUIDType):
return self.visit_uuid(primitive, primitive_partner)
elif isinstance(primitive, FixedType):
return self.visit_fixed(primitive, primitive_partner)
elif isinstance(primitive, BinaryType):
return self.visit_binary(primitive, primitive_partner)
elif isinstance(primitive, UnknownType):
return self.visit_unknown(primitive, primitive_partner)
else:
raise ValueError(f"Type not recognized: {primitive}")
@abstractmethod
def visit_boolean(self, boolean_type: BooleanType, partner: Optional[P]) -> T:
"""Visit a BooleanType."""
@abstractmethod
def visit_integer(self, integer_type: IntegerType, partner: Optional[P]) -> T:
"""Visit a IntegerType."""
@abstractmethod
def visit_long(self, long_type: LongType, partner: Optional[P]) -> T:
"""Visit a LongType."""
@abstractmethod
def visit_float(self, float_type: FloatType, partner: Optional[P]) -> T:
"""Visit a FloatType."""
@abstractmethod
def visit_double(self, double_type: DoubleType, partner: Optional[P]) -> T:
"""Visit a DoubleType."""
@abstractmethod
def visit_decimal(self, decimal_type: DecimalType, partner: Optional[P]) -> T:
"""Visit a DecimalType."""
@abstractmethod
def visit_date(self, date_type: DateType, partner: Optional[P]) -> T:
"""Visit a DecimalType."""
@abstractmethod
def visit_time(self, time_type: TimeType, partner: Optional[P]) -> T:
"""Visit a DecimalType."""
@abstractmethod
def visit_timestamp(self, timestamp_type: TimestampType, partner: Optional[P]) -> T:
"""Visit a TimestampType."""
@abstractmethod
def visit_timestamp_ns(self, timestamp_ns_type: TimestampNanoType, partner: Optional[P]) -> T:
"""Visit a TimestampNanoType."""
@abstractmethod
def visit_timestamptz(self, timestamptz_type: TimestamptzType, partner: Optional[P]) -> T:
"""Visit a TimestamptzType."""
@abstractmethod
def visit_timestamptz_ns(self, timestamptz_ns_type: TimestamptzNanoType, partner: Optional[P]) -> T:
"""Visit a TimestamptzNanoType."""
@abstractmethod
def visit_string(self, string_type: StringType, partner: Optional[P]) -> T:
"""Visit a StringType."""
@abstractmethod
def visit_uuid(self, uuid_type: UUIDType, partner: Optional[P]) -> T:
"""Visit a UUIDType."""
@abstractmethod
def visit_fixed(self, fixed_type: FixedType, partner: Optional[P]) -> T:
"""Visit a FixedType."""
@abstractmethod
def visit_binary(self, binary_type: BinaryType, partner: Optional[P]) -> T:
"""Visit a BinaryType."""
@abstractmethod
def visit_unknown(self, unknown_type: UnknownType, partner: Optional[P]) -> T:
"""Visit a UnknownType."""
class PartnerAccessor(Generic[P], ABC):
@abstractmethod
def schema_partner(self, partner: Optional[P]) -> Optional[P]:
"""Return the equivalent of the schema as a struct."""
@abstractmethod
def field_partner(self, partner_struct: Optional[P], field_id: int, field_name: str) -> Optional[P]:
"""Return the equivalent struct field by name or id in the partner struct."""
@abstractmethod
def list_element_partner(self, partner_list: Optional[P]) -> Optional[P]:
"""Return the equivalent list element in the partner list."""
@abstractmethod
def map_key_partner(self, partner_map: Optional[P]) -> Optional[P]:
"""Return the equivalent map key in the partner map."""
@abstractmethod
def map_value_partner(self, partner_map: Optional[P]) -> Optional[P]:
"""Return the equivalent map value in the partner map."""
@singledispatch
def visit_with_partner(
schema_or_type: Union[Schema, IcebergType], partner: P, visitor: SchemaWithPartnerVisitor[T, P], accessor: PartnerAccessor[P]
) -> T:
raise ValueError(f"Unsupported type: {schema_or_type}")
@visit_with_partner.register(Schema)
def _(schema: Schema, partner: P, visitor: SchemaWithPartnerVisitor[P, T], accessor: PartnerAccessor[P]) -> T:
struct_partner = accessor.schema_partner(partner)
return visitor.schema(schema, partner, visit_with_partner(schema.as_struct(), struct_partner, visitor, accessor)) # type: ignore
@visit_with_partner.register(StructType)
def _(struct: StructType, partner: P, visitor: SchemaWithPartnerVisitor[P, T], accessor: PartnerAccessor[P]) -> T:
field_results = []
for field in struct.fields:
field_partner = accessor.field_partner(partner, field.field_id, field.name)
visitor.before_field(field, field_partner)
try:
field_result = visit_with_partner(field.field_type, field_partner, visitor, accessor) # type: ignore
field_results.append(visitor.field(field, field_partner, field_result))
finally:
visitor.after_field(field, field_partner)
return visitor.struct(struct, partner, field_results)
@visit_with_partner.register(ListType)
def _(list_type: ListType, partner: P, visitor: SchemaWithPartnerVisitor[P, T], accessor: PartnerAccessor[P]) -> T:
element_partner = accessor.list_element_partner(partner)
visitor.before_list_element(list_type.element_field, element_partner)
try:
element_result = visit_with_partner(list_type.element_type, element_partner, visitor, accessor) # type: ignore
finally:
visitor.after_list_element(list_type.element_field, element_partner)
return visitor.list(list_type, partner, element_result)
@visit_with_partner.register(MapType)
def _(map_type: MapType, partner: P, visitor: SchemaWithPartnerVisitor[P, T], accessor: PartnerAccessor[P]) -> T:
key_partner = accessor.map_key_partner(partner)
visitor.before_map_key(map_type.key_field, key_partner)
try:
key_result = visit_with_partner(map_type.key_type, key_partner, visitor, accessor) # type: ignore
finally:
visitor.after_map_key(map_type.key_field, key_partner)
value_partner = accessor.map_value_partner(partner)
visitor.before_map_value(map_type.value_field, value_partner)
try:
value_result = visit_with_partner(map_type.value_type, value_partner, visitor, accessor) # type: ignore
finally:
visitor.after_map_value(map_type.value_field, value_partner)
return visitor.map(map_type, partner, key_result, value_result)
@visit_with_partner.register(PrimitiveType)
def _(primitive: PrimitiveType, partner: P, visitor: SchemaWithPartnerVisitor[P, T], _: PartnerAccessor[P]) -> T:
return visitor.primitive(primitive, partner)
class SchemaVisitorPerPrimitiveType(SchemaVisitor[T], ABC):
def primitive(self, primitive: PrimitiveType) -> T:
"""Visit a PrimitiveType."""
if isinstance(primitive, FixedType):
return self.visit_fixed(primitive)
elif isinstance(primitive, DecimalType):
return self.visit_decimal(primitive)
elif isinstance(primitive, BooleanType):
return self.visit_boolean(primitive)
elif isinstance(primitive, IntegerType):
return self.visit_integer(primitive)
elif isinstance(primitive, LongType):
return self.visit_long(primitive)
elif isinstance(primitive, FloatType):
return self.visit_float(primitive)
elif isinstance(primitive, DoubleType):
return self.visit_double(primitive)
elif isinstance(primitive, DateType):
return self.visit_date(primitive)
elif isinstance(primitive, TimeType):
return self.visit_time(primitive)
elif isinstance(primitive, TimestampType):
return self.visit_timestamp(primitive)
elif isinstance(primitive, TimestampNanoType):
return self.visit_timestamp_ns(primitive)
elif isinstance(primitive, TimestamptzType):
return self.visit_timestamptz(primitive)
elif isinstance(primitive, TimestamptzNanoType):
return self.visit_timestamptz_ns(primitive)
elif isinstance(primitive, StringType):
return self.visit_string(primitive)
elif isinstance(primitive, UUIDType):
return self.visit_uuid(primitive)
elif isinstance(primitive, BinaryType):
return self.visit_binary(primitive)
elif isinstance(primitive, UnknownType):
return self.visit_unknown(primitive)
else:
raise ValueError(f"Type not recognized: {primitive}")
@abstractmethod
def visit_fixed(self, fixed_type: FixedType) -> T:
"""Visit a FixedType."""
@abstractmethod
def visit_decimal(self, decimal_type: DecimalType) -> T:
"""Visit a DecimalType."""
@abstractmethod
def visit_boolean(self, boolean_type: BooleanType) -> T:
"""Visit a BooleanType."""
@abstractmethod
def visit_integer(self, integer_type: IntegerType) -> T:
"""Visit a IntegerType."""
@abstractmethod
def visit_long(self, long_type: LongType) -> T:
"""Visit a LongType."""
@abstractmethod
def visit_float(self, float_type: FloatType) -> T:
"""Visit a FloatType."""
@abstractmethod
def visit_double(self, double_type: DoubleType) -> T:
"""Visit a DoubleType."""
@abstractmethod
def visit_date(self, date_type: DateType) -> T:
"""Visit a DecimalType."""
@abstractmethod
def visit_time(self, time_type: TimeType) -> T:
"""Visit a DecimalType."""
@abstractmethod
def visit_timestamp(self, timestamp_type: TimestampType) -> T:
"""Visit a TimestampType."""
@abstractmethod
def visit_timestamp_ns(self, timestamp_type: TimestampNanoType) -> T:
"""Visit a TimestampNanoType."""
@abstractmethod
def visit_timestamptz(self, timestamptz_type: TimestamptzType) -> T:
"""Visit a TimestamptzType."""
@abstractmethod
def visit_timestamptz_ns(self, timestamptz_ns_type: TimestamptzNanoType) -> T:
"""Visit a TimestamptzNanoType."""
@abstractmethod
def visit_string(self, string_type: StringType) -> T:
"""Visit a StringType."""
@abstractmethod
def visit_uuid(self, uuid_type: UUIDType) -> T:
"""Visit a UUIDType."""
@abstractmethod
def visit_binary(self, binary_type: BinaryType) -> T:
"""Visit a BinaryType."""
@abstractmethod
def visit_unknown(self, unknown_type: UnknownType) -> T:
"""Visit a UnknownType."""
@dataclass(init=True, eq=True, frozen=True)
class Accessor:
"""An accessor for a specific position in a container that implements the StructProtocol."""
position: int
inner: Optional[Accessor] = None
def __str__(self) -> str:
"""Return the string representation of the Accessor class."""
return f"Accessor(position={self.position},inner={self.inner})"
def __repr__(self) -> str:
"""Return the string representation of the Accessor class."""
return self.__str__()
def get(self, container: StructProtocol) -> Any:
"""Return the value at self.position in `container`.
Args:
container (StructProtocol): A container to access at position `self.position`.
Returns:
Any: The value at position `self.position` in the container.
"""
pos = self.position
val = container[pos]
inner = self
while inner.inner:
inner = inner.inner
val = val[inner.position]
return val
@singledispatch
def visit(obj: Union[Schema, IcebergType], visitor: SchemaVisitor[T]) -> T:
"""Apply a schema visitor to any point within a schema.
The function traverses the schema in post-order fashion.
Args:
obj (Union[Schema, IcebergType]): An instance of a Schema or an IcebergType.
visitor (SchemaVisitor[T]): An instance of an implementation of the generic SchemaVisitor base class.
Raises:
NotImplementedError: If attempting to visit an unrecognized object type.
"""
raise NotImplementedError(f"Cannot visit non-type: {obj}")
@visit.register(Schema)
def _(obj: Schema, visitor: SchemaVisitor[T]) -> T:
"""Visit a Schema with a concrete SchemaVisitor."""
return visitor.schema(obj, visit(obj.as_struct(), visitor))
@visit.register(StructType)
def _(obj: StructType, visitor: SchemaVisitor[T]) -> T:
"""Visit a StructType with a concrete SchemaVisitor."""
results = []
for field in obj.fields:
visitor.before_field(field)
result = visit(field.field_type, visitor)
visitor.after_field(field)
results.append(visitor.field(field, result))
return visitor.struct(obj, results)
@visit.register(ListType)
def _(obj: ListType, visitor: SchemaVisitor[T]) -> T:
"""Visit a ListType with a concrete SchemaVisitor."""
visitor.before_list_element(obj.element_field)
result = visit(obj.element_type, visitor)
visitor.after_list_element(obj.element_field)
return visitor.list(obj, result)
@visit.register(MapType)
def _(obj: MapType, visitor: SchemaVisitor[T]) -> T:
"""Visit a MapType with a concrete SchemaVisitor."""
visitor.before_map_key(obj.key_field)
key_result = visit(obj.key_type, visitor)
visitor.after_map_key(obj.key_field)
visitor.before_map_value(obj.value_field)
value_result = visit(obj.value_type, visitor)
visitor.after_map_value(obj.value_field)
return visitor.map(obj, key_result, value_result)
@visit.register(PrimitiveType)
def _(obj: PrimitiveType, visitor: SchemaVisitor[T]) -> T:
"""Visit a PrimitiveType with a concrete SchemaVisitor."""
return visitor.primitive(obj)
@singledispatch
def pre_order_visit(obj: Union[Schema, IcebergType], visitor: PreOrderSchemaVisitor[T]) -> T:
"""Apply a schema visitor to any point within a schema.
The function traverses the schema in pre-order fashion. This is a slimmed down version
compared to the post-order traversal (missing before and after methods), mostly
because we don't use the pre-order traversal much.
Args:
obj (Union[Schema, IcebergType]): An instance of a Schema or an IcebergType.
visitor (PreOrderSchemaVisitor[T]): An instance of an implementation of the generic PreOrderSchemaVisitor base class.
Raises:
NotImplementedError: If attempting to visit an unrecognized object type.
"""
raise NotImplementedError(f"Cannot visit non-type: {obj}")
@pre_order_visit.register(Schema)
def _(obj: Schema, visitor: PreOrderSchemaVisitor[T]) -> T:
"""Visit a Schema with a concrete PreOrderSchemaVisitor."""
return visitor.schema(obj, lambda: pre_order_visit(obj.as_struct(), visitor))
@pre_order_visit.register(StructType)
def _(obj: StructType, visitor: PreOrderSchemaVisitor[T]) -> T:
"""Visit a StructType with a concrete PreOrderSchemaVisitor."""
return visitor.struct(
obj,
[
partial(
lambda field: visitor.field(field, partial(lambda field: pre_order_visit(field.field_type, visitor), field)),
field,
)
for field in obj.fields
],
)
@pre_order_visit.register(ListType)
def _(obj: ListType, visitor: PreOrderSchemaVisitor[T]) -> T:
"""Visit a ListType with a concrete PreOrderSchemaVisitor."""
return visitor.list(obj, lambda: pre_order_visit(obj.element_type, visitor))
@pre_order_visit.register(MapType)
def _(obj: MapType, visitor: PreOrderSchemaVisitor[T]) -> T:
"""Visit a MapType with a concrete PreOrderSchemaVisitor."""
return visitor.map(obj, lambda: pre_order_visit(obj.key_type, visitor), lambda: pre_order_visit(obj.value_type, visitor))
@pre_order_visit.register(PrimitiveType)
def _(obj: PrimitiveType, visitor: PreOrderSchemaVisitor[T]) -> T:
"""Visit a PrimitiveType with a concrete PreOrderSchemaVisitor."""
return visitor.primitive(obj)
class _IndexById(SchemaVisitor[Dict[int, NestedField]]):
"""A schema visitor for generating a field ID to NestedField index."""
def __init__(self) -> None:
self._index: Dict[int, NestedField] = {}
def schema(self, schema: Schema, struct_result: Dict[int, NestedField]) -> Dict[int, NestedField]:
return self._index
def struct(self, struct: StructType, field_results: List[Dict[int, NestedField]]) -> Dict[int, NestedField]:
return self._index
def field(self, field: NestedField, field_result: Dict[int, NestedField]) -> Dict[int, NestedField]:
"""Add the field ID to the index."""
self._index[field.field_id] = field
return self._index
def list(self, list_type: ListType, element_result: Dict[int, NestedField]) -> Dict[int, NestedField]:
"""Add the list element ID to the index."""
self._index[list_type.element_field.field_id] = list_type.element_field
return self._index
def map(
self, map_type: MapType, key_result: Dict[int, NestedField], value_result: Dict[int, NestedField]
) -> Dict[int, NestedField]:
"""Add the key ID and value ID as individual items in the index."""
self._index[map_type.key_field.field_id] = map_type.key_field
self._index[map_type.value_field.field_id] = map_type.value_field
return self._index
def primitive(self, primitive: PrimitiveType) -> Dict[int, NestedField]:
return self._index
def index_by_id(schema_or_type: Union[Schema, IcebergType]) -> Dict[int, NestedField]:
"""Generate an index of field IDs to NestedField instances.
Args:
schema_or_type (Union[Schema, IcebergType]): A schema or type to index.
Returns:
Dict[int, NestedField]: An index of field IDs to NestedField instances.
"""
return visit(schema_or_type, _IndexById())
class _IndexParents(SchemaVisitor[Dict[int, int]]):
def __init__(self) -> None:
self.id_to_parent: Dict[int, int] = {}
self.id_stack: List[int] = []
def before_field(self, field: NestedField) -> None:
self.id_stack.append(field.field_id)
def after_field(self, field: NestedField) -> None:
self.id_stack.pop()
def schema(self, schema: Schema, struct_result: Dict[int, int]) -> Dict[int, int]:
return self.id_to_parent
def struct(self, struct: StructType, field_results: List[Dict[int, int]]) -> Dict[int, int]:
for field in struct.fields:
parent_id = self.id_stack[-1] if self.id_stack else None
if parent_id is not None:
# fields in the root struct are not added
self.id_to_parent[field.field_id] = parent_id
return self.id_to_parent
def field(self, field: NestedField, field_result: Dict[int, int]) -> Dict[int, int]:
return self.id_to_parent
def list(self, list_type: ListType, element_result: Dict[int, int]) -> Dict[int, int]:
self.id_to_parent[list_type.element_id] = self.id_stack[-1]
return self.id_to_parent
def map(self, map_type: MapType, key_result: Dict[int, int], value_result: Dict[int, int]) -> Dict[int, int]:
self.id_to_parent[map_type.key_id] = self.id_stack[-1]
self.id_to_parent[map_type.value_id] = self.id_stack[-1]
return self.id_to_parent
def primitive(self, primitive: PrimitiveType) -> Dict[int, int]:
return self.id_to_parent
def _index_parents(schema_or_type: Union[Schema, IcebergType]) -> Dict[int, int]:
"""Generate an index of field IDs to their parent field IDs.
Args:
schema_or_type (Union[Schema, IcebergType]): A schema or type to index.
Returns:
Dict[int, int]: An index of field IDs to their parent field IDs.
"""
return visit(schema_or_type, _IndexParents())
class _IndexByName(SchemaVisitor[Dict[str, int]]):
"""A schema visitor for generating a field name to field ID index."""
def __init__(self) -> None:
self._index: Dict[str, int] = {}
self._short_name_to_id: Dict[str, int] = {}
self._combined_index: Dict[str, int] = {}
self._field_names: List[str] = []
self._short_field_names: List[str] = []
def before_map_value(self, value: NestedField) -> None:
if not isinstance(value.field_type, StructType):
self._short_field_names.append(value.name)
self._field_names.append(value.name)
def after_map_value(self, value: NestedField) -> None:
if not isinstance(value.field_type, StructType):
self._short_field_names.pop()
self._field_names.pop()
def before_list_element(self, element: NestedField) -> None:
"""Short field names omit element when the element is a StructType."""
if not isinstance(element.field_type, StructType):
self._short_field_names.append(element.name)
self._field_names.append(element.name)
def after_list_element(self, element: NestedField) -> None:
if not isinstance(element.field_type, StructType):
self._short_field_names.pop()
self._field_names.pop()
def before_field(self, field: NestedField) -> None:
"""Store the field name."""
self._field_names.append(field.name)
self._short_field_names.append(field.name)
def after_field(self, field: NestedField) -> None:
"""Remove the last field name stored."""
self._field_names.pop()
self._short_field_names.pop()
def schema(self, schema: Schema, struct_result: Dict[str, int]) -> Dict[str, int]:
return self._index
def struct(self, struct: StructType, field_results: List[Dict[str, int]]) -> Dict[str, int]:
return self._index
def field(self, field: NestedField, field_result: Dict[str, int]) -> Dict[str, int]:
"""Add the field name to the index."""
self._add_field(field.name, field.field_id)
return self._index
def list(self, list_type: ListType, element_result: Dict[str, int]) -> Dict[str, int]:
"""Add the list element name to the index."""
self._add_field(list_type.element_field.name, list_type.element_field.field_id)
return self._index
def map(self, map_type: MapType, key_result: Dict[str, int], value_result: Dict[str, int]) -> Dict[str, int]:
"""Add the key name and value name as individual items in the index."""
self._add_field(map_type.key_field.name, map_type.key_field.field_id)
self._add_field(map_type.value_field.name, map_type.value_field.field_id)
return self._index
def _add_field(self, name: str, field_id: int) -> None:
"""Add a field name to the index, mapping its full name to its field ID.
Args:
name (str): The field name.
field_id (int): The field ID.
Raises:
ValueError: If the field name is already contained in the index.
"""
full_name = name
if self._field_names:
full_name = ".".join([".".join(self._field_names), name])
if full_name in self._index:
raise ValueError(f"Invalid schema, multiple fields for name {full_name}: {self._index[full_name]} and {field_id}")
self._index[full_name] = field_id
if self._short_field_names:
short_name = ".".join([".".join(self._short_field_names), name])
self._short_name_to_id[short_name] = field_id
def primitive(self, primitive: PrimitiveType) -> Dict[str, int]:
return self._index
def by_name(self) -> Dict[str, int]:
"""Return an index of combined full and short names.
Note: Only short names that do not conflict with full names are included.
"""
combined_index = self._short_name_to_id.copy()
combined_index.update(self._index)
return combined_index
def by_id(self) -> Dict[int, str]:
"""Return an index of ID to full names."""
id_to_full_name = {value: key for key, value in self._index.items()}
return id_to_full_name
def index_by_name(schema_or_type: Union[Schema, IcebergType]) -> Dict[str, int]:
"""Generate an index of field names to field IDs.
Args:
schema_or_type (Union[Schema, IcebergType]): A schema or type to index.
Returns:
Dict[str, int]: An index of field names to field IDs.
"""
if len(schema_or_type.fields) > 0:
indexer = _IndexByName()
visit(schema_or_type, indexer)
return indexer.by_name()
else:
return EMPTY_DICT
def index_name_by_id(schema_or_type: Union[Schema, IcebergType]) -> Dict[int, str]:
"""Generate an index of field IDs full field names.
Args:
schema_or_type (Union[Schema, IcebergType]): A schema or type to index.
Returns:
Dict[str, int]: An index of field IDs to full names.
"""
indexer = _IndexByName()
visit(schema_or_type, indexer)
return indexer.by_id()
Position = int
class _BuildPositionAccessors(SchemaVisitor[Dict[Position, Accessor]]):
"""A schema visitor for generating a field ID to accessor index.
Example:
>>> from pyiceberg.schema import Schema
>>> from pyiceberg.types import *
>>> schema = Schema(
... NestedField(field_id=2, name="id", field_type=IntegerType(), required=False),
... NestedField(field_id=1, name="data", field_type=StringType(), required=True),
... NestedField(
... field_id=3,
... name="location",
... field_type=StructType(
... NestedField(field_id=5, name="latitude", field_type=FloatType(), required=False),
... NestedField(field_id=6, name="longitude", field_type=FloatType(), required=False),
... ),
... required=True,
... ),
... schema_id=1,
... identifier_field_ids=[1],
... )
>>> result = build_position_accessors(schema)
>>> expected = {
... 2: Accessor(position=0, inner=None),
... 1: Accessor(position=1, inner=None),
... 5: Accessor(position=2, inner=Accessor(position=0, inner=None)),
... 6: Accessor(position=2, inner=Accessor(position=1, inner=None))
... 3: Accessor(position=2, inner=None),
... }
>>> result == expected
True
"""
def schema(self, schema: Schema, struct_result: Dict[Position, Accessor]) -> Dict[Position, Accessor]:
return struct_result
def struct(self, struct: StructType, field_results: List[Dict[Position, Accessor]]) -> Dict[Position, Accessor]:
result = {}
for position, field in enumerate(struct.fields):
if field_results[position]:
for inner_field_id, acc in field_results[position].items():
result[inner_field_id] = Accessor(position, inner=acc)
result[field.field_id] = Accessor(position)
return result
def field(self, field: NestedField, field_result: Dict[Position, Accessor]) -> Dict[Position, Accessor]:
return field_result
def list(self, list_type: ListType, element_result: Dict[Position, Accessor]) -> Dict[Position, Accessor]:
return {}
def map(
self, map_type: MapType, key_result: Dict[Position, Accessor], value_result: Dict[Position, Accessor]
) -> Dict[Position, Accessor]:
return {}
def primitive(self, primitive: PrimitiveType) -> Dict[Position, Accessor]:
return {}
def build_position_accessors(schema_or_type: Union[Schema, IcebergType]) -> Dict[int, Accessor]:
"""Generate an index of field IDs to schema position accessors.
Args:
schema_or_type (Union[Schema, IcebergType]): A schema or type to index.
Returns:
Dict[int, Accessor]: An index of field IDs to accessors.
"""
return visit(schema_or_type, _BuildPositionAccessors())
def assign_fresh_schema_ids(schema_or_type: Union[Schema, IcebergType], next_id: Optional[Callable[[], int]] = None) -> Schema:
"""Traverses the schema, and sets new IDs."""
return pre_order_visit(schema_or_type, _SetFreshIDs(next_id_func=next_id))
class _SetFreshIDs(PreOrderSchemaVisitor[IcebergType]):
"""Traverses the schema and assigns monotonically increasing ids."""
old_id_to_new_id: Dict[int, int]
def __init__(self, next_id_func: Optional[Callable[[], int]] = None) -> None:
self.old_id_to_new_id = {}
counter = itertools.count(1)
self.next_id_func = next_id_func if next_id_func is not None else lambda: next(counter)
def _get_and_increment(self, current_id: int) -> int:
new_id = self.next_id_func()
self.old_id_to_new_id[current_id] = new_id
return new_id
def schema(self, schema: Schema, struct_result: Callable[[], StructType]) -> Schema:
return Schema(
*struct_result().fields,
identifier_field_ids=[self.old_id_to_new_id[field_id] for field_id in schema.identifier_field_ids],
)
def struct(self, struct: StructType, field_results: List[Callable[[], IcebergType]]) -> StructType:
new_ids = [self._get_and_increment(field.field_id) for field in struct.fields]
new_fields = []
for field_id, field, field_type in zip(new_ids, struct.fields, field_results):
new_fields.append(
NestedField(
field_id=field_id,
name=field.name,
field_type=field_type(),
required=field.required,
doc=field.doc,
)
)
return StructType(*new_fields)
def field(self, field: NestedField, field_result: Callable[[], IcebergType]) -> IcebergType:
return field_result()
def list(self, list_type: ListType, element_result: Callable[[], IcebergType]) -> ListType:
element_id = self._get_and_increment(list_type.element_id)
return ListType(
element_id=element_id,
element=element_result(),
element_required=list_type.element_required,
)
def map(self, map_type: MapType, key_result: Callable[[], IcebergType], value_result: Callable[[], IcebergType]) -> MapType:
key_id = self._get_and_increment(map_type.key_id)
value_id = self._get_and_increment(map_type.value_id)
return MapType(
key_id=key_id,
key_type=key_result(),
value_id=value_id,
value_type=value_result(),
value_required=map_type.value_required,
)
def primitive(self, primitive: PrimitiveType) -> PrimitiveType:
return primitive
# Implementation copied from Apache Iceberg repo.
def make_compatible_name(name: str) -> str:
if not _valid_avro_name(name):
return _sanitize_name(name)
return name
def _valid_avro_name(name: str) -> bool:
length = len(name)
assert length > 0, ValueError("Can not validate empty avro name")
first = name[0]
if not (first.isalpha() or first == "_"):
return False
for character in name[1:]:
if not (character.isalnum() or character == "_"):
return False
return True
def _sanitize_name(name: str) -> str:
sb = []
first = name[0]
if not (first.isalpha() or first == "_"):
sb.append(_sanitize_char(first))
else:
sb.append(first)
for character in name[1:]:
if not (character.isalnum() or character == "_"):
sb.append(_sanitize_char(character))
else:
sb.append(character)
return "".join(sb)
def _sanitize_char(character: str) -> str:
return "_" + character if character.isdigit() else "_x" + hex(ord(character))[2:].upper()
def sanitize_column_names(schema: Schema) -> Schema:
"""Sanitize column names to make them compatible with Avro.
The column name should be starting with '_' or digit followed by a string only contains '_', digit or alphabet,
otherwise it will be sanitized to conform the avro naming convention.
Args:
schema: The schema to be sanitized.
Returns:
The sanitized schema.
"""
result = visit(schema.as_struct(), _SanitizeColumnsVisitor())
return Schema(
*(result or StructType()).fields,
schema_id=schema.schema_id,
identifier_field_ids=schema.identifier_field_ids,
)
class _SanitizeColumnsVisitor(SchemaVisitor[Optional[IcebergType]]):
def schema(self, schema: Schema, struct_result: Optional[IcebergType]) -> Optional[IcebergType]:
return struct_result
def field(self, field: NestedField, field_result: Optional[IcebergType]) -> Optional[IcebergType]:
return NestedField(
field_id=field.field_id,
name=make_compatible_name(field.name),
field_type=field_result,
doc=field.doc,
required=field.required,
)
def struct(self, struct: StructType, field_results: List[Optional[IcebergType]]) -> Optional[IcebergType]:
return StructType(*[field for field in field_results if field is not None])
def list(self, list_type: ListType, element_result: Optional[IcebergType]) -> Optional[IcebergType]:
return ListType(element_id=list_type.element_id, element_type=element_result, element_required=list_type.element_required)
def map(
self, map_type: MapType, key_result: Optional[IcebergType], value_result: Optional[IcebergType]
) -> Optional[IcebergType]:
return MapType(
key_id=map_type.key_id,
value_id=map_type.value_id,
key_type=key_result,
value_type=value_result,
value_required=map_type.value_required,
)
def primitive(self, primitive: PrimitiveType) -> Optional[IcebergType]:
return primitive
def prune_columns(schema: Schema, selected: Set[int], select_full_types: bool = True) -> Schema:
"""Prunes a column by only selecting a set of field-ids.
Args:
schema: The schema to be pruned.
selected: The field-ids to be included.
select_full_types: Return the full struct when a subset is recorded
Returns:
The pruned schema.
"""
result = visit(schema.as_struct(), _PruneColumnsVisitor(selected, select_full_types))
return Schema(
*(result or StructType()).fields,
schema_id=schema.schema_id,
identifier_field_ids=list(selected.intersection(schema.identifier_field_ids)),
)
class _PruneColumnsVisitor(SchemaVisitor[Optional[IcebergType]]):
selected: Set[int]
select_full_types: bool
def __init__(self, selected: Set[int], select_full_types: bool):
self.selected = selected
self.select_full_types = select_full_types
def schema(self, schema: Schema, struct_result: Optional[IcebergType]) -> Optional[IcebergType]:
return struct_result
def struct(self, struct: StructType, field_results: List[Optional[IcebergType]]) -> Optional[IcebergType]:
fields = struct.fields
selected_fields = []
same_type = True
for idx, projected_type in enumerate(field_results):
field = fields[idx]
if field.field_type == projected_type:
selected_fields.append(field)
elif projected_type is not None:
same_type = False
# Type has changed, create a new field with the projected type
selected_fields.append(
NestedField(
field_id=field.field_id,
name=field.name,
field_type=projected_type,
doc=field.doc,
required=field.required,
)
)
if selected_fields:
if len(selected_fields) == len(fields) and same_type is True:
# Nothing has changed, and we can return the original struct
return struct
else:
return StructType(*selected_fields)
return None
def field(self, field: NestedField, field_result: Optional[IcebergType]) -> Optional[IcebergType]:
if field.field_id in self.selected:
if self.select_full_types:
return field.field_type
elif field.field_type.is_struct:
return self._project_selected_struct(field_result)
else:
if not field.field_type.is_primitive:
raise ValueError(
f"Cannot explicitly project List or Map types, {field.field_id}:{field.name} of type {field.field_type} was selected"
)
# Selected non-struct field
return field.field_type
elif field_result is not None:
# This field wasn't selected but a subfield was so include that
return field_result
else:
return None
def list(self, list_type: ListType, element_result: Optional[IcebergType]) -> Optional[IcebergType]:
if list_type.element_id in self.selected:
if self.select_full_types:
return list_type
elif list_type.element_type and list_type.element_type.is_struct:
projected_struct = self._project_selected_struct(element_result)
return self._project_list(list_type, projected_struct)
else:
if not list_type.element_type.is_primitive:
raise ValueError(
f"Cannot explicitly project List or Map types, {list_type.element_id} of type {list_type.element_type} was selected"
)
return list_type
elif element_result is not None:
return self._project_list(list_type, element_result)
else:
return None
def map(
self, map_type: MapType, key_result: Optional[IcebergType], value_result: Optional[IcebergType]
) -> Optional[IcebergType]:
if map_type.value_id in self.selected:
if self.select_full_types:
return map_type
elif map_type.value_type and map_type.value_type.is_struct:
projected_struct = self._project_selected_struct(value_result)
return self._project_map(map_type, projected_struct)
if not map_type.value_type.is_primitive:
raise ValueError(
f"Cannot explicitly project List or Map types, Map value {map_type.value_id} of type {map_type.value_type} was selected"
)
return map_type
elif value_result is not None:
return self._project_map(map_type, value_result)
elif map_type.key_id in self.selected:
return map_type
return None
def primitive(self, primitive: PrimitiveType) -> Optional[IcebergType]:
return None
@staticmethod
def _project_selected_struct(projected_field: Optional[IcebergType]) -> StructType:
if projected_field and not isinstance(projected_field, StructType):
raise ValueError("Expected a struct")
if projected_field is None:
return StructType()
else:
return projected_field
@staticmethod
def _project_list(list_type: ListType, element_result: IcebergType) -> ListType:
if list_type.element_type == element_result:
return list_type
else:
return ListType(
element_id=list_type.element_id, element_type=element_result, element_required=list_type.element_required
)
@staticmethod
def _project_map(map_type: MapType, value_result: IcebergType) -> MapType:
if map_type.value_type == value_result:
return map_type
else:
return MapType(
key_id=map_type.key_id,
value_id=map_type.value_id,
key_type=map_type.key_type,
value_type=value_result,
value_required=map_type.value_required,
)
@singledispatch
def promote(file_type: IcebergType, read_type: IcebergType) -> IcebergType:
"""Promotes reading a file type to a read type.
Args:
file_type (IcebergType): The type of the Avro file.
read_type (IcebergType): The requested read type.
Raises:
ResolveError: If attempting to resolve an unrecognized object type.
"""
if file_type == read_type:
return file_type
else:
raise ResolveError(f"Cannot promote {file_type} to {read_type}")
@promote.register(IntegerType)
def _(file_type: IntegerType, read_type: IcebergType) -> IcebergType:
if isinstance(read_type, LongType):
# Ints/Longs are binary compatible in Avro, so this is okay
return read_type
else:
raise ResolveError(f"Cannot promote an int to {read_type}")
@promote.register(FloatType)
def _(file_type: FloatType, read_type: IcebergType) -> IcebergType:
if isinstance(read_type, DoubleType):
# A double type is wider
return read_type
else:
raise ResolveError(f"Cannot promote an float to {read_type}")
@promote.register(StringType)
def _(file_type: StringType, read_type: IcebergType) -> IcebergType:
if isinstance(read_type, BinaryType):
return read_type
else:
raise ResolveError(f"Cannot promote an string to {read_type}")
@promote.register(BinaryType)
def _(file_type: BinaryType, read_type: IcebergType) -> IcebergType:
if isinstance(read_type, StringType):
return read_type
else:
raise ResolveError(f"Cannot promote an binary to {read_type}")
@promote.register(DecimalType)
def _(file_type: DecimalType, read_type: IcebergType) -> IcebergType:
if isinstance(read_type, DecimalType):
if file_type.precision <= read_type.precision and file_type.scale == file_type.scale:
return read_type
else:
raise ResolveError(f"Cannot reduce precision from {file_type} to {read_type}")
else:
raise ResolveError(f"Cannot promote an decimal to {read_type}")
@promote.register(FixedType)
def _(file_type: FixedType, read_type: IcebergType) -> IcebergType:
if isinstance(read_type, UUIDType) and len(file_type) == 16:
# Since pyarrow reads parquet UUID as fixed 16-byte binary, the promotion is needed to ensure read compatibility
return read_type
else:
raise ResolveError(f"Cannot promote {file_type} to {read_type}")
def _check_schema_compatible(requested_schema: Schema, provided_schema: Schema) -> None:
"""
Check if the `provided_schema` is compatible with `requested_schema`.
Both Schemas must have valid IDs and share the same ID for the same field names.
Two schemas are considered compatible when:
1. All `required` fields in `requested_schema` are present and are also `required` in the `provided_schema`
2. Field Types are consistent for fields that are present in both schemas. I.e. the field type
in the `provided_schema` can be promoted to the field type of the same field ID in `requested_schema`
Raises:
ValueError: If the schemas are not compatible.
"""
pre_order_visit(requested_schema, _SchemaCompatibilityVisitor(provided_schema))
class _SchemaCompatibilityVisitor(PreOrderSchemaVisitor[bool]):
provided_schema: Schema
def __init__(self, provided_schema: Schema):
from rich.console import Console
from rich.table import Table as RichTable
self.provided_schema = provided_schema
self.rich_table = RichTable(show_header=True, header_style="bold")
self.rich_table.add_column("")
self.rich_table.add_column("Table field")
self.rich_table.add_column("Dataframe field")
self.console = Console(record=True)
def _is_field_compatible(self, lhs: NestedField) -> bool:
# Validate nullability first.
# An optional field can be missing in the provided schema
# But a required field must exist as a required field
try:
rhs = self.provided_schema.find_field(lhs.field_id)
except ValueError:
if lhs.required:
self.rich_table.add_row("❌", str(lhs), "Missing")
return False
else:
self.rich_table.add_row("✅", str(lhs), "Missing")
return True
if lhs.required and not rhs.required:
self.rich_table.add_row("❌", str(lhs), str(rhs))
return False
# Check type compatibility
if lhs.field_type == rhs.field_type:
self.rich_table.add_row("✅", str(lhs), str(rhs))
return True
# We only check that the parent node is also of the same type.
# We check the type of the child nodes when we traverse them later.
elif any(
(isinstance(lhs.field_type, container_type) and isinstance(rhs.field_type, container_type))
for container_type in {StructType, MapType, ListType}
):
self.rich_table.add_row("✅", str(lhs), str(rhs))
return True
else:
try:
# If type can be promoted to the requested schema
# it is considered compatible
promote(rhs.field_type, lhs.field_type)
self.rich_table.add_row("✅", str(lhs), str(rhs))
return True
except ResolveError:
self.rich_table.add_row("❌", str(lhs), str(rhs))
return False
def schema(self, schema: Schema, struct_result: Callable[[], bool]) -> bool:
if not (result := struct_result()):
self.console.print(self.rich_table)
raise ValueError(f"Mismatch in fields:\n{self.console.export_text()}")
return result
def struct(self, struct: StructType, field_results: List[Callable[[], bool]]) -> bool:
results = [result() for result in field_results]
return all(results)
def field(self, field: NestedField, field_result: Callable[[], bool]) -> bool:
return self._is_field_compatible(field) and field_result()
def list(self, list_type: ListType, element_result: Callable[[], bool]) -> bool:
return self._is_field_compatible(list_type.element_field) and element_result()
def map(self, map_type: MapType, key_result: Callable[[], bool], value_result: Callable[[], bool]) -> bool:
return all(
[
self._is_field_compatible(map_type.key_field),
self._is_field_compatible(map_type.value_field),
key_result(),
value_result(),
]
)
def primitive(self, primitive: PrimitiveType) -> bool:
return True