pyiceberg/avro/resolver.py (357 lines of code) (raw):

# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. # pylint: disable=arguments-renamed,unused-argument from enum import Enum from typing import ( Callable, Dict, List, Optional, Tuple, Union, ) from pyiceberg.avro.decoder import BinaryDecoder from pyiceberg.avro.reader import ( BinaryReader, BooleanReader, DateReader, DecimalReader, DefaultReader, DoubleReader, FixedReader, FloatReader, IntegerReader, ListReader, MapReader, NoneReader, OptionReader, Reader, StringReader, StructReader, TimeReader, TimestampNanoReader, TimestampReader, TimestamptzNanoReader, TimestamptzReader, UnknownReader, UUIDReader, ) from pyiceberg.avro.writer import ( BinaryWriter, BooleanWriter, DateWriter, DecimalWriter, DefaultWriter, DoubleWriter, FixedWriter, FloatWriter, IntegerWriter, ListWriter, MapWriter, OptionWriter, StringWriter, StructWriter, TimestampNanoWriter, TimestamptzNanoWriter, TimestamptzWriter, TimestampWriter, TimeWriter, UnknownWriter, UUIDWriter, Writer, ) from pyiceberg.exceptions import ResolveError from pyiceberg.schema import ( PartnerAccessor, PrimitiveWithPartnerVisitor, Schema, SchemaVisitorPerPrimitiveType, promote, visit, visit_with_partner, ) from pyiceberg.typedef import EMPTY_DICT, Record, StructProtocol from pyiceberg.types import ( BinaryType, BooleanType, DateType, DecimalType, DoubleType, FixedType, FloatType, IcebergType, IntegerType, ListType, LongType, MapType, NestedField, PrimitiveType, StringType, StructType, TimestampNanoType, TimestampType, TimestamptzNanoType, TimestamptzType, TimeType, UnknownType, UUIDType, ) STRUCT_ROOT = -1 def construct_reader( file_schema: Union[Schema, IcebergType], read_types: Dict[int, Callable[..., StructProtocol]] = EMPTY_DICT ) -> Reader: """Construct a reader from a file schema. Args: file_schema (Schema | IcebergType): The schema of the Avro file. read_types (Dict[int, Callable[..., StructProtocol]]): Constructors for structs for certain field-ids Raises: NotImplementedError: If attempting to resolve an unrecognized object type. """ return resolve_reader(file_schema, file_schema, read_types) def construct_writer(file_schema: Union[Schema, IcebergType]) -> Writer: """Construct a writer from a file schema. Args: file_schema (Schema | IcebergType): The schema of the Avro file. Raises: NotImplementedError: If attempting to resolve an unrecognized object type. """ return visit(file_schema, CONSTRUCT_WRITER_VISITOR) class ConstructWriter(SchemaVisitorPerPrimitiveType[Writer]): """Construct a writer tree from an Iceberg schema.""" def schema(self, schema: Schema, struct_result: Writer) -> Writer: return struct_result def struct(self, struct: StructType, field_results: List[Writer]) -> Writer: return StructWriter(tuple((pos, result) for pos, result in enumerate(field_results))) def field(self, field: NestedField, field_result: Writer) -> Writer: return field_result if field.required else OptionWriter(field_result) def list(self, list_type: ListType, element_result: Writer) -> Writer: return ListWriter(element_result) def map(self, map_type: MapType, key_result: Writer, value_result: Writer) -> Writer: return MapWriter(key_result, value_result) def visit_fixed(self, fixed_type: FixedType) -> Writer: return FixedWriter(len(fixed_type)) def visit_decimal(self, decimal_type: DecimalType) -> Writer: return DecimalWriter(decimal_type.precision, decimal_type.scale) def visit_boolean(self, boolean_type: BooleanType) -> Writer: return BooleanWriter() def visit_integer(self, integer_type: IntegerType) -> Writer: return IntegerWriter() def visit_long(self, long_type: LongType) -> Writer: return IntegerWriter() def visit_float(self, float_type: FloatType) -> Writer: return FloatWriter() def visit_double(self, double_type: DoubleType) -> Writer: return DoubleWriter() def visit_date(self, date_type: DateType) -> Writer: return DateWriter() def visit_time(self, time_type: TimeType) -> Writer: return TimeWriter() def visit_timestamp(self, timestamp_type: TimestampType) -> Writer: return TimestampWriter() def visit_timestamp_ns(self, timestamp_ns_type: TimestampNanoType) -> Writer: return TimestampNanoWriter() def visit_timestamptz(self, timestamptz_type: TimestamptzType) -> Writer: return TimestamptzWriter() def visit_timestamptz_ns(self, timestamptz_ns_type: TimestamptzNanoType) -> Writer: return TimestamptzNanoWriter() def visit_string(self, string_type: StringType) -> Writer: return StringWriter() def visit_uuid(self, uuid_type: UUIDType) -> Writer: return UUIDWriter() def visit_binary(self, binary_type: BinaryType) -> Writer: return BinaryWriter() def visit_unknown(self, unknown_type: UnknownType) -> Writer: return UnknownWriter() CONSTRUCT_WRITER_VISITOR = ConstructWriter() def resolve_writer( record_schema: Union[Schema, IcebergType], file_schema: Union[Schema, IcebergType], ) -> Writer: """Resolve the file and read schema to produce a reader. Args: record_schema (Schema | IcebergType): The schema of the record in memory. file_schema (Schema | IcebergType): The schema of the file that will be written Raises: NotImplementedError: If attempting to resolve an unrecognized object type. """ if record_schema == file_schema: return construct_writer(file_schema) return visit_with_partner(file_schema, record_schema, WriteSchemaResolver(), SchemaPartnerAccessor()) # type: ignore def resolve_reader( file_schema: Union[Schema, IcebergType], read_schema: Union[Schema, IcebergType], read_types: Dict[int, Callable[..., StructProtocol]] = EMPTY_DICT, read_enums: Dict[int, Callable[..., Enum]] = EMPTY_DICT, ) -> Reader: """Resolve the file and read schema to produce a reader. Args: file_schema (Schema | IcebergType): The schema of the Avro file. read_schema (Schema | IcebergType): The requested read schema which is equal, subset or superset of the file schema. read_types (Dict[int, Callable[..., StructProtocol]]): A dict of types to use for struct data. read_enums (Dict[int, Callable[..., Enum]]): A dict of fields that have to be converted to an enum. Raises: NotImplementedError: If attempting to resolve an unrecognized object type. """ return visit_with_partner(file_schema, read_schema, ReadSchemaResolver(read_types, read_enums), SchemaPartnerAccessor()) # type: ignore class EnumReader(Reader): """An Enum reader to wrap primitive values into an Enum.""" __slots__ = ("enum", "reader") enum: Callable[..., Enum] reader: Reader def __init__(self, enum: Callable[..., Enum], reader: Reader) -> None: self.enum = enum self.reader = reader def read(self, decoder: BinaryDecoder) -> Enum: return self.enum(self.reader.read(decoder)) def skip(self, decoder: BinaryDecoder) -> None: pass class WriteSchemaResolver(PrimitiveWithPartnerVisitor[IcebergType, Writer]): def schema(self, file_schema: Schema, record_schema: Optional[IcebergType], result: Writer) -> Writer: return result def struct(self, file_schema: StructType, record_struct: Optional[IcebergType], file_writers: List[Writer]) -> Writer: if not isinstance(record_struct, StructType): raise ResolveError(f"File/write schema are not aligned for struct, got {record_struct}") record_struct_positions: Dict[int, int] = {field.field_id: pos for pos, field in enumerate(record_struct.fields)} results: List[Tuple[Optional[int], Writer]] = [] for writer, file_field in zip(file_writers, file_schema.fields): if file_field.field_id in record_struct_positions: results.append((record_struct_positions[file_field.field_id], writer)) elif file_field.required: # There is a default value if file_field.write_default is not None: # The field is not in the record, but there is a write default value results.append((None, DefaultWriter(writer=writer, value=file_field.write_default))) elif file_field.required: raise ValueError(f"Field is required, and there is no write default: {file_field}") else: results.append((None, writer)) return StructWriter(field_writers=tuple(results)) def field(self, file_field: NestedField, record_type: Optional[IcebergType], field_writer: Writer) -> Writer: return field_writer if file_field.required else OptionWriter(field_writer) def list(self, file_list_type: ListType, file_list: Optional[IcebergType], element_writer: Writer) -> Writer: return ListWriter(element_writer if file_list_type.element_required else OptionWriter(element_writer)) def map( self, file_map_type: MapType, file_primitive: Optional[IcebergType], key_writer: Writer, value_writer: Writer ) -> Writer: return MapWriter(key_writer, value_writer if file_map_type.value_required else OptionWriter(value_writer)) def primitive(self, file_primitive: PrimitiveType, record_primitive: Optional[IcebergType]) -> Writer: if record_primitive is not None: # ensure that the type can be projected to the expected if file_primitive != record_primitive: promote(record_primitive, file_primitive) return super().primitive(file_primitive, file_primitive) def visit_boolean(self, boolean_type: BooleanType, partner: Optional[IcebergType]) -> Writer: return BooleanWriter() def visit_integer(self, integer_type: IntegerType, partner: Optional[IcebergType]) -> Writer: return IntegerWriter() def visit_long(self, long_type: LongType, partner: Optional[IcebergType]) -> Writer: return IntegerWriter() def visit_float(self, float_type: FloatType, partner: Optional[IcebergType]) -> Writer: return FloatWriter() def visit_double(self, double_type: DoubleType, partner: Optional[IcebergType]) -> Writer: return DoubleWriter() def visit_decimal(self, decimal_type: DecimalType, partner: Optional[IcebergType]) -> Writer: return DecimalWriter(decimal_type.precision, decimal_type.scale) def visit_date(self, date_type: DateType, partner: Optional[IcebergType]) -> Writer: return DateWriter() def visit_time(self, time_type: TimeType, partner: Optional[IcebergType]) -> Writer: return TimeWriter() def visit_timestamp(self, timestamp_type: TimestampType, partner: Optional[IcebergType]) -> Writer: return TimestampWriter() def visit_timestamp_ns(self, timestamp_ns_type: TimestampNanoType, partner: Optional[IcebergType]) -> Writer: return TimestampNanoWriter() def visit_timestamptz(self, timestamptz_type: TimestamptzType, partner: Optional[IcebergType]) -> Writer: return TimestamptzWriter() def visit_timestamptz_ns(self, timestamptz_ns_type: TimestamptzNanoType, partner: Optional[IcebergType]) -> Writer: return TimestamptzNanoWriter() def visit_string(self, string_type: StringType, partner: Optional[IcebergType]) -> Writer: return StringWriter() def visit_uuid(self, uuid_type: UUIDType, partner: Optional[IcebergType]) -> Writer: return UUIDWriter() def visit_fixed(self, fixed_type: FixedType, partner: Optional[IcebergType]) -> Writer: return FixedWriter(len(fixed_type)) def visit_binary(self, binary_type: BinaryType, partner: Optional[IcebergType]) -> Writer: return BinaryWriter() def visit_unknown(self, unknown_type: UnknownType, partner: Optional[IcebergType]) -> Writer: return UnknownWriter() class ReadSchemaResolver(PrimitiveWithPartnerVisitor[IcebergType, Reader]): __slots__ = ("read_types", "read_enums", "context") read_types: Dict[int, Callable[..., StructProtocol]] read_enums: Dict[int, Callable[..., Enum]] context: List[int] def __init__( self, read_types: Dict[int, Callable[..., StructProtocol]] = EMPTY_DICT, read_enums: Dict[int, Callable[..., Enum]] = EMPTY_DICT, ) -> None: self.read_types = read_types self.read_enums = read_enums self.context = [] def schema(self, schema: Schema, expected_schema: Optional[IcebergType], result: Reader) -> Reader: return result def before_field(self, field: NestedField, field_partner: Optional[NestedField]) -> None: self.context.append(field.field_id) def after_field(self, field: NestedField, field_partner: Optional[NestedField]) -> None: self.context.pop() def struct(self, struct: StructType, expected_struct: Optional[IcebergType], field_readers: List[Reader]) -> Reader: read_struct_id = self.context[STRUCT_ROOT] if len(self.context) > 0 else STRUCT_ROOT struct_callable = self.read_types.get(read_struct_id, Record) if not expected_struct: return StructReader(tuple(enumerate(field_readers)), struct_callable, struct) if not isinstance(expected_struct, StructType): raise ResolveError(f"File/read schema are not aligned for struct, got {expected_struct}") expected_positions: Dict[int, int] = {field.field_id: pos for pos, field in enumerate(expected_struct.fields)} # first, add readers for the file fields that must be in order results: List[Tuple[Optional[int], Reader]] = [ ( expected_positions.get(field.field_id), # Check if we need to convert it to an Enum result_reader if not (enum_type := self.read_enums.get(field.field_id)) else EnumReader(enum_type, result_reader), ) for field, result_reader in zip(struct.fields, field_readers) ] file_fields = {field.field_id for field in struct.fields} for pos, read_field in enumerate(expected_struct.fields): if read_field.field_id not in file_fields: if isinstance(read_field, NestedField) and read_field.initial_default is not None: # The field is not in the file, but there is a default value # and that one can be required results.append((pos, DefaultReader(read_field.initial_default))) elif read_field.required: raise ResolveError(f"{read_field} is non-optional, and not part of the file schema") else: # Just set the new field to None results.append((pos, NoneReader())) return StructReader(tuple(results), struct_callable, expected_struct) def field(self, field: NestedField, expected_field: Optional[IcebergType], field_reader: Reader) -> Reader: return field_reader if field.required else OptionReader(field_reader) def list(self, list_type: ListType, expected_list: Optional[IcebergType], element_reader: Reader) -> Reader: if expected_list and not isinstance(expected_list, ListType): raise ResolveError(f"File/read schema are not aligned for list, got {expected_list}") return ListReader(element_reader if list_type.element_required else OptionReader(element_reader)) def map(self, map_type: MapType, expected_map: Optional[IcebergType], key_reader: Reader, value_reader: Reader) -> Reader: if expected_map and not isinstance(expected_map, MapType): raise ResolveError(f"File/read schema are not aligned for map, got {expected_map}") return MapReader(key_reader, value_reader if map_type.value_required else OptionReader(value_reader)) def primitive(self, primitive: PrimitiveType, expected_primitive: Optional[IcebergType]) -> Reader: if expected_primitive is not None: if not isinstance(expected_primitive, PrimitiveType): raise ResolveError(f"File/read schema are not aligned for {primitive}, got {expected_primitive}") # ensure that the type can be projected to the expected if primitive != expected_primitive: promote(primitive, expected_primitive) return super().primitive(primitive, expected_primitive) def visit_boolean(self, boolean_type: BooleanType, partner: Optional[IcebergType]) -> Reader: return BooleanReader() def visit_integer(self, integer_type: IntegerType, partner: Optional[IcebergType]) -> Reader: return IntegerReader() def visit_long(self, long_type: LongType, partner: Optional[IcebergType]) -> Reader: return IntegerReader() def visit_float(self, float_type: FloatType, partner: Optional[IcebergType]) -> Reader: return FloatReader() def visit_double(self, double_type: DoubleType, partner: Optional[IcebergType]) -> Reader: return DoubleReader() def visit_decimal(self, decimal_type: DecimalType, partner: Optional[IcebergType]) -> Reader: return DecimalReader(decimal_type.precision, decimal_type.scale) def visit_date(self, date_type: DateType, partner: Optional[IcebergType]) -> Reader: return DateReader() def visit_time(self, time_type: TimeType, partner: Optional[IcebergType]) -> Reader: return TimeReader() def visit_timestamp(self, timestamp_type: TimestampType, partner: Optional[IcebergType]) -> Reader: return TimestampReader() def visit_timestamp_ns(self, timestamp_ns_type: TimestampNanoType, partner: Optional[IcebergType]) -> Reader: return TimestampNanoReader() def visit_timestamptz(self, timestamptz_type: TimestamptzType, partner: Optional[IcebergType]) -> Reader: return TimestamptzReader() def visit_timestamptz_ns(self, timestamptz_ns_type: TimestamptzNanoType, partner: Optional[IcebergType]) -> Reader: return TimestamptzNanoReader() def visit_string(self, string_type: StringType, partner: Optional[IcebergType]) -> Reader: return StringReader() def visit_uuid(self, uuid_type: UUIDType, partner: Optional[IcebergType]) -> Reader: return UUIDReader() def visit_fixed(self, fixed_type: FixedType, partner: Optional[IcebergType]) -> Reader: return FixedReader(len(fixed_type)) def visit_binary(self, binary_type: BinaryType, partner: Optional[IcebergType]) -> Reader: return BinaryReader() def visit_unknown(self, unknown_type: UnknownType, partner: Optional[IcebergType]) -> Reader: return UnknownReader() class SchemaPartnerAccessor(PartnerAccessor[IcebergType]): def schema_partner(self, partner: Optional[IcebergType]) -> Optional[IcebergType]: if isinstance(partner, Schema): return partner.as_struct() raise ResolveError(f"File/read schema are not aligned for schema, got {partner}") def field_partner(self, partner: Optional[IcebergType], field_id: int, field_name: str) -> Optional[IcebergType]: if isinstance(partner, StructType): field = partner.field(field_id) else: raise ResolveError(f"File/read schema are not aligned for struct, got {partner}") return field.field_type if field else None def list_element_partner(self, partner_list: Optional[IcebergType]) -> Optional[IcebergType]: if isinstance(partner_list, ListType): return partner_list.element_type raise ResolveError(f"File/read schema are not aligned for list, got {partner_list}") def map_key_partner(self, partner_map: Optional[IcebergType]) -> Optional[IcebergType]: if isinstance(partner_map, MapType): return partner_map.key_type raise ResolveError(f"File/read schema are not aligned for map, got {partner_map}") def map_value_partner(self, partner_map: Optional[IcebergType]) -> Optional[IcebergType]: if isinstance(partner_map, MapType): return partner_map.value_type raise ResolveError(f"File/read schema are not aligned for map, got {partner_map}")