pyiceberg/avro/writer.py (144 lines of code) (raw):

# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. """ Classes for building the Writer tree. Constructing a writer tree from the schema makes it easy to decouple the writing implementation from the schema. """ from __future__ import annotations from abc import abstractmethod from dataclasses import dataclass from dataclasses import field as dataclassfield from typing import ( Any, Dict, List, Optional, Tuple, ) from uuid import UUID from pyiceberg.avro.encoder import BinaryEncoder from pyiceberg.typedef import Record from pyiceberg.utils.decimal import decimal_required_bytes, decimal_to_bytes from pyiceberg.utils.singleton import Singleton @dataclass(frozen=True) class Writer(Singleton): @abstractmethod def write(self, encoder: BinaryEncoder, val: Any) -> Any: ... def __repr__(self) -> str: """Return string representation of this object.""" return f"{self.__class__.__name__}()" @dataclass(frozen=True) class BooleanWriter(Writer): def write(self, encoder: BinaryEncoder, val: bool) -> None: encoder.write_boolean(val) @dataclass(frozen=True) class IntegerWriter(Writer): """Longs and ints are encoded the same way, and there is no long in Python.""" def write(self, encoder: BinaryEncoder, val: int) -> None: encoder.write_int(val) @dataclass(frozen=True) class FloatWriter(Writer): def write(self, encoder: BinaryEncoder, val: float) -> None: encoder.write_float(val) @dataclass(frozen=True) class DoubleWriter(Writer): def write(self, encoder: BinaryEncoder, val: float) -> None: encoder.write_double(val) @dataclass(frozen=True) class DateWriter(Writer): def write(self, encoder: BinaryEncoder, val: int) -> None: encoder.write_int(val) @dataclass(frozen=True) class TimeWriter(Writer): def write(self, encoder: BinaryEncoder, val: int) -> None: encoder.write_int(val) @dataclass(frozen=True) class TimestampWriter(Writer): def write(self, encoder: BinaryEncoder, val: int) -> None: encoder.write_int(val) @dataclass(frozen=True) class TimestampNanoWriter(Writer): def write(self, encoder: BinaryEncoder, val: int) -> None: encoder.write_int(val) @dataclass(frozen=True) class TimestamptzWriter(Writer): def write(self, encoder: BinaryEncoder, val: int) -> None: encoder.write_int(val) @dataclass(frozen=True) class TimestamptzNanoWriter(Writer): def write(self, encoder: BinaryEncoder, val: int) -> None: encoder.write_int(val) @dataclass(frozen=True) class StringWriter(Writer): def write(self, encoder: BinaryEncoder, val: Any) -> None: encoder.write_utf8(val) @dataclass(frozen=True) class UUIDWriter(Writer): def write(self, encoder: BinaryEncoder, val: UUID) -> None: encoder.write(val.bytes) @dataclass(frozen=True) class UnknownWriter(Writer): def write(self, encoder: BinaryEncoder, val: Any) -> None: encoder.write_unknown(val) @dataclass(frozen=True) class FixedWriter(Writer): _len: int = dataclassfield() def write(self, encoder: BinaryEncoder, val: bytes) -> None: if len(val) != self._len: raise ValueError(f"Expected {self._len} bytes, got {len(val)}") encoder.write(val) def __len__(self) -> int: """Return the length of this object.""" return self._len def __repr__(self) -> str: """Return string representation of this object.""" return f"FixedWriter({self._len})" @dataclass(frozen=True) class BinaryWriter(Writer): """Variable byte length writer.""" def write(self, encoder: BinaryEncoder, val: Any) -> None: encoder.write_bytes(val) @dataclass(frozen=True) class DecimalWriter(Writer): precision: int = dataclassfield() scale: int = dataclassfield() def write(self, encoder: BinaryEncoder, val: Any) -> None: return encoder.write(decimal_to_bytes(val, byte_length=decimal_required_bytes(self.precision))) def __repr__(self) -> str: """Return string representation of this object.""" return f"DecimalWriter({self.precision}, {self.scale})" @dataclass(frozen=True) class OptionWriter(Writer): option: Writer = dataclassfield() def write(self, encoder: BinaryEncoder, val: Any) -> None: if val is not None: encoder.write_int(1) self.option.write(encoder, val) else: encoder.write_int(0) @dataclass(frozen=True) class StructWriter(Writer): field_writers: Tuple[Tuple[Optional[int], Writer], ...] = dataclassfield() def write(self, encoder: BinaryEncoder, val: Record) -> None: for pos, writer in self.field_writers: # When pos is None, then it is a default value writer.write(encoder, val[pos] if pos is not None else None) def __eq__(self, other: Any) -> bool: """Implement the equality operator for this object.""" return self.field_writers == other.field_writers if isinstance(other, StructWriter) else False def __repr__(self) -> str: """Return string representation of this object.""" return f"StructWriter(tuple(({','.join(repr(field) for field in self.field_writers)})))" def __hash__(self) -> int: """Return the hash of the writer as hash of this object.""" return hash(self.field_writers) @dataclass(frozen=True) class ListWriter(Writer): element_writer: Writer def write(self, encoder: BinaryEncoder, val: List[Any]) -> None: encoder.write_int(len(val)) for v in val: self.element_writer.write(encoder, v) if len(val) > 0: encoder.write_int(0) @dataclass(frozen=True) class MapWriter(Writer): key_writer: Writer value_writer: Writer def write(self, encoder: BinaryEncoder, val: Dict[Any, Any]) -> None: encoder.write_int(len(val)) for k, v in val.items(): self.key_writer.write(encoder, k) self.value_writer.write(encoder, v) if len(val) > 0: encoder.write_int(0) @dataclass(frozen=True) class DefaultWriter(Writer): writer: Writer value: Any def write(self, encoder: BinaryEncoder, _: Any) -> None: self.writer.write(encoder, self.value)