pyiceberg/avro/writer.py (144 lines of code) (raw):
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Classes for building the Writer tree.
Constructing a writer tree from the schema makes it easy
to decouple the writing implementation from the schema.
"""
from __future__ import annotations
from abc import abstractmethod
from dataclasses import dataclass
from dataclasses import field as dataclassfield
from typing import (
Any,
Dict,
List,
Optional,
Tuple,
)
from uuid import UUID
from pyiceberg.avro.encoder import BinaryEncoder
from pyiceberg.typedef import Record
from pyiceberg.utils.decimal import decimal_required_bytes, decimal_to_bytes
from pyiceberg.utils.singleton import Singleton
@dataclass(frozen=True)
class Writer(Singleton):
@abstractmethod
def write(self, encoder: BinaryEncoder, val: Any) -> Any: ...
def __repr__(self) -> str:
"""Return string representation of this object."""
return f"{self.__class__.__name__}()"
@dataclass(frozen=True)
class BooleanWriter(Writer):
def write(self, encoder: BinaryEncoder, val: bool) -> None:
encoder.write_boolean(val)
@dataclass(frozen=True)
class IntegerWriter(Writer):
"""Longs and ints are encoded the same way, and there is no long in Python."""
def write(self, encoder: BinaryEncoder, val: int) -> None:
encoder.write_int(val)
@dataclass(frozen=True)
class FloatWriter(Writer):
def write(self, encoder: BinaryEncoder, val: float) -> None:
encoder.write_float(val)
@dataclass(frozen=True)
class DoubleWriter(Writer):
def write(self, encoder: BinaryEncoder, val: float) -> None:
encoder.write_double(val)
@dataclass(frozen=True)
class DateWriter(Writer):
def write(self, encoder: BinaryEncoder, val: int) -> None:
encoder.write_int(val)
@dataclass(frozen=True)
class TimeWriter(Writer):
def write(self, encoder: BinaryEncoder, val: int) -> None:
encoder.write_int(val)
@dataclass(frozen=True)
class TimestampWriter(Writer):
def write(self, encoder: BinaryEncoder, val: int) -> None:
encoder.write_int(val)
@dataclass(frozen=True)
class TimestampNanoWriter(Writer):
def write(self, encoder: BinaryEncoder, val: int) -> None:
encoder.write_int(val)
@dataclass(frozen=True)
class TimestamptzWriter(Writer):
def write(self, encoder: BinaryEncoder, val: int) -> None:
encoder.write_int(val)
@dataclass(frozen=True)
class TimestamptzNanoWriter(Writer):
def write(self, encoder: BinaryEncoder, val: int) -> None:
encoder.write_int(val)
@dataclass(frozen=True)
class StringWriter(Writer):
def write(self, encoder: BinaryEncoder, val: Any) -> None:
encoder.write_utf8(val)
@dataclass(frozen=True)
class UUIDWriter(Writer):
def write(self, encoder: BinaryEncoder, val: UUID) -> None:
encoder.write(val.bytes)
@dataclass(frozen=True)
class UnknownWriter(Writer):
def write(self, encoder: BinaryEncoder, val: Any) -> None:
encoder.write_unknown(val)
@dataclass(frozen=True)
class FixedWriter(Writer):
_len: int = dataclassfield()
def write(self, encoder: BinaryEncoder, val: bytes) -> None:
if len(val) != self._len:
raise ValueError(f"Expected {self._len} bytes, got {len(val)}")
encoder.write(val)
def __len__(self) -> int:
"""Return the length of this object."""
return self._len
def __repr__(self) -> str:
"""Return string representation of this object."""
return f"FixedWriter({self._len})"
@dataclass(frozen=True)
class BinaryWriter(Writer):
"""Variable byte length writer."""
def write(self, encoder: BinaryEncoder, val: Any) -> None:
encoder.write_bytes(val)
@dataclass(frozen=True)
class DecimalWriter(Writer):
precision: int = dataclassfield()
scale: int = dataclassfield()
def write(self, encoder: BinaryEncoder, val: Any) -> None:
return encoder.write(decimal_to_bytes(val, byte_length=decimal_required_bytes(self.precision)))
def __repr__(self) -> str:
"""Return string representation of this object."""
return f"DecimalWriter({self.precision}, {self.scale})"
@dataclass(frozen=True)
class OptionWriter(Writer):
option: Writer = dataclassfield()
def write(self, encoder: BinaryEncoder, val: Any) -> None:
if val is not None:
encoder.write_int(1)
self.option.write(encoder, val)
else:
encoder.write_int(0)
@dataclass(frozen=True)
class StructWriter(Writer):
field_writers: Tuple[Tuple[Optional[int], Writer], ...] = dataclassfield()
def write(self, encoder: BinaryEncoder, val: Record) -> None:
for pos, writer in self.field_writers:
# When pos is None, then it is a default value
writer.write(encoder, val[pos] if pos is not None else None)
def __eq__(self, other: Any) -> bool:
"""Implement the equality operator for this object."""
return self.field_writers == other.field_writers if isinstance(other, StructWriter) else False
def __repr__(self) -> str:
"""Return string representation of this object."""
return f"StructWriter(tuple(({','.join(repr(field) for field in self.field_writers)})))"
def __hash__(self) -> int:
"""Return the hash of the writer as hash of this object."""
return hash(self.field_writers)
@dataclass(frozen=True)
class ListWriter(Writer):
element_writer: Writer
def write(self, encoder: BinaryEncoder, val: List[Any]) -> None:
encoder.write_int(len(val))
for v in val:
self.element_writer.write(encoder, v)
if len(val) > 0:
encoder.write_int(0)
@dataclass(frozen=True)
class MapWriter(Writer):
key_writer: Writer
value_writer: Writer
def write(self, encoder: BinaryEncoder, val: Dict[Any, Any]) -> None:
encoder.write_int(len(val))
for k, v in val.items():
self.key_writer.write(encoder, k)
self.value_writer.write(encoder, v)
if len(val) > 0:
encoder.write_int(0)
@dataclass(frozen=True)
class DefaultWriter(Writer):
writer: Writer
value: Any
def write(self, encoder: BinaryEncoder, _: Any) -> None:
self.writer.write(encoder, self.value)