pyiceberg/partitioning.py (308 lines of code) (raw):
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
import uuid
from abc import ABC, abstractmethod
from dataclasses import dataclass
from datetime import date, datetime, time
from functools import cached_property, singledispatch
from typing import Annotated, Any, Dict, Generic, List, Optional, Tuple, TypeVar, Union
from urllib.parse import quote_plus
from pydantic import (
BeforeValidator,
Field,
PlainSerializer,
WithJsonSchema,
model_validator,
)
from pyiceberg.schema import Schema
from pyiceberg.transforms import (
BucketTransform,
DayTransform,
HourTransform,
IdentityTransform,
MonthTransform,
Transform,
TruncateTransform,
UnknownTransform,
VoidTransform,
YearTransform,
parse_transform,
)
from pyiceberg.typedef import IcebergBaseModel, Record
from pyiceberg.types import (
DateType,
IcebergType,
NestedField,
PrimitiveType,
StructType,
TimestampType,
TimestamptzType,
TimeType,
UUIDType,
)
from pyiceberg.utils.datetime import date_to_days, datetime_to_micros, time_to_micros
INITIAL_PARTITION_SPEC_ID = 0
PARTITION_FIELD_ID_START: int = 1000
class PartitionField(IcebergBaseModel):
"""PartitionField represents how one partition value is derived from the source column via transformation.
Attributes:
source_id(int): The source column id of table's schema.
field_id(int): The partition field id across all the table partition specs.
transform(Transform): The transform used to produce partition values from source column.
name(str): The name of this partition field.
"""
source_id: int = Field(alias="source-id")
field_id: int = Field(alias="field-id")
transform: Annotated[ # type: ignore
Transform,
BeforeValidator(parse_transform),
PlainSerializer(lambda c: str(c), return_type=str), # pylint: disable=W0108
WithJsonSchema({"type": "string"}, mode="serialization"),
] = Field()
name: str = Field()
def __init__(
self,
source_id: Optional[int] = None,
field_id: Optional[int] = None,
transform: Optional[Transform[Any, Any]] = None,
name: Optional[str] = None,
**data: Any,
):
if source_id is not None:
data["source-id"] = source_id
if field_id is not None:
data["field-id"] = field_id
if transform is not None:
data["transform"] = transform
if name is not None:
data["name"] = name
super().__init__(**data)
@model_validator(mode="before")
@classmethod
def map_source_ids_onto_source_id(cls, data: Any) -> Any:
if isinstance(data, dict):
if "source-id" not in data and (source_ids := data["source-ids"]):
if isinstance(source_ids, list):
if len(source_ids) == 0:
raise ValueError("Empty source-ids is not allowed")
if len(source_ids) > 1:
raise ValueError("Multi argument transforms are not yet supported")
data["source-id"] = source_ids[0]
return data
def __str__(self) -> str:
"""Return the string representation of the PartitionField class."""
return f"{self.field_id}: {self.name}: {self.transform}({self.source_id})"
class PartitionSpec(IcebergBaseModel):
"""
PartitionSpec captures the transformation from table data to partition values.
Attributes:
spec_id(int): any change to PartitionSpec will produce a new specId.
fields(Tuple[PartitionField): list of partition fields to produce partition values.
"""
spec_id: int = Field(alias="spec-id", default=INITIAL_PARTITION_SPEC_ID)
fields: Tuple[PartitionField, ...] = Field(default_factory=tuple)
def __init__(
self,
*fields: PartitionField,
**data: Any,
):
if fields:
data["fields"] = tuple(fields)
super().__init__(**data)
def __eq__(self, other: Any) -> bool:
"""
Produce a boolean to return True if two objects are considered equal.
Note:
Equality of PartitionSpec is determined by spec_id and partition fields only.
"""
if not isinstance(other, PartitionSpec):
return False
return self.spec_id == other.spec_id and self.fields == other.fields
def __str__(self) -> str:
"""
Produce a human-readable string representation of PartitionSpec.
Note:
Only include list of partition fields in the PartitionSpec's string representation.
"""
result_str = "["
if self.fields:
result_str += "\n " + "\n ".join([str(field) for field in self.fields]) + "\n"
result_str += "]"
return result_str
def __repr__(self) -> str:
"""Return the string representation of the PartitionSpec class."""
fields = f"{', '.join(repr(column) for column in self.fields)}, " if self.fields else ""
return f"PartitionSpec({fields}spec_id={self.spec_id})"
def is_unpartitioned(self) -> bool:
return not self.fields
@property
def last_assigned_field_id(self) -> int:
if self.fields:
return max(pf.field_id for pf in self.fields)
return PARTITION_FIELD_ID_START - 1
@cached_property
def source_id_to_fields_map(self) -> Dict[int, List[PartitionField]]:
source_id_to_fields_map: Dict[int, List[PartitionField]] = {}
for partition_field in self.fields:
existing = source_id_to_fields_map.get(partition_field.source_id, [])
existing.append(partition_field)
source_id_to_fields_map[partition_field.source_id] = existing
return source_id_to_fields_map
def fields_by_source_id(self, field_id: int) -> List[PartitionField]:
return self.source_id_to_fields_map.get(field_id, [])
def compatible_with(self, other: PartitionSpec) -> bool:
"""Produce a boolean to return True if two PartitionSpec are considered compatible."""
if self == other:
return True
if len(self.fields) != len(other.fields):
return False
return all(
this_field.source_id == that_field.source_id
and this_field.transform == that_field.transform
and this_field.name == that_field.name
for this_field, that_field in zip(self.fields, other.fields)
)
def partition_type(self, schema: Schema) -> StructType:
"""Produce a struct of the PartitionSpec.
The partition fields should be optional:
- All partition transforms are required to produce null if the input value is null, so it can
happen when the source column is optional.
- Partition fields may be added later, in which case not all files would have the result field,
and it may be null.
There is a case where we can guarantee that a partition field in the first and only partition spec
that uses a required source column will never be null, but it doesn't seem worth tracking this case.
:param schema: The schema to bind to.
:return: A StructType that represents the PartitionSpec, with a NestedField for each PartitionField.
"""
nested_fields = []
for field in self.fields:
source_type = schema.find_type(field.source_id)
result_type = field.transform.result_type(source_type)
required = schema.find_field(field.source_id).required
nested_fields.append(NestedField(field.field_id, field.name, result_type, required=required))
return StructType(*nested_fields)
def partition_to_path(self, data: Record, schema: Schema) -> str:
partition_type = self.partition_type(schema)
field_types = partition_type.fields
field_strs = []
value_strs = []
for pos in range(len(self.fields)):
partition_field = self.fields[pos]
value_str = partition_field.transform.to_human_string(field_types[pos].field_type, value=data[pos])
value_strs.append(quote_plus(value_str, safe=""))
field_strs.append(quote_plus(partition_field.name, safe=""))
path = "/".join([field_str + "=" + value_str for field_str, value_str in zip(field_strs, value_strs)])
return path
UNPARTITIONED_PARTITION_SPEC = PartitionSpec(spec_id=0)
def assign_fresh_partition_spec_ids(spec: PartitionSpec, old_schema: Schema, fresh_schema: Schema) -> PartitionSpec:
partition_fields = []
for pos, field in enumerate(spec.fields):
original_column_name = old_schema.find_column_name(field.source_id)
if original_column_name is None:
raise ValueError(f"Could not find in old schema: {field}")
fresh_field = fresh_schema.find_field(original_column_name)
if fresh_field is None:
raise ValueError(f"Could not find field in fresh schema: {original_column_name}")
partition_fields.append(
PartitionField(
name=field.name,
source_id=fresh_field.field_id,
field_id=PARTITION_FIELD_ID_START + pos,
transform=field.transform,
)
)
return PartitionSpec(*partition_fields, spec_id=INITIAL_PARTITION_SPEC_ID)
T = TypeVar("T")
class PartitionSpecVisitor(Generic[T], ABC):
@abstractmethod
def identity(self, field_id: int, source_name: str, source_id: int) -> T:
"""Visit identity partition field."""
@abstractmethod
def bucket(self, field_id: int, source_name: str, source_id: int, num_buckets: int) -> T:
"""Visit bucket partition field."""
@abstractmethod
def truncate(self, field_id: int, source_name: str, source_id: int, width: int) -> T:
"""Visit truncate partition field."""
@abstractmethod
def year(self, field_id: int, source_name: str, source_id: int) -> T:
"""Visit year partition field."""
@abstractmethod
def month(self, field_id: int, source_name: str, source_id: int) -> T:
"""Visit month partition field."""
@abstractmethod
def day(self, field_id: int, source_name: str, source_id: int) -> T:
"""Visit day partition field."""
@abstractmethod
def hour(self, field_id: int, source_name: str, source_id: int) -> T:
"""Visit hour partition field."""
@abstractmethod
def always_null(self, field_id: int, source_name: str, source_id: int) -> T:
"""Visit void partition field."""
@abstractmethod
def unknown(self, field_id: int, source_name: str, source_id: int, transform: str) -> T:
"""Visit unknown partition field."""
raise ValueError(f"Unknown transform is not supported: {transform}")
class _PartitionNameGenerator(PartitionSpecVisitor[str]):
def identity(self, field_id: int, source_name: str, source_id: int) -> str:
return source_name
def bucket(self, field_id: int, source_name: str, source_id: int, num_buckets: int) -> str:
return f"{source_name}_bucket_{num_buckets}"
def truncate(self, field_id: int, source_name: str, source_id: int, width: int) -> str:
return source_name + "_trunc_" + str(width)
def year(self, field_id: int, source_name: str, source_id: int) -> str:
return source_name + "_year"
def month(self, field_id: int, source_name: str, source_id: int) -> str:
return source_name + "_month"
def day(self, field_id: int, source_name: str, source_id: int) -> str:
return source_name + "_day"
def hour(self, field_id: int, source_name: str, source_id: int) -> str:
return source_name + "_hour"
def always_null(self, field_id: int, source_name: str, source_id: int) -> str:
return source_name + "_null"
def unknown(self, field_id: int, source_name: str, source_id: int, transform: str) -> str:
return super().unknown(field_id, source_name, source_id, transform)
R = TypeVar("R")
@singledispatch
def _visit(spec: PartitionSpec, schema: Schema, visitor: PartitionSpecVisitor[R]) -> List[R]:
return [_visit_partition_field(schema, field, visitor) for field in spec.fields]
def _visit_partition_field(schema: Schema, field: PartitionField, visitor: PartitionSpecVisitor[R]) -> R:
source_name = schema.find_column_name(field.source_id)
if not source_name:
raise ValueError(f"Could not find field with id {field.source_id}")
transform = field.transform
if isinstance(transform, IdentityTransform):
return visitor.identity(field.field_id, source_name, field.source_id)
elif isinstance(transform, BucketTransform):
return visitor.bucket(field.field_id, source_name, field.source_id, transform.num_buckets)
elif isinstance(transform, TruncateTransform):
return visitor.truncate(field.field_id, source_name, field.source_id, transform.width)
elif isinstance(transform, DayTransform):
return visitor.day(field.field_id, source_name, field.source_id)
elif isinstance(transform, HourTransform):
return visitor.hour(field.field_id, source_name, field.source_id)
elif isinstance(transform, MonthTransform):
return visitor.month(field.field_id, source_name, field.source_id)
elif isinstance(transform, YearTransform):
return visitor.year(field.field_id, source_name, field.source_id)
elif isinstance(transform, VoidTransform):
return visitor.always_null(field.field_id, source_name, field.source_id)
elif isinstance(transform, UnknownTransform):
return visitor.unknown(field.field_id, source_name, field.source_id, repr(transform))
else:
raise ValueError(f"Unknown transform {transform}")
@dataclass(frozen=True)
class PartitionFieldValue:
field: PartitionField
value: Any
@dataclass(frozen=True)
class PartitionKey:
field_values: List[PartitionFieldValue]
partition_spec: PartitionSpec
schema: Schema
@cached_property
def partition(self) -> Record: # partition key transformed with iceberg internal representation as input
iceberg_typed_key_values = []
for raw_partition_field_value in self.field_values:
partition_fields = self.partition_spec.source_id_to_fields_map[raw_partition_field_value.field.source_id]
if len(partition_fields) != 1:
raise ValueError(f"Cannot have redundant partitions: {partition_fields}")
partition_field = partition_fields[0]
iceberg_typed_key_values.append(
partition_record_value(
partition_field=partition_field,
value=raw_partition_field_value.value,
schema=self.schema,
)
)
return Record(*iceberg_typed_key_values)
def to_path(self) -> str:
return self.partition_spec.partition_to_path(self.partition, self.schema)
def partition_record_value(partition_field: PartitionField, value: Any, schema: Schema) -> Any:
"""
Return the Partition Record representation of the value.
The value is first converted to internal partition representation.
For example, UUID is converted to bytes[16], DateType to days since epoch, etc.
Then the corresponding PartitionField's transform is applied to return
the final partition record value.
"""
iceberg_type = schema.find_field(name_or_id=partition_field.source_id).field_type
return _to_partition_representation(iceberg_type, value)
@singledispatch
def _to_partition_representation(type: IcebergType, value: Any) -> Any:
"""Strip the logical type into the physical type.
It can be that the value is already transformed into its physical type,
in this case it will return the original value. Keep in mind that the
bucket transform always will return an int, but an identity transform
can return date that still needs to be transformed into an int (days
since epoch).
"""
return TypeError(f"Unsupported partition field type: {type}")
@_to_partition_representation.register(TimestampType)
@_to_partition_representation.register(TimestamptzType)
def _(type: IcebergType, value: Optional[Union[int, datetime]]) -> Optional[int]:
if value is None:
return None
elif isinstance(value, int):
return value
elif isinstance(value, datetime):
return datetime_to_micros(value)
else:
raise ValueError(f"Type not recognized: {value}")
@_to_partition_representation.register(DateType)
def _(type: IcebergType, value: Optional[Union[int, date]]) -> Optional[int]:
if value is None:
return None
elif isinstance(value, int):
return value
elif isinstance(value, date):
return date_to_days(value)
else:
raise ValueError(f"Type not recognized: {value}")
@_to_partition_representation.register(TimeType)
def _(type: IcebergType, value: Optional[time]) -> Optional[int]:
return time_to_micros(value) if value is not None else None
@_to_partition_representation.register(UUIDType)
def _(type: IcebergType, value: Optional[uuid.UUID]) -> Optional[str]:
return str(value) if value is not None else None
@_to_partition_representation.register(PrimitiveType)
def _(type: IcebergType, value: Optional[Any]) -> Optional[Any]:
return value