pyiceberg/io/pyarrow.py (1,894 lines of code) (raw):
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=redefined-outer-name,arguments-renamed,fixme
"""FileIO implementation for reading and writing table files that uses pyarrow.fs.
This file contains a FileIO implementation that relies on the filesystem interface provided
by PyArrow. It relies on PyArrow's `from_uri` method that infers the correct filesystem
type to use. Theoretically, this allows the supported storage types to grow naturally
with the pyarrow library.
"""
from __future__ import annotations
import concurrent.futures
import fnmatch
import functools
import itertools
import logging
import operator
import os
import re
import uuid
import warnings
from abc import ABC, abstractmethod
from concurrent.futures import Future
from copy import copy
from dataclasses import dataclass
from enum import Enum
from functools import lru_cache, singledispatch
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Generic,
Iterable,
Iterator,
List,
Optional,
Set,
Tuple,
TypeVar,
Union,
cast,
)
from urllib.parse import urlparse
import pyarrow as pa
import pyarrow.compute as pc
import pyarrow.dataset as ds
import pyarrow.lib
import pyarrow.parquet as pq
from pyarrow import ChunkedArray
from pyarrow.fs import (
FileInfo,
FileSystem,
FileType,
FSSpecHandler,
)
from sortedcontainers import SortedList
from pyiceberg.conversions import to_bytes
from pyiceberg.exceptions import ResolveError
from pyiceberg.expressions import AlwaysTrue, BooleanExpression, BoundIsNaN, BoundIsNull, BoundTerm, Not, Or
from pyiceberg.expressions.literals import Literal
from pyiceberg.expressions.visitors import (
BoundBooleanExpressionVisitor,
bind,
extract_field_ids,
translate_column_names,
)
from pyiceberg.expressions.visitors import visit as boolean_expression_visit
from pyiceberg.io import (
AWS_ACCESS_KEY_ID,
AWS_REGION,
AWS_ROLE_ARN,
AWS_ROLE_SESSION_NAME,
AWS_SECRET_ACCESS_KEY,
AWS_SESSION_TOKEN,
GCS_DEFAULT_LOCATION,
GCS_SERVICE_HOST,
GCS_TOKEN,
GCS_TOKEN_EXPIRES_AT_MS,
HDFS_HOST,
HDFS_KERB_TICKET,
HDFS_PORT,
HDFS_USER,
PYARROW_USE_LARGE_TYPES_ON_READ,
S3_ACCESS_KEY_ID,
S3_CONNECT_TIMEOUT,
S3_ENDPOINT,
S3_FORCE_VIRTUAL_ADDRESSING,
S3_PROXY_URI,
S3_REGION,
S3_REQUEST_TIMEOUT,
S3_RESOLVE_REGION,
S3_ROLE_ARN,
S3_ROLE_SESSION_NAME,
S3_SECRET_ACCESS_KEY,
S3_SESSION_TOKEN,
FileIO,
InputFile,
InputStream,
OutputFile,
OutputStream,
_parse_location,
)
from pyiceberg.manifest import (
DataFile,
DataFileContent,
FileFormat,
)
from pyiceberg.partitioning import PartitionField, PartitionFieldValue, PartitionKey, PartitionSpec, partition_record_value
from pyiceberg.schema import (
Accessor,
PartnerAccessor,
PreOrderSchemaVisitor,
Schema,
SchemaVisitorPerPrimitiveType,
SchemaWithPartnerVisitor,
_check_schema_compatible,
build_position_accessors,
pre_order_visit,
promote,
prune_columns,
sanitize_column_names,
visit,
visit_with_partner,
)
from pyiceberg.table.locations import load_location_provider
from pyiceberg.table.metadata import TableMetadata
from pyiceberg.table.name_mapping import NameMapping, apply_name_mapping
from pyiceberg.table.puffin import PuffinFile
from pyiceberg.transforms import IdentityTransform, TruncateTransform
from pyiceberg.typedef import EMPTY_DICT, Properties, Record
from pyiceberg.types import (
BinaryType,
BooleanType,
DateType,
DecimalType,
DoubleType,
FixedType,
FloatType,
IcebergType,
IntegerType,
ListType,
LongType,
MapType,
NestedField,
PrimitiveType,
StringType,
StructType,
TimestampNanoType,
TimestampType,
TimestamptzNanoType,
TimestamptzType,
TimeType,
UnknownType,
UUIDType,
)
from pyiceberg.utils.concurrent import ExecutorFactory
from pyiceberg.utils.config import Config
from pyiceberg.utils.datetime import millis_to_datetime
from pyiceberg.utils.decimal import unscaled_to_decimal
from pyiceberg.utils.deprecated import deprecation_message
from pyiceberg.utils.properties import get_first_property_value, property_as_bool, property_as_int
from pyiceberg.utils.singleton import Singleton
from pyiceberg.utils.truncate import truncate_upper_bound_binary_string, truncate_upper_bound_text_string
if TYPE_CHECKING:
from pyiceberg.table import FileScanTask, WriteTask
logger = logging.getLogger(__name__)
ONE_MEGABYTE = 1024 * 1024
BUFFER_SIZE = "buffer-size"
ICEBERG_SCHEMA = b"iceberg.schema"
# The PARQUET: in front means that it is Parquet specific, in this case the field_id
PYARROW_PARQUET_FIELD_ID_KEY = b"PARQUET:field_id"
PYARROW_FIELD_DOC_KEY = b"doc"
LIST_ELEMENT_NAME = "element"
MAP_KEY_NAME = "key"
MAP_VALUE_NAME = "value"
DOC = "doc"
UTC_ALIASES = {"UTC", "+00:00", "Etc/UTC", "Z"}
T = TypeVar("T")
@lru_cache
def _cached_resolve_s3_region(bucket: str) -> Optional[str]:
from pyarrow.fs import resolve_s3_region
try:
return resolve_s3_region(bucket=bucket)
except (OSError, TypeError):
logger.warning(f"Unable to resolve region for bucket {bucket}")
return None
class UnsupportedPyArrowTypeException(Exception):
"""Cannot convert PyArrow type to corresponding Iceberg type."""
def __init__(self, field: pa.Field, *args: Any):
self.field = field
super().__init__(*args)
class PyArrowLocalFileSystem(pyarrow.fs.LocalFileSystem):
def open_output_stream(self, path: str, *args: Any, **kwargs: Any) -> pyarrow.NativeFile:
# In LocalFileSystem, parent directories must be first created before opening an output stream
self.create_dir(os.path.dirname(path), recursive=True)
return super().open_output_stream(path, *args, **kwargs)
class PyArrowFile(InputFile, OutputFile):
"""A combined InputFile and OutputFile implementation that uses a pyarrow filesystem to generate pyarrow.lib.NativeFile instances.
Args:
location (str): A URI or a path to a local file.
Attributes:
location(str): The URI or path to a local file for a PyArrowFile instance.
Examples:
>>> from pyiceberg.io.pyarrow import PyArrowFile
>>> # input_file = PyArrowFile("s3://foo/bar.txt")
>>> # Read the contents of the PyArrowFile instance
>>> # Make sure that you have permissions to read/write
>>> # file_content = input_file.open().read()
>>> # output_file = PyArrowFile("s3://baz/qux.txt")
>>> # Write bytes to a file
>>> # Make sure that you have permissions to read/write
>>> # output_file.create().write(b'foobytes')
"""
_filesystem: FileSystem
_path: str
_buffer_size: int
def __init__(self, location: str, path: str, fs: FileSystem, buffer_size: int = ONE_MEGABYTE):
self._filesystem = fs
self._path = path
self._buffer_size = buffer_size
super().__init__(location=location)
def _file_info(self) -> FileInfo:
"""Retrieve a pyarrow.fs.FileInfo object for the location.
Raises:
PermissionError: If the file at self.location cannot be accessed due to a permission error such as
an AWS error code 15.
"""
try:
file_info = self._filesystem.get_file_info(self._path)
except OSError as e:
if e.errno == 13 or "AWS Error [code 15]" in str(e):
raise PermissionError(f"Cannot get file info, access denied: {self.location}") from e
raise # pragma: no cover - If some other kind of OSError, raise the raw error
if file_info.type == FileType.NotFound:
raise FileNotFoundError(f"Cannot get file info, file not found: {self.location}")
return file_info
def __len__(self) -> int:
"""Return the total length of the file, in bytes."""
file_info = self._file_info()
return file_info.size
def exists(self) -> bool:
"""Check whether the location exists."""
try:
self._file_info() # raises FileNotFoundError if it does not exist
return True
except FileNotFoundError:
return False
def open(self, seekable: bool = True) -> InputStream:
"""Open the location using a PyArrow FileSystem inferred from the location.
Args:
seekable: If the stream should support seek, or if it is consumed sequential.
Returns:
pyarrow.lib.NativeFile: A NativeFile instance for the file located at `self.location`.
Raises:
FileNotFoundError: If the file at self.location does not exist.
PermissionError: If the file at self.location cannot be accessed due to a permission error such as
an AWS error code 15.
"""
try:
if seekable:
input_file = self._filesystem.open_input_file(self._path)
else:
input_file = self._filesystem.open_input_stream(self._path, buffer_size=self._buffer_size)
except FileNotFoundError:
raise
except PermissionError:
raise
except OSError as e:
if e.errno == 2 or "Path does not exist" in str(e):
raise FileNotFoundError(f"Cannot open file, does not exist: {self.location}") from e
elif e.errno == 13 or "AWS Error [code 15]" in str(e):
raise PermissionError(f"Cannot open file, access denied: {self.location}") from e
raise # pragma: no cover - If some other kind of OSError, raise the raw error
return input_file
def create(self, overwrite: bool = False) -> OutputStream:
"""Create a writable pyarrow.lib.NativeFile for this PyArrowFile's location.
Args:
overwrite (bool): Whether to overwrite the file if it already exists.
Returns:
pyarrow.lib.NativeFile: A NativeFile instance for the file located at self.location.
Raises:
FileExistsError: If the file already exists at `self.location` and `overwrite` is False.
Note:
This retrieves a pyarrow NativeFile by opening an output stream. If overwrite is set to False,
a check is first performed to verify that the file does not exist. This is not thread-safe and
a possibility does exist that the file can be created by a concurrent process after the existence
check yet before the output stream is created. In such a case, the default pyarrow behavior will
truncate the contents of the existing file when opening the output stream.
"""
try:
if not overwrite and self.exists() is True:
raise FileExistsError(f"Cannot create file, already exists: {self.location}")
output_file = self._filesystem.open_output_stream(self._path, buffer_size=self._buffer_size)
except PermissionError:
raise
except OSError as e:
if e.errno == 13 or "AWS Error [code 15]" in str(e):
raise PermissionError(f"Cannot create file, access denied: {self.location}") from e
raise # pragma: no cover - If some other kind of OSError, raise the raw error
return output_file
def to_input_file(self) -> PyArrowFile:
"""Return a new PyArrowFile for the location of an existing PyArrowFile instance.
This method is included to abide by the OutputFile abstract base class. Since this implementation uses a single
PyArrowFile class (as opposed to separate InputFile and OutputFile implementations), this method effectively returns
a copy of the same instance.
"""
return self
class PyArrowFileIO(FileIO):
fs_by_scheme: Callable[[str, Optional[str]], FileSystem]
def __init__(self, properties: Properties = EMPTY_DICT):
self.fs_by_scheme: Callable[[str, Optional[str]], FileSystem] = lru_cache(self._initialize_fs)
super().__init__(properties=properties)
@staticmethod
def parse_location(location: str) -> Tuple[str, str, str]:
"""Return the path without the scheme."""
uri = urlparse(location)
if not uri.scheme:
return "file", uri.netloc, os.path.abspath(location)
elif uri.scheme in ("hdfs", "viewfs"):
return uri.scheme, uri.netloc, uri.path
else:
return uri.scheme, uri.netloc, f"{uri.netloc}{uri.path}"
def _initialize_fs(self, scheme: str, netloc: Optional[str] = None) -> FileSystem:
"""Initialize FileSystem for different scheme."""
if scheme in {"oss"}:
return self._initialize_oss_fs()
elif scheme in {"s3", "s3a", "s3n"}:
return self._initialize_s3_fs(netloc)
elif scheme in {"hdfs", "viewfs"}:
return self._initialize_hdfs_fs(scheme, netloc)
elif scheme in {"gs", "gcs"}:
return self._initialize_gcs_fs()
elif scheme in {"file"}:
return self._initialize_local_fs()
else:
raise ValueError(f"Unrecognized filesystem type in URI: {scheme}")
def _initialize_oss_fs(self) -> FileSystem:
from pyarrow.fs import S3FileSystem
client_kwargs: Dict[str, Any] = {
"endpoint_override": self.properties.get(S3_ENDPOINT),
"access_key": get_first_property_value(self.properties, S3_ACCESS_KEY_ID, AWS_ACCESS_KEY_ID),
"secret_key": get_first_property_value(self.properties, S3_SECRET_ACCESS_KEY, AWS_SECRET_ACCESS_KEY),
"session_token": get_first_property_value(self.properties, S3_SESSION_TOKEN, AWS_SESSION_TOKEN),
"region": get_first_property_value(self.properties, S3_REGION, AWS_REGION),
"force_virtual_addressing": property_as_bool(self.properties, S3_FORCE_VIRTUAL_ADDRESSING, True),
}
if proxy_uri := self.properties.get(S3_PROXY_URI):
client_kwargs["proxy_options"] = proxy_uri
if connect_timeout := self.properties.get(S3_CONNECT_TIMEOUT):
client_kwargs["connect_timeout"] = float(connect_timeout)
if request_timeout := self.properties.get(S3_REQUEST_TIMEOUT):
client_kwargs["request_timeout"] = float(request_timeout)
if role_arn := get_first_property_value(self.properties, S3_ROLE_ARN, AWS_ROLE_ARN):
client_kwargs["role_arn"] = role_arn
if session_name := get_first_property_value(self.properties, S3_ROLE_SESSION_NAME, AWS_ROLE_SESSION_NAME):
client_kwargs["session_name"] = session_name
return S3FileSystem(**client_kwargs)
def _initialize_s3_fs(self, netloc: Optional[str]) -> FileSystem:
from pyarrow.fs import S3FileSystem
provided_region = get_first_property_value(self.properties, S3_REGION, AWS_REGION)
# Do this when we don't provide the region at all, or when we explicitly enable it
if provided_region is None or property_as_bool(self.properties, S3_RESOLVE_REGION, False) is True:
# Resolve region from netloc(bucket), fallback to user-provided region
# Only supported by buckets hosted by S3
bucket_region = _cached_resolve_s3_region(bucket=netloc) or provided_region
if provided_region is not None and bucket_region != provided_region:
logger.warning(
f"PyArrow FileIO overriding S3 bucket region for bucket {netloc}: "
f"provided region {provided_region}, actual region {bucket_region}"
)
else:
bucket_region = provided_region
client_kwargs: Dict[str, Any] = {
"endpoint_override": self.properties.get(S3_ENDPOINT),
"access_key": get_first_property_value(self.properties, S3_ACCESS_KEY_ID, AWS_ACCESS_KEY_ID),
"secret_key": get_first_property_value(self.properties, S3_SECRET_ACCESS_KEY, AWS_SECRET_ACCESS_KEY),
"session_token": get_first_property_value(self.properties, S3_SESSION_TOKEN, AWS_SESSION_TOKEN),
"region": bucket_region,
}
if proxy_uri := self.properties.get(S3_PROXY_URI):
client_kwargs["proxy_options"] = proxy_uri
if connect_timeout := self.properties.get(S3_CONNECT_TIMEOUT):
client_kwargs["connect_timeout"] = float(connect_timeout)
if request_timeout := self.properties.get(S3_REQUEST_TIMEOUT):
client_kwargs["request_timeout"] = float(request_timeout)
if role_arn := get_first_property_value(self.properties, S3_ROLE_ARN, AWS_ROLE_ARN):
client_kwargs["role_arn"] = role_arn
if session_name := get_first_property_value(self.properties, S3_ROLE_SESSION_NAME, AWS_ROLE_SESSION_NAME):
client_kwargs["session_name"] = session_name
if self.properties.get(S3_FORCE_VIRTUAL_ADDRESSING) is not None:
client_kwargs["force_virtual_addressing"] = property_as_bool(self.properties, S3_FORCE_VIRTUAL_ADDRESSING, False)
return S3FileSystem(**client_kwargs)
def _initialize_hdfs_fs(self, scheme: str, netloc: Optional[str]) -> FileSystem:
from pyarrow.fs import HadoopFileSystem
hdfs_kwargs: Dict[str, Any] = {}
if netloc:
return HadoopFileSystem.from_uri(f"{scheme}://{netloc}")
if host := self.properties.get(HDFS_HOST):
hdfs_kwargs["host"] = host
if port := self.properties.get(HDFS_PORT):
# port should be an integer type
hdfs_kwargs["port"] = int(port)
if user := self.properties.get(HDFS_USER):
hdfs_kwargs["user"] = user
if kerb_ticket := self.properties.get(HDFS_KERB_TICKET):
hdfs_kwargs["kerb_ticket"] = kerb_ticket
return HadoopFileSystem(**hdfs_kwargs)
def _initialize_gcs_fs(self) -> FileSystem:
from pyarrow.fs import GcsFileSystem
gcs_kwargs: Dict[str, Any] = {}
if access_token := self.properties.get(GCS_TOKEN):
gcs_kwargs["access_token"] = access_token
if expiration := self.properties.get(GCS_TOKEN_EXPIRES_AT_MS):
gcs_kwargs["credential_token_expiration"] = millis_to_datetime(int(expiration))
if bucket_location := self.properties.get(GCS_DEFAULT_LOCATION):
gcs_kwargs["default_bucket_location"] = bucket_location
if endpoint := self.properties.get(GCS_SERVICE_HOST):
url_parts = urlparse(endpoint)
gcs_kwargs["scheme"] = url_parts.scheme
gcs_kwargs["endpoint_override"] = url_parts.netloc
return GcsFileSystem(**gcs_kwargs)
def _initialize_local_fs(self) -> FileSystem:
return PyArrowLocalFileSystem()
def new_input(self, location: str) -> PyArrowFile:
"""Get a PyArrowFile instance to read bytes from the file at the given location.
Args:
location (str): A URI or a path to a local file.
Returns:
PyArrowFile: A PyArrowFile instance for the given location.
"""
scheme, netloc, path = self.parse_location(location)
return PyArrowFile(
fs=self.fs_by_scheme(scheme, netloc),
location=location,
path=path,
buffer_size=int(self.properties.get(BUFFER_SIZE, ONE_MEGABYTE)),
)
def new_output(self, location: str) -> PyArrowFile:
"""Get a PyArrowFile instance to write bytes to the file at the given location.
Args:
location (str): A URI or a path to a local file.
Returns:
PyArrowFile: A PyArrowFile instance for the given location.
"""
scheme, netloc, path = self.parse_location(location)
return PyArrowFile(
fs=self.fs_by_scheme(scheme, netloc),
location=location,
path=path,
buffer_size=int(self.properties.get(BUFFER_SIZE, ONE_MEGABYTE)),
)
def delete(self, location: Union[str, InputFile, OutputFile]) -> None:
"""Delete the file at the given location.
Args:
location (Union[str, InputFile, OutputFile]): The URI to the file--if an InputFile instance or an OutputFile instance is provided,
the location attribute for that instance is used as the location to delete.
Raises:
FileNotFoundError: When the file at the provided location does not exist.
PermissionError: If the file at the provided location cannot be accessed due to a permission error such as
an AWS error code 15.
"""
str_location = location.location if isinstance(location, (InputFile, OutputFile)) else location
scheme, netloc, path = self.parse_location(str_location)
fs = self.fs_by_scheme(scheme, netloc)
try:
fs.delete_file(path)
except FileNotFoundError:
raise
except PermissionError:
raise
except OSError as e:
if e.errno == 2 or "Path does not exist" in str(e):
raise FileNotFoundError(f"Cannot delete file, does not exist: {location}") from e
elif e.errno == 13 or "AWS Error [code 15]" in str(e):
raise PermissionError(f"Cannot delete file, access denied: {location}") from e
raise # pragma: no cover - If some other kind of OSError, raise the raw error
def __getstate__(self) -> Dict[str, Any]:
"""Create a dictionary of the PyArrowFileIO fields used when pickling."""
fileio_copy = copy(self.__dict__)
fileio_copy["fs_by_scheme"] = None
return fileio_copy
def __setstate__(self, state: Dict[str, Any]) -> None:
"""Deserialize the state into a PyArrowFileIO instance."""
self.__dict__ = state
self.fs_by_scheme = lru_cache(self._initialize_fs)
def schema_to_pyarrow(
schema: Union[Schema, IcebergType],
metadata: Dict[bytes, bytes] = EMPTY_DICT,
include_field_ids: bool = True,
) -> pa.schema:
return visit(schema, _ConvertToArrowSchema(metadata, include_field_ids))
class _ConvertToArrowSchema(SchemaVisitorPerPrimitiveType[pa.DataType]):
_metadata: Dict[bytes, bytes]
def __init__(self, metadata: Dict[bytes, bytes] = EMPTY_DICT, include_field_ids: bool = True) -> None:
self._metadata = metadata
self._include_field_ids = include_field_ids
def schema(self, _: Schema, struct_result: pa.StructType) -> pa.schema:
return pa.schema(list(struct_result), metadata=self._metadata)
def struct(self, _: StructType, field_results: List[pa.DataType]) -> pa.DataType:
return pa.struct(field_results)
def field(self, field: NestedField, field_result: pa.DataType) -> pa.Field:
metadata = {}
if field.doc:
metadata[PYARROW_FIELD_DOC_KEY] = field.doc
if self._include_field_ids:
metadata[PYARROW_PARQUET_FIELD_ID_KEY] = str(field.field_id)
return pa.field(
name=field.name,
type=field_result,
nullable=field.optional,
metadata=metadata,
)
def list(self, list_type: ListType, element_result: pa.DataType) -> pa.DataType:
element_field = self.field(list_type.element_field, element_result)
return pa.large_list(value_type=element_field)
def map(self, map_type: MapType, key_result: pa.DataType, value_result: pa.DataType) -> pa.DataType:
key_field = self.field(map_type.key_field, key_result)
value_field = self.field(map_type.value_field, value_result)
return pa.map_(key_type=key_field, item_type=value_field)
def visit_fixed(self, fixed_type: FixedType) -> pa.DataType:
return pa.binary(len(fixed_type))
def visit_decimal(self, decimal_type: DecimalType) -> pa.DataType:
return pa.decimal128(decimal_type.precision, decimal_type.scale)
def visit_boolean(self, _: BooleanType) -> pa.DataType:
return pa.bool_()
def visit_integer(self, _: IntegerType) -> pa.DataType:
return pa.int32()
def visit_long(self, _: LongType) -> pa.DataType:
return pa.int64()
def visit_float(self, _: FloatType) -> pa.DataType:
# 32-bit IEEE 754 floating point
return pa.float32()
def visit_double(self, _: DoubleType) -> pa.DataType:
# 64-bit IEEE 754 floating point
return pa.float64()
def visit_date(self, _: DateType) -> pa.DataType:
# Date encoded as an int
return pa.date32()
def visit_time(self, _: TimeType) -> pa.DataType:
return pa.time64("us")
def visit_timestamp(self, _: TimestampType) -> pa.DataType:
return pa.timestamp(unit="us")
def visit_timestamp_ns(self, _: TimestampNanoType) -> pa.DataType:
return pa.timestamp(unit="ns")
def visit_timestamptz(self, _: TimestamptzType) -> pa.DataType:
return pa.timestamp(unit="us", tz="UTC")
def visit_timestamptz_ns(self, _: TimestamptzNanoType) -> pa.DataType:
return pa.timestamp(unit="ns", tz="UTC")
def visit_string(self, _: StringType) -> pa.DataType:
return pa.large_string()
def visit_uuid(self, _: UUIDType) -> pa.DataType:
return pa.binary(16)
def visit_unknown(self, _: UnknownType) -> pa.DataType:
return pa.null()
def visit_binary(self, _: BinaryType) -> pa.DataType:
return pa.large_binary()
def _convert_scalar(value: Any, iceberg_type: IcebergType) -> pa.scalar:
if not isinstance(iceberg_type, PrimitiveType):
raise ValueError(f"Expected primitive type, got: {iceberg_type}")
return pa.scalar(value=value, type=schema_to_pyarrow(iceberg_type))
class _ConvertToArrowExpression(BoundBooleanExpressionVisitor[pc.Expression]):
def visit_in(self, term: BoundTerm[Any], literals: Set[Any]) -> pc.Expression:
pyarrow_literals = pa.array(literals, type=schema_to_pyarrow(term.ref().field.field_type))
return pc.field(term.ref().field.name).isin(pyarrow_literals)
def visit_not_in(self, term: BoundTerm[Any], literals: Set[Any]) -> pc.Expression:
pyarrow_literals = pa.array(literals, type=schema_to_pyarrow(term.ref().field.field_type))
return ~pc.field(term.ref().field.name).isin(pyarrow_literals)
def visit_is_nan(self, term: BoundTerm[Any]) -> pc.Expression:
ref = pc.field(term.ref().field.name)
return pc.is_nan(ref)
def visit_not_nan(self, term: BoundTerm[Any]) -> pc.Expression:
ref = pc.field(term.ref().field.name)
return ~pc.is_nan(ref)
def visit_is_null(self, term: BoundTerm[Any]) -> pc.Expression:
return pc.field(term.ref().field.name).is_null(nan_is_null=False)
def visit_not_null(self, term: BoundTerm[Any]) -> pc.Expression:
return pc.field(term.ref().field.name).is_valid()
def visit_equal(self, term: BoundTerm[Any], literal: Literal[Any]) -> pc.Expression:
return pc.field(term.ref().field.name) == _convert_scalar(literal.value, term.ref().field.field_type)
def visit_not_equal(self, term: BoundTerm[Any], literal: Literal[Any]) -> pc.Expression:
return pc.field(term.ref().field.name) != _convert_scalar(literal.value, term.ref().field.field_type)
def visit_greater_than_or_equal(self, term: BoundTerm[Any], literal: Literal[Any]) -> pc.Expression:
return pc.field(term.ref().field.name) >= _convert_scalar(literal.value, term.ref().field.field_type)
def visit_greater_than(self, term: BoundTerm[Any], literal: Literal[Any]) -> pc.Expression:
return pc.field(term.ref().field.name) > _convert_scalar(literal.value, term.ref().field.field_type)
def visit_less_than(self, term: BoundTerm[Any], literal: Literal[Any]) -> pc.Expression:
return pc.field(term.ref().field.name) < _convert_scalar(literal.value, term.ref().field.field_type)
def visit_less_than_or_equal(self, term: BoundTerm[Any], literal: Literal[Any]) -> pc.Expression:
return pc.field(term.ref().field.name) <= _convert_scalar(literal.value, term.ref().field.field_type)
def visit_starts_with(self, term: BoundTerm[Any], literal: Literal[Any]) -> pc.Expression:
return pc.starts_with(pc.field(term.ref().field.name), literal.value)
def visit_not_starts_with(self, term: BoundTerm[Any], literal: Literal[Any]) -> pc.Expression:
return ~pc.starts_with(pc.field(term.ref().field.name), literal.value)
def visit_true(self) -> pc.Expression:
return pc.scalar(True)
def visit_false(self) -> pc.Expression:
return pc.scalar(False)
def visit_not(self, child_result: pc.Expression) -> pc.Expression:
return ~child_result
def visit_and(self, left_result: pc.Expression, right_result: pc.Expression) -> pc.Expression:
return left_result & right_result
def visit_or(self, left_result: pc.Expression, right_result: pc.Expression) -> pc.Expression:
return left_result | right_result
class _NullNaNUnmentionedTermsCollector(BoundBooleanExpressionVisitor[None]):
# BoundTerms which have either is_null or is_not_null appearing at least once in the boolean expr.
is_null_or_not_bound_terms: set[BoundTerm[Any]]
# The remaining BoundTerms appearing in the boolean expr.
null_unmentioned_bound_terms: set[BoundTerm[Any]]
# BoundTerms which have either is_nan or is_not_nan appearing at least once in the boolean expr.
is_nan_or_not_bound_terms: set[BoundTerm[Any]]
# The remaining BoundTerms appearing in the boolean expr.
nan_unmentioned_bound_terms: set[BoundTerm[Any]]
def __init__(self) -> None:
super().__init__()
self.is_null_or_not_bound_terms = set()
self.null_unmentioned_bound_terms = set()
self.is_nan_or_not_bound_terms = set()
self.nan_unmentioned_bound_terms = set()
def _handle_explicit_is_null_or_not(self, term: BoundTerm[Any]) -> None:
"""Handle the predicate case where either is_null or is_not_null is included."""
if term in self.null_unmentioned_bound_terms:
self.null_unmentioned_bound_terms.remove(term)
self.is_null_or_not_bound_terms.add(term)
def _handle_null_unmentioned(self, term: BoundTerm[Any]) -> None:
"""Handle the predicate case where neither is_null or is_not_null is included."""
if term not in self.is_null_or_not_bound_terms:
self.null_unmentioned_bound_terms.add(term)
def _handle_explicit_is_nan_or_not(self, term: BoundTerm[Any]) -> None:
"""Handle the predicate case where either is_nan or is_not_nan is included."""
if term in self.nan_unmentioned_bound_terms:
self.nan_unmentioned_bound_terms.remove(term)
self.is_nan_or_not_bound_terms.add(term)
def _handle_nan_unmentioned(self, term: BoundTerm[Any]) -> None:
"""Handle the predicate case where neither is_nan or is_not_nan is included."""
if term not in self.is_nan_or_not_bound_terms:
self.nan_unmentioned_bound_terms.add(term)
def visit_in(self, term: BoundTerm[Any], literals: Set[Any]) -> None:
self._handle_null_unmentioned(term)
self._handle_nan_unmentioned(term)
def visit_not_in(self, term: BoundTerm[Any], literals: Set[Any]) -> None:
self._handle_null_unmentioned(term)
self._handle_nan_unmentioned(term)
def visit_is_nan(self, term: BoundTerm[Any]) -> None:
self._handle_null_unmentioned(term)
self._handle_explicit_is_nan_or_not(term)
def visit_not_nan(self, term: BoundTerm[Any]) -> None:
self._handle_null_unmentioned(term)
self._handle_explicit_is_nan_or_not(term)
def visit_is_null(self, term: BoundTerm[Any]) -> None:
self._handle_explicit_is_null_or_not(term)
self._handle_nan_unmentioned(term)
def visit_not_null(self, term: BoundTerm[Any]) -> None:
self._handle_explicit_is_null_or_not(term)
self._handle_nan_unmentioned(term)
def visit_equal(self, term: BoundTerm[Any], literal: Literal[Any]) -> None:
self._handle_null_unmentioned(term)
self._handle_nan_unmentioned(term)
def visit_not_equal(self, term: BoundTerm[Any], literal: Literal[Any]) -> None:
self._handle_null_unmentioned(term)
self._handle_nan_unmentioned(term)
def visit_greater_than_or_equal(self, term: BoundTerm[Any], literal: Literal[Any]) -> None:
self._handle_null_unmentioned(term)
self._handle_nan_unmentioned(term)
def visit_greater_than(self, term: BoundTerm[Any], literal: Literal[Any]) -> None:
self._handle_null_unmentioned(term)
self._handle_nan_unmentioned(term)
def visit_less_than(self, term: BoundTerm[Any], literal: Literal[Any]) -> None:
self._handle_null_unmentioned(term)
self._handle_nan_unmentioned(term)
def visit_less_than_or_equal(self, term: BoundTerm[Any], literal: Literal[Any]) -> None:
self._handle_null_unmentioned(term)
self._handle_nan_unmentioned(term)
def visit_starts_with(self, term: BoundTerm[Any], literal: Literal[Any]) -> None:
self._handle_null_unmentioned(term)
self._handle_nan_unmentioned(term)
def visit_not_starts_with(self, term: BoundTerm[Any], literal: Literal[Any]) -> None:
self._handle_null_unmentioned(term)
self._handle_nan_unmentioned(term)
def visit_true(self) -> None:
return
def visit_false(self) -> None:
return
def visit_not(self, child_result: None) -> None:
return
def visit_and(self, left_result: None, right_result: None) -> None:
return
def visit_or(self, left_result: None, right_result: None) -> None:
return
def collect(
self,
expr: BooleanExpression,
) -> None:
"""Collect the bound references categorized by having at least one is_null or is_not_null in the expr and the remaining."""
boolean_expression_visit(expr, self)
def expression_to_pyarrow(expr: BooleanExpression) -> pc.Expression:
return boolean_expression_visit(expr, _ConvertToArrowExpression())
def _expression_to_complementary_pyarrow(expr: BooleanExpression) -> pc.Expression:
"""Complementary filter conversion function of expression_to_pyarrow.
Could not use expression_to_pyarrow(Not(expr)) to achieve this complementary effect because ~ in pyarrow.compute.Expression does not handle null.
"""
collector = _NullNaNUnmentionedTermsCollector()
collector.collect(expr)
# Convert the set of terms to a sorted list so that layout of the expression to build is deterministic.
null_unmentioned_bound_terms: List[BoundTerm[Any]] = sorted(
collector.null_unmentioned_bound_terms, key=lambda term: term.ref().field.name
)
nan_unmentioned_bound_terms: List[BoundTerm[Any]] = sorted(
collector.nan_unmentioned_bound_terms, key=lambda term: term.ref().field.name
)
preserve_expr: BooleanExpression = Not(expr)
for term in null_unmentioned_bound_terms:
preserve_expr = Or(preserve_expr, BoundIsNull(term=term))
for term in nan_unmentioned_bound_terms:
preserve_expr = Or(preserve_expr, BoundIsNaN(term=term))
return expression_to_pyarrow(preserve_expr)
@lru_cache
def _get_file_format(file_format: FileFormat, **kwargs: Dict[str, Any]) -> ds.FileFormat:
if file_format == FileFormat.PARQUET:
return ds.ParquetFileFormat(**kwargs)
else:
raise ValueError(f"Unsupported file format: {file_format}")
def _construct_fragment(fs: FileSystem, data_file: DataFile, file_format_kwargs: Dict[str, Any] = EMPTY_DICT) -> ds.Fragment:
_, _, path = PyArrowFileIO.parse_location(data_file.file_path)
return _get_file_format(data_file.file_format, **file_format_kwargs).make_fragment(path, fs)
def _read_deletes(fs: FileSystem, data_file: DataFile) -> Dict[str, pa.ChunkedArray]:
if data_file.file_format == FileFormat.PARQUET:
delete_fragment = _construct_fragment(
fs,
data_file,
file_format_kwargs={"dictionary_columns": ("file_path",), "pre_buffer": True, "buffer_size": ONE_MEGABYTE},
)
table = ds.Scanner.from_fragment(fragment=delete_fragment).to_table()
table = table.unify_dictionaries()
return {
file.as_py(): table.filter(pc.field("file_path") == file).column("pos")
for file in table.column("file_path").chunks[0].dictionary
}
elif data_file.file_format == FileFormat.PUFFIN:
_, _, path = PyArrowFileIO.parse_location(data_file.file_path)
with fs.open_input_file(path) as fi:
payload = fi.read()
return PuffinFile(payload).to_vector()
else:
raise ValueError(f"Delete file format not supported: {data_file.file_format}")
def _combine_positional_deletes(positional_deletes: List[pa.ChunkedArray], start_index: int, end_index: int) -> pa.Array:
if len(positional_deletes) == 1:
all_chunks = positional_deletes[0]
else:
all_chunks = pa.chunked_array(itertools.chain(*[arr.chunks for arr in positional_deletes]))
# Create the full range array with pyarrow
full_range = pa.array(range(start_index, end_index))
# When available, replace with Arrow generator to improve performance
# See https://github.com/apache/iceberg-python/issues/1271 for details
# Filter out values in all_chunks from full_range
result = pc.filter(full_range, pc.invert(pc.is_in(full_range, value_set=all_chunks)))
# Subtract the start_index from each element in the result
return pc.subtract(result, pa.scalar(start_index))
def pyarrow_to_schema(
schema: pa.Schema, name_mapping: Optional[NameMapping] = None, downcast_ns_timestamp_to_us: bool = False
) -> Schema:
has_ids = visit_pyarrow(schema, _HasIds())
if has_ids:
return visit_pyarrow(schema, _ConvertToIceberg(downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us))
elif name_mapping is not None:
schema_without_ids = _pyarrow_to_schema_without_ids(schema, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us)
return apply_name_mapping(schema_without_ids, name_mapping)
else:
raise ValueError(
"Parquet file does not have field-ids and the Iceberg table does not have 'schema.name-mapping.default' defined"
)
def _pyarrow_to_schema_without_ids(schema: pa.Schema, downcast_ns_timestamp_to_us: bool = False) -> Schema:
return visit_pyarrow(schema, _ConvertToIcebergWithoutIDs(downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us))
def _pyarrow_schema_ensure_large_types(schema: pa.Schema) -> pa.Schema:
return visit_pyarrow(schema, _ConvertToLargeTypes())
def _pyarrow_schema_ensure_small_types(schema: pa.Schema) -> pa.Schema:
return visit_pyarrow(schema, _ConvertToSmallTypes())
@singledispatch
def visit_pyarrow(obj: Union[pa.DataType, pa.Schema], visitor: PyArrowSchemaVisitor[T]) -> T:
"""Apply a pyarrow schema visitor to any point within a schema.
The function traverses the schema in post-order fashion.
Args:
obj (Union[pa.DataType, pa.Schema]): An instance of a Schema or an IcebergType.
visitor (PyArrowSchemaVisitor[T]): An instance of an implementation of the generic PyarrowSchemaVisitor base class.
Raises:
NotImplementedError: If attempting to visit an unrecognized object type.
"""
raise NotImplementedError(f"Cannot visit non-type: {obj}")
@visit_pyarrow.register(pa.Schema)
def _(obj: pa.Schema, visitor: PyArrowSchemaVisitor[T]) -> T:
return visitor.schema(obj, visit_pyarrow(pa.struct(obj), visitor))
@visit_pyarrow.register(pa.StructType)
def _(obj: pa.StructType, visitor: PyArrowSchemaVisitor[T]) -> T:
results = [visit_pyarrow(field, visitor) for field in obj]
return visitor.struct(obj, results)
@visit_pyarrow.register(pa.ListType)
@visit_pyarrow.register(pa.FixedSizeListType)
@visit_pyarrow.register(pa.LargeListType)
def _(obj: Union[pa.ListType, pa.LargeListType, pa.FixedSizeListType], visitor: PyArrowSchemaVisitor[T]) -> T:
visitor.before_list_element(obj.value_field)
result = visit_pyarrow(obj.value_type, visitor)
visitor.after_list_element(obj.value_field)
return visitor.list(obj, result)
@visit_pyarrow.register(pa.MapType)
def _(obj: pa.MapType, visitor: PyArrowSchemaVisitor[T]) -> T:
visitor.before_map_key(obj.key_field)
key_result = visit_pyarrow(obj.key_type, visitor)
visitor.after_map_key(obj.key_field)
visitor.before_map_value(obj.item_field)
value_result = visit_pyarrow(obj.item_type, visitor)
visitor.after_map_value(obj.item_field)
return visitor.map(obj, key_result, value_result)
@visit_pyarrow.register(pa.DictionaryType)
def _(obj: pa.DictionaryType, visitor: PyArrowSchemaVisitor[T]) -> T:
# Parquet has no dictionary type. dictionary-encoding is handled
# as an encoding detail, not as a separate type.
# We will follow this approach in determining the Iceberg Type,
# as we only support parquet in PyIceberg for now.
logger.warning(f"Iceberg does not have a dictionary type. {type(obj)} will be inferred as {obj.value_type} on read.")
return visit_pyarrow(obj.value_type, visitor)
@visit_pyarrow.register(pa.Field)
def _(obj: pa.Field, visitor: PyArrowSchemaVisitor[T]) -> T:
field_type = obj.type
visitor.before_field(obj)
try:
result = visit_pyarrow(field_type, visitor)
except TypeError as e:
raise UnsupportedPyArrowTypeException(obj, f"Column '{obj.name}' has an unsupported type: {field_type}") from e
visitor.after_field(obj)
return visitor.field(obj, result)
@visit_pyarrow.register(pa.DataType)
def _(obj: pa.DataType, visitor: PyArrowSchemaVisitor[T]) -> T:
if pa.types.is_nested(obj):
raise TypeError(f"Expected primitive type, got: {type(obj)}")
return visitor.primitive(obj)
class PyArrowSchemaVisitor(Generic[T], ABC):
def before_field(self, field: pa.Field) -> None:
"""Override this method to perform an action immediately before visiting a field."""
def after_field(self, field: pa.Field) -> None:
"""Override this method to perform an action immediately after visiting a field."""
def before_list_element(self, element: pa.Field) -> None:
"""Override this method to perform an action immediately before visiting an element within a ListType."""
def after_list_element(self, element: pa.Field) -> None:
"""Override this method to perform an action immediately after visiting an element within a ListType."""
def before_map_key(self, key: pa.Field) -> None:
"""Override this method to perform an action immediately before visiting a key within a MapType."""
def after_map_key(self, key: pa.Field) -> None:
"""Override this method to perform an action immediately after visiting a key within a MapType."""
def before_map_value(self, value: pa.Field) -> None:
"""Override this method to perform an action immediately before visiting a value within a MapType."""
def after_map_value(self, value: pa.Field) -> None:
"""Override this method to perform an action immediately after visiting a value within a MapType."""
@abstractmethod
def schema(self, schema: pa.Schema, struct_result: T) -> T:
"""Visit a schema."""
@abstractmethod
def struct(self, struct: pa.StructType, field_results: List[T]) -> T:
"""Visit a struct."""
@abstractmethod
def field(self, field: pa.Field, field_result: T) -> T:
"""Visit a field."""
@abstractmethod
def list(self, list_type: pa.ListType, element_result: T) -> T:
"""Visit a list."""
@abstractmethod
def map(self, map_type: pa.MapType, key_result: T, value_result: T) -> T:
"""Visit a map."""
@abstractmethod
def primitive(self, primitive: pa.DataType) -> T:
"""Visit a primitive type."""
def _get_field_id(field: pa.Field) -> Optional[int]:
return (
int(field_id_str.decode())
if (field.metadata and (field_id_str := field.metadata.get(PYARROW_PARQUET_FIELD_ID_KEY)))
else None
)
class _HasIds(PyArrowSchemaVisitor[bool]):
def schema(self, schema: pa.Schema, struct_result: bool) -> bool:
return struct_result
def struct(self, struct: pa.StructType, field_results: List[bool]) -> bool:
return all(field_results)
def field(self, field: pa.Field, field_result: bool) -> bool:
return all([_get_field_id(field) is not None, field_result])
def list(self, list_type: pa.ListType, element_result: bool) -> bool:
element_field = list_type.value_field
element_id = _get_field_id(element_field)
return element_result and element_id is not None
def map(self, map_type: pa.MapType, key_result: bool, value_result: bool) -> bool:
key_field = map_type.key_field
key_id = _get_field_id(key_field)
value_field = map_type.item_field
value_id = _get_field_id(value_field)
return all([key_id is not None, value_id is not None, key_result, value_result])
def primitive(self, primitive: pa.DataType) -> bool:
return True
class _ConvertToIceberg(PyArrowSchemaVisitor[Union[IcebergType, Schema]]):
"""Converts PyArrowSchema to Iceberg Schema. Applies the IDs from name_mapping if provided."""
_field_names: List[str]
def __init__(self, downcast_ns_timestamp_to_us: bool = False) -> None:
self._field_names = []
self._downcast_ns_timestamp_to_us = downcast_ns_timestamp_to_us
def _field_id(self, field: pa.Field) -> int:
if (field_id := _get_field_id(field)) is not None:
return field_id
else:
raise ValueError(f"Cannot convert {field} to Iceberg Field as field_id is empty.")
def schema(self, schema: pa.Schema, struct_result: StructType) -> Schema:
return Schema(*struct_result.fields)
def struct(self, struct: pa.StructType, field_results: List[NestedField]) -> StructType:
return StructType(*field_results)
def field(self, field: pa.Field, field_result: IcebergType) -> NestedField:
field_id = self._field_id(field)
field_doc = doc_str.decode() if (field.metadata and (doc_str := field.metadata.get(PYARROW_FIELD_DOC_KEY))) else None
field_type = field_result
return NestedField(field_id, field.name, field_type, required=not field.nullable, doc=field_doc)
def list(self, list_type: pa.ListType, element_result: IcebergType) -> ListType:
element_field = list_type.value_field
self._field_names.append(LIST_ELEMENT_NAME)
element_id = self._field_id(element_field)
self._field_names.pop()
return ListType(element_id, element_result, element_required=not element_field.nullable)
def map(self, map_type: pa.MapType, key_result: IcebergType, value_result: IcebergType) -> MapType:
key_field = map_type.key_field
self._field_names.append(MAP_KEY_NAME)
key_id = self._field_id(key_field)
self._field_names.pop()
value_field = map_type.item_field
self._field_names.append(MAP_VALUE_NAME)
value_id = self._field_id(value_field)
self._field_names.pop()
return MapType(key_id, key_result, value_id, value_result, value_required=not value_field.nullable)
def primitive(self, primitive: pa.DataType) -> PrimitiveType:
if pa.types.is_boolean(primitive):
return BooleanType()
elif pa.types.is_integer(primitive):
width = primitive.bit_width
if width <= 32:
return IntegerType()
elif width <= 64:
return LongType()
else:
# Does not exist (yet)
raise TypeError(f"Unsupported integer type: {primitive}")
elif pa.types.is_float32(primitive):
return FloatType()
elif pa.types.is_float64(primitive):
return DoubleType()
elif isinstance(primitive, pa.Decimal128Type):
primitive = cast(pa.Decimal128Type, primitive)
return DecimalType(primitive.precision, primitive.scale)
elif pa.types.is_string(primitive) or pa.types.is_large_string(primitive) or pa.types.is_string_view(primitive):
return StringType()
elif pa.types.is_date32(primitive):
return DateType()
elif isinstance(primitive, pa.Time64Type) and primitive.unit == "us":
return TimeType()
elif pa.types.is_timestamp(primitive):
primitive = cast(pa.TimestampType, primitive)
if primitive.unit in ("s", "ms", "us"):
# Supported types, will be upcast automatically to 'us'
pass
elif primitive.unit == "ns":
if self._downcast_ns_timestamp_to_us:
logger.warning("Iceberg does not yet support 'ns' timestamp precision. Downcasting to 'us'.")
else:
raise TypeError(
"Iceberg does not yet support 'ns' timestamp precision. Use 'downcast-ns-timestamp-to-us-on-write' configuration property to automatically downcast 'ns' to 'us' on write.",
)
else:
raise TypeError(f"Unsupported precision for timestamp type: {primitive.unit}")
if primitive.tz in UTC_ALIASES:
return TimestamptzType()
elif primitive.tz is None:
return TimestampType()
elif pa.types.is_binary(primitive) or pa.types.is_large_binary(primitive) or pa.types.is_binary_view(primitive):
return BinaryType()
elif pa.types.is_fixed_size_binary(primitive):
primitive = cast(pa.FixedSizeBinaryType, primitive)
return FixedType(primitive.byte_width)
elif pa.types.is_null(primitive):
return UnknownType()
raise TypeError(f"Unsupported type: {primitive}")
def before_field(self, field: pa.Field) -> None:
self._field_names.append(field.name)
def after_field(self, field: pa.Field) -> None:
self._field_names.pop()
def before_list_element(self, element: pa.Field) -> None:
self._field_names.append(LIST_ELEMENT_NAME)
def after_list_element(self, element: pa.Field) -> None:
self._field_names.pop()
def before_map_key(self, key: pa.Field) -> None:
self._field_names.append(MAP_KEY_NAME)
def after_map_key(self, element: pa.Field) -> None:
self._field_names.pop()
def before_map_value(self, value: pa.Field) -> None:
self._field_names.append(MAP_VALUE_NAME)
def after_map_value(self, element: pa.Field) -> None:
self._field_names.pop()
class _ConvertToLargeTypes(PyArrowSchemaVisitor[Union[pa.DataType, pa.Schema]]):
def schema(self, schema: pa.Schema, struct_result: pa.StructType) -> pa.Schema:
return pa.schema(struct_result)
def struct(self, struct: pa.StructType, field_results: List[pa.Field]) -> pa.StructType:
return pa.struct(field_results)
def field(self, field: pa.Field, field_result: pa.DataType) -> pa.Field:
return field.with_type(field_result)
def list(self, list_type: pa.ListType, element_result: pa.DataType) -> pa.DataType:
return pa.large_list(element_result)
def map(self, map_type: pa.MapType, key_result: pa.DataType, value_result: pa.DataType) -> pa.DataType:
return pa.map_(key_result, value_result)
def primitive(self, primitive: pa.DataType) -> pa.DataType:
if primitive == pa.string():
return pa.large_string()
elif primitive == pa.binary():
return pa.large_binary()
return primitive
class _ConvertToSmallTypes(PyArrowSchemaVisitor[Union[pa.DataType, pa.Schema]]):
def schema(self, schema: pa.Schema, struct_result: pa.StructType) -> pa.Schema:
return pa.schema(struct_result)
def struct(self, struct: pa.StructType, field_results: List[pa.Field]) -> pa.StructType:
return pa.struct(field_results)
def field(self, field: pa.Field, field_result: pa.DataType) -> pa.Field:
return field.with_type(field_result)
def list(self, list_type: pa.ListType, element_result: pa.DataType) -> pa.DataType:
return pa.list_(element_result)
def map(self, map_type: pa.MapType, key_result: pa.DataType, value_result: pa.DataType) -> pa.DataType:
return pa.map_(key_result, value_result)
def primitive(self, primitive: pa.DataType) -> pa.DataType:
if primitive == pa.large_string():
return pa.string()
elif primitive == pa.large_binary():
return pa.binary()
return primitive
class _ConvertToIcebergWithoutIDs(_ConvertToIceberg):
"""
Converts PyArrowSchema to Iceberg Schema with all -1 ids.
The schema generated through this visitor should always be
used in conjunction with `new_table_metadata` function to
assign new field ids in order. This is currently used only
when creating an Iceberg Schema from a PyArrow schema when
creating a new Iceberg table.
"""
def _field_id(self, field: pa.Field) -> int:
return -1
def _get_column_projection_values(
file: DataFile, projected_schema: Schema, partition_spec: Optional[PartitionSpec], file_project_field_ids: Set[int]
) -> Tuple[bool, Dict[str, Any]]:
"""Apply Column Projection rules to File Schema."""
project_schema_diff = projected_schema.field_ids.difference(file_project_field_ids)
should_project_columns = len(project_schema_diff) > 0
projected_missing_fields: Dict[str, Any] = {}
if not should_project_columns:
return False, {}
partition_schema: StructType
accessors: Dict[int, Accessor]
if partition_spec is not None:
partition_schema = partition_spec.partition_type(projected_schema)
accessors = build_position_accessors(partition_schema)
else:
return False, {}
for field_id in project_schema_diff:
for partition_field in partition_spec.fields_by_source_id(field_id):
if isinstance(partition_field.transform, IdentityTransform):
accessor = accessors.get(partition_field.field_id)
if accessor is None:
continue
# The partition field may not exist in the partition record of the data file.
# This can happen when new partition fields are introduced after the file was written.
try:
if partition_value := accessor.get(file.partition):
projected_missing_fields[partition_field.name] = partition_value
except IndexError:
continue
return True, projected_missing_fields
def _task_to_record_batches(
fs: FileSystem,
task: FileScanTask,
bound_row_filter: BooleanExpression,
projected_schema: Schema,
projected_field_ids: Set[int],
positional_deletes: Optional[List[ChunkedArray]],
case_sensitive: bool,
name_mapping: Optional[NameMapping] = None,
partition_spec: Optional[PartitionSpec] = None,
) -> Iterator[pa.RecordBatch]:
_, _, path = _parse_location(task.file.file_path)
arrow_format = ds.ParquetFileFormat(pre_buffer=True, buffer_size=(ONE_MEGABYTE * 8))
with fs.open_input_file(path) as fin:
fragment = arrow_format.make_fragment(fin)
physical_schema = fragment.physical_schema
# In V1 and V2 table formats, we only support Timestamp 'us' in Iceberg Schema
# Hence it is reasonable to always cast 'ns' timestamp to 'us' on read.
# When V3 support is introduced, we will update `downcast_ns_timestamp_to_us` flag based on
# the table format version.
file_schema = pyarrow_to_schema(physical_schema, name_mapping, downcast_ns_timestamp_to_us=True)
pyarrow_filter = None
if bound_row_filter is not AlwaysTrue():
translated_row_filter = translate_column_names(bound_row_filter, file_schema, case_sensitive=case_sensitive)
bound_file_filter = bind(file_schema, translated_row_filter, case_sensitive=case_sensitive)
pyarrow_filter = expression_to_pyarrow(bound_file_filter)
# Apply column projection rules
# https://iceberg.apache.org/spec/#column-projection
should_project_columns, projected_missing_fields = _get_column_projection_values(
task.file, projected_schema, partition_spec, file_schema.field_ids
)
file_project_schema = prune_columns(file_schema, projected_field_ids, select_full_types=False)
fragment_scanner = ds.Scanner.from_fragment(
fragment=fragment,
schema=physical_schema,
# This will push down the query to Arrow.
# But in case there are positional deletes, we have to apply them first
filter=pyarrow_filter if not positional_deletes else None,
columns=[col.name for col in file_project_schema.columns],
)
next_index = 0
batches = fragment_scanner.to_batches()
for batch in batches:
next_index = next_index + len(batch)
current_index = next_index - len(batch)
current_batch = batch
if positional_deletes:
# Create the mask of indices that we're interested in
indices = _combine_positional_deletes(positional_deletes, current_index, current_index + len(batch))
current_batch = current_batch.take(indices)
# skip empty batches
if current_batch.num_rows == 0:
continue
# Apply the user filter
if pyarrow_filter is not None:
# Temporary fix until PyArrow 21 is released ( https://github.com/apache/arrow/pull/46057 )
table = pa.Table.from_batches([current_batch])
table = table.filter(pyarrow_filter)
# skip empty batches
if table.num_rows == 0:
continue
current_batch = table.combine_chunks().to_batches()[0]
result_batch = _to_requested_schema(
projected_schema,
file_project_schema,
current_batch,
downcast_ns_timestamp_to_us=True,
)
# Inject projected column values if available
if should_project_columns:
for name, value in projected_missing_fields.items():
index = result_batch.schema.get_field_index(name)
if index != -1:
arr = pa.repeat(value, result_batch.num_rows)
result_batch = result_batch.set_column(index, name, arr)
yield result_batch
def _read_all_delete_files(io: FileIO, tasks: Iterable[FileScanTask]) -> Dict[str, List[ChunkedArray]]:
deletes_per_file: Dict[str, List[ChunkedArray]] = {}
unique_deletes = set(itertools.chain.from_iterable([task.delete_files for task in tasks]))
if len(unique_deletes) > 0:
executor = ExecutorFactory.get_or_create()
deletes_per_files: Iterator[Dict[str, ChunkedArray]] = executor.map(
lambda args: _read_deletes(*args),
[(_fs_from_file_path(io, delete_file.file_path), delete_file) for delete_file in unique_deletes],
)
for delete in deletes_per_files:
for file, arr in delete.items():
if file in deletes_per_file:
deletes_per_file[file].append(arr)
else:
deletes_per_file[file] = [arr]
return deletes_per_file
def _fs_from_file_path(io: FileIO, file_path: str) -> FileSystem:
scheme, netloc, _ = _parse_location(file_path)
if isinstance(io, PyArrowFileIO):
return io.fs_by_scheme(scheme, netloc)
else:
try:
from pyiceberg.io.fsspec import FsspecFileIO
if isinstance(io, FsspecFileIO):
from pyarrow.fs import PyFileSystem
return PyFileSystem(FSSpecHandler(io.get_fs(scheme)))
else:
raise ValueError(f"Expected PyArrowFileIO or FsspecFileIO, got: {io}")
except ModuleNotFoundError as e:
# When FsSpec is not installed
raise ValueError(f"Expected PyArrowFileIO or FsspecFileIO, got: {io}") from e
class ArrowScan:
_table_metadata: TableMetadata
_io: FileIO
_projected_schema: Schema
_bound_row_filter: BooleanExpression
_case_sensitive: bool
_limit: Optional[int]
"""Scan the Iceberg Table and create an Arrow construct.
Attributes:
_table_metadata: Current table metadata of the Iceberg table
_io: PyIceberg FileIO implementation from which to fetch the io properties
_projected_schema: Iceberg Schema to project onto the data files
_bound_row_filter: Schema bound row expression to filter the data with
_case_sensitive: Case sensitivity when looking up column names
_limit: Limit the number of records.
"""
def __init__(
self,
table_metadata: TableMetadata,
io: FileIO,
projected_schema: Schema,
row_filter: BooleanExpression,
case_sensitive: bool = True,
limit: Optional[int] = None,
) -> None:
self._table_metadata = table_metadata
self._io = io
self._projected_schema = projected_schema
self._bound_row_filter = bind(table_metadata.schema(), row_filter, case_sensitive=case_sensitive)
self._case_sensitive = case_sensitive
self._limit = limit
@property
def _projected_field_ids(self) -> Set[int]:
"""Set of field IDs that should be projected from the data files."""
return {
id
for id in self._projected_schema.field_ids
if not isinstance(self._projected_schema.find_type(id), (MapType, ListType))
}.union(extract_field_ids(self._bound_row_filter))
def to_table(self, tasks: Iterable[FileScanTask]) -> pa.Table:
"""Scan the Iceberg table and return a pa.Table.
Returns a pa.Table with data from the Iceberg table by resolving the
right columns that match the current table schema. Only data that
matches the provided row_filter expression is returned.
Args:
tasks: FileScanTasks representing the data files and delete files to read from.
Returns:
A PyArrow table. Total number of rows will be capped if specified.
Raises:
ResolveError: When a required field cannot be found in the file
ValueError: When a field type in the file cannot be projected to the schema type
"""
deletes_per_file = _read_all_delete_files(self._io, tasks)
executor = ExecutorFactory.get_or_create()
def _table_from_scan_task(task: FileScanTask) -> pa.Table:
batches = list(self._record_batches_from_scan_tasks_and_deletes([task], deletes_per_file))
if len(batches) > 0:
return pa.Table.from_batches(batches)
else:
return None
futures = [
executor.submit(
_table_from_scan_task,
task,
)
for task in tasks
]
total_row_count = 0
# for consistent ordering, we need to maintain future order
futures_index = {f: i for i, f in enumerate(futures)}
completed_futures: SortedList[Future[pa.Table]] = SortedList(iterable=[], key=lambda f: futures_index[f])
for future in concurrent.futures.as_completed(futures):
completed_futures.add(future)
if table_result := future.result():
total_row_count += len(table_result)
# stop early if limit is satisfied
if self._limit is not None and total_row_count >= self._limit:
break
# by now, we've either completed all tasks or satisfied the limit
if self._limit is not None:
_ = [f.cancel() for f in futures if not f.done()]
tables = [f.result() for f in completed_futures if f.result()]
arrow_schema = schema_to_pyarrow(self._projected_schema, include_field_ids=False)
if len(tables) < 1:
return pa.Table.from_batches([], schema=arrow_schema)
result = pa.concat_tables(tables, promote_options="permissive")
if property_as_bool(self._io.properties, PYARROW_USE_LARGE_TYPES_ON_READ, False):
deprecation_message(
deprecated_in="0.10.0",
removed_in="0.11.0",
help_message=f"Property `{PYARROW_USE_LARGE_TYPES_ON_READ}` will be removed.",
)
result = result.cast(arrow_schema)
if self._limit is not None:
return result.slice(0, self._limit)
return result
def to_record_batches(self, tasks: Iterable[FileScanTask]) -> Iterator[pa.RecordBatch]:
"""Scan the Iceberg table and return an Iterator[pa.RecordBatch].
Returns an Iterator of pa.RecordBatch with data from the Iceberg table
by resolving the right columns that match the current table schema.
Only data that matches the provided row_filter expression is returned.
Args:
tasks: FileScanTasks representing the data files and delete files to read from.
Returns:
An Iterator of PyArrow RecordBatches.
Total number of rows will be capped if specified.
Raises:
ResolveError: When a required field cannot be found in the file
ValueError: When a field type in the file cannot be projected to the schema type
"""
deletes_per_file = _read_all_delete_files(self._io, tasks)
return self._record_batches_from_scan_tasks_and_deletes(tasks, deletes_per_file)
def _record_batches_from_scan_tasks_and_deletes(
self, tasks: Iterable[FileScanTask], deletes_per_file: Dict[str, List[ChunkedArray]]
) -> Iterator[pa.RecordBatch]:
total_row_count = 0
for task in tasks:
if self._limit is not None and total_row_count >= self._limit:
break
batches = _task_to_record_batches(
_fs_from_file_path(self._io, task.file.file_path),
task,
self._bound_row_filter,
self._projected_schema,
self._projected_field_ids,
deletes_per_file.get(task.file.file_path),
self._case_sensitive,
self._table_metadata.name_mapping(),
self._table_metadata.spec(),
)
for batch in batches:
if self._limit is not None:
if total_row_count >= self._limit:
break
elif total_row_count + len(batch) >= self._limit:
batch = batch.slice(0, self._limit - total_row_count)
yield batch
total_row_count += len(batch)
def _to_requested_schema(
requested_schema: Schema,
file_schema: Schema,
batch: pa.RecordBatch,
downcast_ns_timestamp_to_us: bool = False,
include_field_ids: bool = False,
) -> pa.RecordBatch:
# We could reuse some of these visitors
struct_array = visit_with_partner(
requested_schema,
batch,
ArrowProjectionVisitor(file_schema, downcast_ns_timestamp_to_us, include_field_ids),
ArrowAccessor(file_schema),
)
return pa.RecordBatch.from_struct_array(struct_array)
class ArrowProjectionVisitor(SchemaWithPartnerVisitor[pa.Array, Optional[pa.Array]]):
_file_schema: Schema
_include_field_ids: bool
_downcast_ns_timestamp_to_us: bool
_use_large_types: Optional[bool]
def __init__(
self,
file_schema: Schema,
downcast_ns_timestamp_to_us: bool = False,
include_field_ids: bool = False,
use_large_types: Optional[bool] = None,
) -> None:
self._file_schema = file_schema
self._include_field_ids = include_field_ids
self._downcast_ns_timestamp_to_us = downcast_ns_timestamp_to_us
self._use_large_types = use_large_types
if use_large_types is not None:
deprecation_message(
deprecated_in="0.10.0",
removed_in="0.11.0",
help_message="Argument `use_large_types` will be removed from ArrowProjectionVisitor",
)
def _cast_if_needed(self, field: NestedField, values: pa.Array) -> pa.Array:
file_field = self._file_schema.find_field(field.field_id)
if field.field_type.is_primitive:
if field.field_type != file_field.field_type:
target_schema = schema_to_pyarrow(
promote(file_field.field_type, field.field_type), include_field_ids=self._include_field_ids
)
if self._use_large_types is False:
target_schema = _pyarrow_schema_ensure_small_types(target_schema)
return values.cast(target_schema)
elif (target_type := schema_to_pyarrow(field.field_type, include_field_ids=self._include_field_ids)) != values.type:
if field.field_type == TimestampType():
# Downcasting of nanoseconds to microseconds
if (
pa.types.is_timestamp(target_type)
and not target_type.tz
and pa.types.is_timestamp(values.type)
and not values.type.tz
):
if target_type.unit == "us" and values.type.unit == "ns" and self._downcast_ns_timestamp_to_us:
return values.cast(target_type, safe=False)
elif target_type.unit == "us" and values.type.unit in {"s", "ms"}:
return values.cast(target_type)
raise ValueError(f"Unsupported schema projection from {values.type} to {target_type}")
elif field.field_type == TimestamptzType():
if (
pa.types.is_timestamp(target_type)
and target_type.tz == "UTC"
and pa.types.is_timestamp(values.type)
and values.type.tz in UTC_ALIASES
):
if target_type.unit == "us" and values.type.unit == "ns" and self._downcast_ns_timestamp_to_us:
return values.cast(target_type, safe=False)
elif target_type.unit == "us" and values.type.unit in {"s", "ms", "us"}:
return values.cast(target_type)
raise ValueError(f"Unsupported schema projection from {values.type} to {target_type}")
return values
def _construct_field(self, field: NestedField, arrow_type: pa.DataType) -> pa.Field:
metadata = {}
if field.doc:
metadata[PYARROW_FIELD_DOC_KEY] = field.doc
if self._include_field_ids:
metadata[PYARROW_PARQUET_FIELD_ID_KEY] = str(field.field_id)
return pa.field(
name=field.name,
type=arrow_type,
nullable=field.optional,
metadata=metadata,
)
def schema(self, schema: Schema, schema_partner: Optional[pa.Array], struct_result: Optional[pa.Array]) -> Optional[pa.Array]:
return struct_result
def struct(
self, struct: StructType, struct_array: Optional[pa.Array], field_results: List[Optional[pa.Array]]
) -> Optional[pa.Array]:
if struct_array is None:
return None
field_arrays: List[pa.Array] = []
fields: List[pa.Field] = []
for field, field_array in zip(struct.fields, field_results):
if field_array is not None:
array = self._cast_if_needed(field, field_array)
field_arrays.append(array)
fields.append(self._construct_field(field, array.type))
elif field.optional:
arrow_type = schema_to_pyarrow(field.field_type, include_field_ids=self._include_field_ids)
field_arrays.append(pa.nulls(len(struct_array), type=arrow_type))
fields.append(self._construct_field(field, arrow_type))
else:
raise ResolveError(f"Field is required, and could not be found in the file: {field}")
return pa.StructArray.from_arrays(
arrays=field_arrays,
fields=pa.struct(fields),
mask=struct_array.is_null() if isinstance(struct_array, pa.StructArray) else None,
)
def field(self, field: NestedField, _: Optional[pa.Array], field_array: Optional[pa.Array]) -> Optional[pa.Array]:
return field_array
def list(self, list_type: ListType, list_array: Optional[pa.Array], value_array: Optional[pa.Array]) -> Optional[pa.Array]:
if isinstance(list_array, (pa.ListArray, pa.LargeListArray, pa.FixedSizeListArray)) and value_array is not None:
list_initializer = pa.large_list if isinstance(list_array, pa.LargeListArray) else pa.list_
if isinstance(value_array, pa.StructArray):
# This can be removed once this has been fixed:
# https://github.com/apache/arrow/issues/38809
list_array = pa.LargeListArray.from_arrays(list_array.offsets, value_array)
value_array = self._cast_if_needed(list_type.element_field, value_array)
arrow_field = list_initializer(self._construct_field(list_type.element_field, value_array.type))
return list_array.cast(arrow_field)
else:
return None
def map(
self, map_type: MapType, map_array: Optional[pa.Array], key_result: Optional[pa.Array], value_result: Optional[pa.Array]
) -> Optional[pa.Array]:
if isinstance(map_array, pa.MapArray) and key_result is not None and value_result is not None:
key_result = self._cast_if_needed(map_type.key_field, key_result)
value_result = self._cast_if_needed(map_type.value_field, value_result)
arrow_field = pa.map_(
self._construct_field(map_type.key_field, key_result.type),
self._construct_field(map_type.value_field, value_result.type),
)
if isinstance(value_result, pa.StructArray):
# Arrow does not allow reordering of fields, therefore we have to copy the array :(
return pa.MapArray.from_arrays(map_array.offsets, key_result, value_result, arrow_field)
else:
return map_array.cast(arrow_field)
else:
return None
def primitive(self, _: PrimitiveType, array: Optional[pa.Array]) -> Optional[pa.Array]:
return array
class ArrowAccessor(PartnerAccessor[pa.Array]):
file_schema: Schema
def __init__(self, file_schema: Schema):
self.file_schema = file_schema
def schema_partner(self, partner: Optional[pa.Array]) -> Optional[pa.Array]:
return partner
def field_partner(self, partner_struct: Optional[pa.Array], field_id: int, _: str) -> Optional[pa.Array]:
if partner_struct is not None:
# use the field name from the file schema
try:
name = self.file_schema.find_field(field_id).name
except ValueError:
return None
if isinstance(partner_struct, pa.StructArray):
return partner_struct.field(name)
elif isinstance(partner_struct, pa.Table):
return partner_struct.column(name).combine_chunks()
elif isinstance(partner_struct, pa.RecordBatch):
return partner_struct.column(name)
else:
raise ValueError(f"Cannot find {name} in expected partner_struct type {type(partner_struct)}")
return None
def list_element_partner(self, partner_list: Optional[pa.Array]) -> Optional[pa.Array]:
return partner_list.values if isinstance(partner_list, (pa.ListArray, pa.LargeListArray, pa.FixedSizeListArray)) else None
def map_key_partner(self, partner_map: Optional[pa.Array]) -> Optional[pa.Array]:
return partner_map.keys if isinstance(partner_map, pa.MapArray) else None
def map_value_partner(self, partner_map: Optional[pa.Array]) -> Optional[pa.Array]:
return partner_map.items if isinstance(partner_map, pa.MapArray) else None
def _primitive_to_physical(iceberg_type: PrimitiveType) -> str:
return visit(iceberg_type, _PRIMITIVE_TO_PHYSICAL_TYPE_VISITOR)
class PrimitiveToPhysicalType(SchemaVisitorPerPrimitiveType[str]):
def schema(self, schema: Schema, struct_result: str) -> str:
raise ValueError(f"Expected primitive-type, got: {schema}")
def struct(self, struct: StructType, field_results: List[str]) -> str:
raise ValueError(f"Expected primitive-type, got: {struct}")
def field(self, field: NestedField, field_result: str) -> str:
raise ValueError(f"Expected primitive-type, got: {field}")
def list(self, list_type: ListType, element_result: str) -> str:
raise ValueError(f"Expected primitive-type, got: {list_type}")
def map(self, map_type: MapType, key_result: str, value_result: str) -> str:
raise ValueError(f"Expected primitive-type, got: {map_type}")
def visit_fixed(self, fixed_type: FixedType) -> str:
return "FIXED_LEN_BYTE_ARRAY"
def visit_decimal(self, decimal_type: DecimalType) -> str:
return "INT32" if decimal_type.precision <= 9 else "INT64" if decimal_type.precision <= 18 else "FIXED_LEN_BYTE_ARRAY"
def visit_boolean(self, boolean_type: BooleanType) -> str:
return "BOOLEAN"
def visit_integer(self, integer_type: IntegerType) -> str:
return "INT32"
def visit_long(self, long_type: LongType) -> str:
return "INT64"
def visit_float(self, float_type: FloatType) -> str:
return "FLOAT"
def visit_double(self, double_type: DoubleType) -> str:
return "DOUBLE"
def visit_date(self, date_type: DateType) -> str:
return "INT32"
def visit_time(self, time_type: TimeType) -> str:
return "INT64"
def visit_timestamp(self, timestamp_type: TimestampType) -> str:
return "INT64"
def visit_timestamp_ns(self, timestamp_type: TimestampNanoType) -> str:
return "INT64"
def visit_timestamptz(self, timestamptz_type: TimestamptzType) -> str:
return "INT64"
def visit_timestamptz_ns(self, timestamptz_ns_type: TimestamptzNanoType) -> str:
return "INT64"
def visit_string(self, string_type: StringType) -> str:
return "BYTE_ARRAY"
def visit_uuid(self, uuid_type: UUIDType) -> str:
return "FIXED_LEN_BYTE_ARRAY"
def visit_binary(self, binary_type: BinaryType) -> str:
return "BYTE_ARRAY"
def visit_unknown(self, unknown_type: UnknownType) -> str:
return "UNKNOWN"
_PRIMITIVE_TO_PHYSICAL_TYPE_VISITOR = PrimitiveToPhysicalType()
class StatsAggregator:
current_min: Any
current_max: Any
trunc_length: Optional[int]
def __init__(self, iceberg_type: PrimitiveType, physical_type_string: str, trunc_length: Optional[int] = None) -> None:
self.current_min = None
self.current_max = None
self.trunc_length = trunc_length
expected_physical_type = _primitive_to_physical(iceberg_type)
if expected_physical_type != physical_type_string:
# Allow promotable physical types
# INT32 -> INT64 and FLOAT -> DOUBLE are safe type casts
if (physical_type_string == "INT32" and expected_physical_type == "INT64") or (
physical_type_string == "FLOAT" and expected_physical_type == "DOUBLE"
):
pass
else:
raise ValueError(
f"Unexpected physical type {physical_type_string} for {iceberg_type}, expected {expected_physical_type}"
)
self.primitive_type = iceberg_type
def serialize(self, value: Any) -> bytes:
return to_bytes(self.primitive_type, value)
def update_min(self, val: Optional[Any]) -> None:
if self.current_min is None:
self.current_min = val
elif val is not None:
self.current_min = min(val, self.current_min)
def update_max(self, val: Optional[Any]) -> None:
if self.current_max is None:
self.current_max = val
elif val is not None:
self.current_max = max(val, self.current_max)
def min_as_bytes(self) -> Optional[bytes]:
if self.current_min is None:
return None
return self.serialize(
self.current_min
if self.trunc_length is None
else TruncateTransform(width=self.trunc_length).transform(self.primitive_type)(self.current_min)
)
def max_as_bytes(self) -> Optional[bytes]:
if self.current_max is None:
return None
if self.primitive_type == StringType():
if not isinstance(self.current_max, str):
raise ValueError("Expected the current_max to be a string")
s_result = truncate_upper_bound_text_string(self.current_max, self.trunc_length)
return self.serialize(s_result) if s_result is not None else None
elif self.primitive_type == BinaryType():
if not isinstance(self.current_max, bytes):
raise ValueError("Expected the current_max to be bytes")
b_result = truncate_upper_bound_binary_string(self.current_max, self.trunc_length)
return self.serialize(b_result) if b_result is not None else None
else:
if self.trunc_length is not None:
raise ValueError(f"{self.primitive_type} cannot be truncated")
return self.serialize(self.current_max)
DEFAULT_TRUNCATION_LENGTH = 16
TRUNCATION_EXPR = r"^truncate\((\d+)\)$"
class MetricModeTypes(Enum):
TRUNCATE = "truncate"
NONE = "none"
COUNTS = "counts"
FULL = "full"
@dataclass(frozen=True)
class MetricsMode(Singleton):
type: MetricModeTypes
length: Optional[int] = None
def match_metrics_mode(mode: str) -> MetricsMode:
sanitized_mode = mode.strip().lower()
if sanitized_mode.startswith("truncate"):
m = re.match(TRUNCATION_EXPR, sanitized_mode)
if m:
length = int(m[1])
if length < 1:
raise ValueError("Truncation length must be larger than 0")
return MetricsMode(MetricModeTypes.TRUNCATE, int(m[1]))
else:
raise ValueError(f"Malformed truncate: {mode}")
elif sanitized_mode == "none":
return MetricsMode(MetricModeTypes.NONE)
elif sanitized_mode == "counts":
return MetricsMode(MetricModeTypes.COUNTS)
elif sanitized_mode == "full":
return MetricsMode(MetricModeTypes.FULL)
else:
raise ValueError(f"Unsupported metrics mode: {mode}")
@dataclass(frozen=True)
class StatisticsCollector:
field_id: int
iceberg_type: PrimitiveType
mode: MetricsMode
column_name: str
class PyArrowStatisticsCollector(PreOrderSchemaVisitor[List[StatisticsCollector]]):
_field_id: int = 0
_schema: Schema
_properties: Dict[str, str]
_default_mode: str
def __init__(self, schema: Schema, properties: Dict[str, str]):
from pyiceberg.table import TableProperties
self._schema = schema
self._properties = properties
self._default_mode = self._properties.get(
TableProperties.DEFAULT_WRITE_METRICS_MODE, TableProperties.DEFAULT_WRITE_METRICS_MODE_DEFAULT
)
def schema(self, schema: Schema, struct_result: Callable[[], List[StatisticsCollector]]) -> List[StatisticsCollector]:
return struct_result()
def struct(
self, struct: StructType, field_results: List[Callable[[], List[StatisticsCollector]]]
) -> List[StatisticsCollector]:
return list(itertools.chain(*[result() for result in field_results]))
def field(self, field: NestedField, field_result: Callable[[], List[StatisticsCollector]]) -> List[StatisticsCollector]:
self._field_id = field.field_id
return field_result()
def list(self, list_type: ListType, element_result: Callable[[], List[StatisticsCollector]]) -> List[StatisticsCollector]:
self._field_id = list_type.element_id
return element_result()
def map(
self,
map_type: MapType,
key_result: Callable[[], List[StatisticsCollector]],
value_result: Callable[[], List[StatisticsCollector]],
) -> List[StatisticsCollector]:
self._field_id = map_type.key_id
k = key_result()
self._field_id = map_type.value_id
v = value_result()
return k + v
def primitive(self, primitive: PrimitiveType) -> List[StatisticsCollector]:
from pyiceberg.table import TableProperties
column_name = self._schema.find_column_name(self._field_id)
if column_name is None:
return []
metrics_mode = match_metrics_mode(self._default_mode)
col_mode = self._properties.get(f"{TableProperties.METRICS_MODE_COLUMN_CONF_PREFIX}.{column_name}")
if col_mode:
metrics_mode = match_metrics_mode(col_mode)
if (
not (isinstance(primitive, StringType) or isinstance(primitive, BinaryType))
and metrics_mode.type == MetricModeTypes.TRUNCATE
):
metrics_mode = MetricsMode(MetricModeTypes.FULL)
is_nested = column_name.find(".") >= 0
if is_nested and metrics_mode.type in [MetricModeTypes.TRUNCATE, MetricModeTypes.FULL]:
metrics_mode = MetricsMode(MetricModeTypes.COUNTS)
return [StatisticsCollector(field_id=self._field_id, iceberg_type=primitive, mode=metrics_mode, column_name=column_name)]
def compute_statistics_plan(
schema: Schema,
table_properties: Dict[str, str],
) -> Dict[int, StatisticsCollector]:
"""
Compute the statistics plan for all columns.
The resulting list is assumed to have the same length and same order as the columns in the pyarrow table.
This allows the list to map from the column index to the Iceberg column ID.
For each element, the desired metrics collection that was provided by the user in the configuration
is computed and then adjusted according to the data type of the column. For nested columns the minimum
and maximum values are not computed. And truncation is only applied to text of binary strings.
Args:
table_properties (from pyiceberg.table.metadata.TableMetadata): The Iceberg table metadata properties.
They are required to compute the mapping of column position to iceberg schema type id. It's also
used to set the mode for column metrics collection
"""
stats_cols = pre_order_visit(schema, PyArrowStatisticsCollector(schema, table_properties))
result: Dict[int, StatisticsCollector] = {}
for stats_col in stats_cols:
result[stats_col.field_id] = stats_col
return result
@dataclass(frozen=True)
class ID2ParquetPath:
field_id: int
parquet_path: str
class ID2ParquetPathVisitor(PreOrderSchemaVisitor[List[ID2ParquetPath]]):
_field_id: int = 0
_path: List[str]
def __init__(self) -> None:
self._path = []
def schema(self, schema: Schema, struct_result: Callable[[], List[ID2ParquetPath]]) -> List[ID2ParquetPath]:
return struct_result()
def struct(self, struct: StructType, field_results: List[Callable[[], List[ID2ParquetPath]]]) -> List[ID2ParquetPath]:
return list(itertools.chain(*[result() for result in field_results]))
def field(self, field: NestedField, field_result: Callable[[], List[ID2ParquetPath]]) -> List[ID2ParquetPath]:
self._field_id = field.field_id
self._path.append(field.name)
result = field_result()
self._path.pop()
return result
def list(self, list_type: ListType, element_result: Callable[[], List[ID2ParquetPath]]) -> List[ID2ParquetPath]:
self._field_id = list_type.element_id
self._path.append("list.element")
result = element_result()
self._path.pop()
return result
def map(
self,
map_type: MapType,
key_result: Callable[[], List[ID2ParquetPath]],
value_result: Callable[[], List[ID2ParquetPath]],
) -> List[ID2ParquetPath]:
self._field_id = map_type.key_id
self._path.append("key_value.key")
k = key_result()
self._path.pop()
self._field_id = map_type.value_id
self._path.append("key_value.value")
v = value_result()
self._path.pop()
return k + v
def primitive(self, primitive: PrimitiveType) -> List[ID2ParquetPath]:
return [ID2ParquetPath(field_id=self._field_id, parquet_path=".".join(self._path))]
def parquet_path_to_id_mapping(
schema: Schema,
) -> Dict[str, int]:
"""
Compute the mapping of parquet column path to Iceberg ID.
For each column, the parquet file metadata has a path_in_schema attribute that follows
a specific naming scheme for nested columnds. This function computes a mapping of
the full paths to the corresponding Iceberg IDs.
Args:
schema (pyiceberg.schema.Schema): The current table schema.
"""
result: Dict[str, int] = {}
for pair in pre_order_visit(schema, ID2ParquetPathVisitor()):
result[pair.parquet_path] = pair.field_id
return result
@dataclass(frozen=True)
class DataFileStatistics:
record_count: int
column_sizes: Dict[int, int]
value_counts: Dict[int, int]
null_value_counts: Dict[int, int]
nan_value_counts: Dict[int, int]
column_aggregates: Dict[int, StatsAggregator]
split_offsets: List[int]
def _partition_value(self, partition_field: PartitionField, schema: Schema) -> Any:
if partition_field.source_id not in self.column_aggregates:
return None
source_field = schema.find_field(partition_field.source_id)
iceberg_transform = partition_field.transform
if not iceberg_transform.preserves_order:
raise ValueError(
f"Cannot infer partition value from parquet metadata for a non-linear Partition Field: {partition_field.name} with transform {partition_field.transform}"
)
transform_func = iceberg_transform.transform(source_field.field_type)
lower_value = transform_func(
partition_record_value(
partition_field=partition_field,
value=self.column_aggregates[partition_field.source_id].current_min,
schema=schema,
)
)
upper_value = transform_func(
partition_record_value(
partition_field=partition_field,
value=self.column_aggregates[partition_field.source_id].current_max,
schema=schema,
)
)
if lower_value != upper_value:
raise ValueError(
f"Cannot infer partition value from parquet metadata as there are more than one partition values for Partition Field: {partition_field.name}. {lower_value=}, {upper_value=}"
)
return lower_value
def partition(self, partition_spec: PartitionSpec, schema: Schema) -> Record:
return Record(*[self._partition_value(field, schema) for field in partition_spec.fields])
def to_serialized_dict(self) -> Dict[str, Any]:
lower_bounds = {}
upper_bounds = {}
for k, agg in self.column_aggregates.items():
_min = agg.min_as_bytes()
if _min is not None:
lower_bounds[k] = _min
_max = agg.max_as_bytes()
if _max is not None:
upper_bounds[k] = _max
return {
"record_count": self.record_count,
"column_sizes": self.column_sizes,
"value_counts": self.value_counts,
"null_value_counts": self.null_value_counts,
"nan_value_counts": self.nan_value_counts,
"lower_bounds": lower_bounds,
"upper_bounds": upper_bounds,
"split_offsets": self.split_offsets,
}
def data_file_statistics_from_parquet_metadata(
parquet_metadata: pq.FileMetaData,
stats_columns: Dict[int, StatisticsCollector],
parquet_column_mapping: Dict[str, int],
) -> DataFileStatistics:
"""
Compute and return DataFileStatistics that includes the following.
- record_count
- column_sizes
- value_counts
- null_value_counts
- nan_value_counts
- column_aggregates
- split_offsets
Args:
parquet_metadata (pyarrow.parquet.FileMetaData): A pyarrow metadata object.
stats_columns (Dict[int, StatisticsCollector]): The statistics gathering plan. It is required to
set the mode for column metrics collection
parquet_column_mapping (Dict[str, int]): The mapping of the parquet file name to the field ID
"""
column_sizes: Dict[int, int] = {}
value_counts: Dict[int, int] = {}
split_offsets: List[int] = []
null_value_counts: Dict[int, int] = {}
nan_value_counts: Dict[int, int] = {}
col_aggs = {}
invalidate_col: Set[int] = set()
for r in range(parquet_metadata.num_row_groups):
# References:
# https://github.com/apache/iceberg/blob/fc381a81a1fdb8f51a0637ca27cd30673bd7aad3/parquet/src/main/java/org/apache/iceberg/parquet/ParquetUtil.java#L232
# https://github.com/apache/parquet-mr/blob/ac29db4611f86a07cc6877b416aa4b183e09b353/parquet-hadoop/src/main/java/org/apache/parquet/hadoop/metadata/ColumnChunkMetaData.java#L184
row_group = parquet_metadata.row_group(r)
data_offset = row_group.column(0).data_page_offset
dictionary_offset = row_group.column(0).dictionary_page_offset
if row_group.column(0).has_dictionary_page and dictionary_offset < data_offset:
split_offsets.append(dictionary_offset)
else:
split_offsets.append(data_offset)
for pos in range(parquet_metadata.num_columns):
column = row_group.column(pos)
field_id = parquet_column_mapping[column.path_in_schema]
stats_col = stats_columns[field_id]
column_sizes.setdefault(field_id, 0)
column_sizes[field_id] += column.total_compressed_size
if stats_col.mode == MetricsMode(MetricModeTypes.NONE):
continue
value_counts[field_id] = value_counts.get(field_id, 0) + column.num_values
if column.is_stats_set:
try:
statistics = column.statistics
if statistics.has_null_count:
null_value_counts[field_id] = null_value_counts.get(field_id, 0) + statistics.null_count
if stats_col.mode == MetricsMode(MetricModeTypes.COUNTS):
continue
if field_id not in col_aggs:
col_aggs[field_id] = StatsAggregator(
stats_col.iceberg_type, statistics.physical_type, stats_col.mode.length
)
if isinstance(stats_col.iceberg_type, DecimalType) and statistics.physical_type != "FIXED_LEN_BYTE_ARRAY":
scale = stats_col.iceberg_type.scale
col_aggs[field_id].update_min(unscaled_to_decimal(statistics.min_raw, scale))
col_aggs[field_id].update_max(unscaled_to_decimal(statistics.max_raw, scale))
else:
col_aggs[field_id].update_min(statistics.min)
col_aggs[field_id].update_max(statistics.max)
except pyarrow.lib.ArrowNotImplementedError as e:
invalidate_col.add(field_id)
logger.warning(e)
else:
invalidate_col.add(field_id)
logger.warning("PyArrow statistics missing for column %d when writing file", pos)
split_offsets.sort()
for field_id in invalidate_col:
col_aggs.pop(field_id, None)
null_value_counts.pop(field_id, None)
return DataFileStatistics(
record_count=parquet_metadata.num_rows,
column_sizes=column_sizes,
value_counts=value_counts,
null_value_counts=null_value_counts,
nan_value_counts=nan_value_counts,
column_aggregates=col_aggs,
split_offsets=split_offsets,
)
def write_file(io: FileIO, table_metadata: TableMetadata, tasks: Iterator[WriteTask]) -> Iterator[DataFile]:
from pyiceberg.table import DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE, TableProperties
parquet_writer_kwargs = _get_parquet_writer_kwargs(table_metadata.properties)
row_group_size = property_as_int(
properties=table_metadata.properties,
property_name=TableProperties.PARQUET_ROW_GROUP_LIMIT,
default=TableProperties.PARQUET_ROW_GROUP_LIMIT_DEFAULT,
)
location_provider = load_location_provider(table_location=table_metadata.location, table_properties=table_metadata.properties)
def write_parquet(task: WriteTask) -> DataFile:
table_schema = table_metadata.schema()
# if schema needs to be transformed, use the transformed schema and adjust the arrow table accordingly
# otherwise use the original schema
if (sanitized_schema := sanitize_column_names(table_schema)) != table_schema:
file_schema = sanitized_schema
else:
file_schema = table_schema
downcast_ns_timestamp_to_us = Config().get_bool(DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE) or False
batches = [
_to_requested_schema(
requested_schema=file_schema,
file_schema=task.schema,
batch=batch,
downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us,
include_field_ids=True,
)
for batch in task.record_batches
]
arrow_table = pa.Table.from_batches(batches)
file_path = location_provider.new_data_location(
data_file_name=task.generate_data_file_filename("parquet"),
partition_key=task.partition_key,
)
fo = io.new_output(file_path)
with fo.create(overwrite=True) as fos:
with pq.ParquetWriter(fos, schema=arrow_table.schema, **parquet_writer_kwargs) as writer:
writer.write(arrow_table, row_group_size=row_group_size)
statistics = data_file_statistics_from_parquet_metadata(
parquet_metadata=writer.writer.metadata,
stats_columns=compute_statistics_plan(file_schema, table_metadata.properties),
parquet_column_mapping=parquet_path_to_id_mapping(file_schema),
)
data_file = DataFile.from_args(
content=DataFileContent.DATA,
file_path=file_path,
file_format=FileFormat.PARQUET,
partition=task.partition_key.partition if task.partition_key else Record(),
file_size_in_bytes=len(fo),
# After this has been fixed:
# https://github.com/apache/iceberg-python/issues/271
# sort_order_id=task.sort_order_id,
sort_order_id=None,
# Just copy these from the table for now
spec_id=table_metadata.default_spec_id,
equality_ids=None,
key_metadata=None,
**statistics.to_serialized_dict(),
)
return data_file
executor = ExecutorFactory.get_or_create()
data_files = executor.map(write_parquet, tasks)
return iter(data_files)
def bin_pack_arrow_table(tbl: pa.Table, target_file_size: int) -> Iterator[List[pa.RecordBatch]]:
from pyiceberg.utils.bin_packing import PackingIterator
avg_row_size_bytes = tbl.nbytes / tbl.num_rows
target_rows_per_file = target_file_size // avg_row_size_bytes
batches = tbl.to_batches(max_chunksize=target_rows_per_file)
bin_packed_record_batches = PackingIterator(
items=batches,
target_weight=target_file_size,
lookback=len(batches), # ignore lookback
weight_func=lambda x: x.nbytes,
largest_bin_first=False,
)
return bin_packed_record_batches
def _check_pyarrow_schema_compatible(
requested_schema: Schema, provided_schema: pa.Schema, downcast_ns_timestamp_to_us: bool = False
) -> None:
"""
Check if the `requested_schema` is compatible with `provided_schema`.
Two schemas are considered compatible when they are equal in terms of the Iceberg Schema type.
Raises:
ValueError: If the schemas are not compatible.
"""
name_mapping = requested_schema.name_mapping
try:
provided_schema = pyarrow_to_schema(
provided_schema, name_mapping=name_mapping, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us
)
except ValueError as e:
provided_schema = _pyarrow_to_schema_without_ids(provided_schema, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us)
additional_names = set(provided_schema._name_to_id.keys()) - set(requested_schema._name_to_id.keys())
raise ValueError(
f"PyArrow table contains more columns: {', '.join(sorted(additional_names))}. Update the schema first (hint, use union_by_name)."
) from e
_check_schema_compatible(requested_schema, provided_schema)
def parquet_files_to_data_files(io: FileIO, table_metadata: TableMetadata, file_paths: Iterator[str]) -> Iterator[DataFile]:
for file_path in file_paths:
data_file = parquet_file_to_data_file(io=io, table_metadata=table_metadata, file_path=file_path)
yield data_file
def parquet_file_to_data_file(io: FileIO, table_metadata: TableMetadata, file_path: str) -> DataFile:
input_file = io.new_input(file_path)
with input_file.open() as input_stream:
parquet_metadata = pq.read_metadata(input_stream)
arrow_schema = parquet_metadata.schema.to_arrow_schema()
if visit_pyarrow(arrow_schema, _HasIds()):
raise NotImplementedError(
f"Cannot add file {file_path} because it has field IDs. `add_files` only supports addition of files without field_ids"
)
schema = table_metadata.schema()
_check_pyarrow_schema_compatible(schema, arrow_schema)
statistics = data_file_statistics_from_parquet_metadata(
parquet_metadata=parquet_metadata,
stats_columns=compute_statistics_plan(schema, table_metadata.properties),
parquet_column_mapping=parquet_path_to_id_mapping(schema),
)
data_file = DataFile.from_args(
content=DataFileContent.DATA,
file_path=file_path,
file_format=FileFormat.PARQUET,
partition=statistics.partition(table_metadata.spec(), table_metadata.schema()),
file_size_in_bytes=len(input_file),
sort_order_id=None,
spec_id=table_metadata.default_spec_id,
equality_ids=None,
key_metadata=None,
**statistics.to_serialized_dict(),
)
return data_file
ICEBERG_UNCOMPRESSED_CODEC = "uncompressed"
PYARROW_UNCOMPRESSED_CODEC = "none"
def _get_parquet_writer_kwargs(table_properties: Properties) -> Dict[str, Any]:
from pyiceberg.table import TableProperties
for key_pattern in [
TableProperties.PARQUET_ROW_GROUP_SIZE_BYTES,
TableProperties.PARQUET_BLOOM_FILTER_MAX_BYTES,
f"{TableProperties.PARQUET_BLOOM_FILTER_COLUMN_ENABLED_PREFIX}.*",
]:
if unsupported_keys := fnmatch.filter(table_properties, key_pattern):
warnings.warn(f"Parquet writer option(s) {unsupported_keys} not implemented")
compression_codec = table_properties.get(TableProperties.PARQUET_COMPRESSION, TableProperties.PARQUET_COMPRESSION_DEFAULT)
compression_level = property_as_int(
properties=table_properties,
property_name=TableProperties.PARQUET_COMPRESSION_LEVEL,
default=TableProperties.PARQUET_COMPRESSION_LEVEL_DEFAULT,
)
if compression_codec == ICEBERG_UNCOMPRESSED_CODEC:
compression_codec = PYARROW_UNCOMPRESSED_CODEC
return {
"compression": compression_codec,
"compression_level": compression_level,
"data_page_size": property_as_int(
properties=table_properties,
property_name=TableProperties.PARQUET_PAGE_SIZE_BYTES,
default=TableProperties.PARQUET_PAGE_SIZE_BYTES_DEFAULT,
),
"dictionary_pagesize_limit": property_as_int(
properties=table_properties,
property_name=TableProperties.PARQUET_DICT_SIZE_BYTES,
default=TableProperties.PARQUET_DICT_SIZE_BYTES_DEFAULT,
),
"write_batch_size": property_as_int(
properties=table_properties,
property_name=TableProperties.PARQUET_PAGE_ROW_LIMIT,
default=TableProperties.PARQUET_PAGE_ROW_LIMIT_DEFAULT,
),
}
def _dataframe_to_data_files(
table_metadata: TableMetadata,
df: pa.Table,
io: FileIO,
write_uuid: Optional[uuid.UUID] = None,
counter: Optional[itertools.count[int]] = None,
) -> Iterable[DataFile]:
"""Convert a PyArrow table into a DataFile.
Returns:
An iterable that supplies datafiles that represent the table.
"""
from pyiceberg.table import DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE, TableProperties, WriteTask
counter = counter or itertools.count(0)
write_uuid = write_uuid or uuid.uuid4()
target_file_size: int = property_as_int( # type: ignore # The property is set with non-None value.
properties=table_metadata.properties,
property_name=TableProperties.WRITE_TARGET_FILE_SIZE_BYTES,
default=TableProperties.WRITE_TARGET_FILE_SIZE_BYTES_DEFAULT,
)
name_mapping = table_metadata.schema().name_mapping
downcast_ns_timestamp_to_us = Config().get_bool(DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE) or False
task_schema = pyarrow_to_schema(df.schema, name_mapping=name_mapping, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us)
if table_metadata.spec().is_unpartitioned():
yield from write_file(
io=io,
table_metadata=table_metadata,
tasks=iter(
[
WriteTask(write_uuid=write_uuid, task_id=next(counter), record_batches=batches, schema=task_schema)
for batches in bin_pack_arrow_table(df, target_file_size)
]
),
)
else:
partitions = _determine_partitions(spec=table_metadata.spec(), schema=table_metadata.schema(), arrow_table=df)
yield from write_file(
io=io,
table_metadata=table_metadata,
tasks=iter(
[
WriteTask(
write_uuid=write_uuid,
task_id=next(counter),
record_batches=batches,
partition_key=partition.partition_key,
schema=task_schema,
)
for partition in partitions
for batches in bin_pack_arrow_table(partition.arrow_table_partition, target_file_size)
]
),
)
@dataclass(frozen=True)
class _TablePartition:
partition_key: PartitionKey
arrow_table_partition: pa.Table
def _determine_partitions(spec: PartitionSpec, schema: Schema, arrow_table: pa.Table) -> List[_TablePartition]:
"""Based on the iceberg table partition spec, filter the arrow table into partitions with their keys.
Example:
Input:
An arrow table with partition key of ['n_legs', 'year'] and with data of
{'year': [2020, 2022, 2022, 2021, 2022, 2022, 2022, 2019, 2021],
'n_legs': [2, 2, 2, 4, 4, 4, 4, 5, 100],
'animal': ["Flamingo", "Parrot", "Parrot", "Dog", "Horse", "Horse", "Horse","Brittle stars", "Centipede"]}.
The algorithm:
- We determine the set of unique partition keys
- Then we produce a set of partitions by filtering on each of the combinations
- We combine the chunks to create a copy to avoid GIL congestion on the original table
"""
# Assign unique names to columns where the partition transform has been applied
# to avoid conflicts
partition_fields = [f"_partition_{field.name}" for field in spec.fields]
for partition, name in zip(spec.fields, partition_fields):
source_field = schema.find_field(partition.source_id)
arrow_table = arrow_table.append_column(
name, partition.transform.pyarrow_transform(source_field.field_type)(arrow_table[source_field.name])
)
unique_partition_fields = arrow_table.select(partition_fields).group_by(partition_fields).aggregate([])
table_partitions = []
# TODO: As a next step, we could also play around with yielding instead of materializing the full list
for unique_partition in unique_partition_fields.to_pylist():
partition_key = PartitionKey(
field_values=[
PartitionFieldValue(field=field, value=unique_partition[name])
for field, name in zip(spec.fields, partition_fields)
],
partition_spec=spec,
schema=schema,
)
filtered_table = arrow_table.filter(
functools.reduce(
operator.and_,
[
pc.field(partition_field_name) == unique_partition[partition_field_name]
if unique_partition[partition_field_name] is not None
else pc.field(partition_field_name).is_null()
for field, partition_field_name in zip(spec.fields, partition_fields)
],
)
)
filtered_table = filtered_table.drop_columns(partition_fields)
# The combine_chunks seems to be counter-intuitive to do, but it actually returns
# fresh buffers that don't interfere with each other when it is written out to file
table_partitions.append(
_TablePartition(partition_key=partition_key, arrow_table_partition=filtered_table.combine_chunks())
)
return table_partitions