pyiceberg/partitioning.py (308 lines of code) (raw):

# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. from __future__ import annotations import uuid from abc import ABC, abstractmethod from dataclasses import dataclass from datetime import date, datetime, time from functools import cached_property, singledispatch from typing import Annotated, Any, Dict, Generic, List, Optional, Tuple, TypeVar, Union from urllib.parse import quote_plus from pydantic import ( BeforeValidator, Field, PlainSerializer, WithJsonSchema, model_validator, ) from pyiceberg.schema import Schema from pyiceberg.transforms import ( BucketTransform, DayTransform, HourTransform, IdentityTransform, MonthTransform, Transform, TruncateTransform, UnknownTransform, VoidTransform, YearTransform, parse_transform, ) from pyiceberg.typedef import IcebergBaseModel, Record from pyiceberg.types import ( DateType, IcebergType, NestedField, PrimitiveType, StructType, TimestampType, TimestamptzType, TimeType, UUIDType, ) from pyiceberg.utils.datetime import date_to_days, datetime_to_micros, time_to_micros INITIAL_PARTITION_SPEC_ID = 0 PARTITION_FIELD_ID_START: int = 1000 class PartitionField(IcebergBaseModel): """PartitionField represents how one partition value is derived from the source column via transformation. Attributes: source_id(int): The source column id of table's schema. field_id(int): The partition field id across all the table partition specs. transform(Transform): The transform used to produce partition values from source column. name(str): The name of this partition field. """ source_id: int = Field(alias="source-id") field_id: int = Field(alias="field-id") transform: Annotated[ # type: ignore Transform, BeforeValidator(parse_transform), PlainSerializer(lambda c: str(c), return_type=str), # pylint: disable=W0108 WithJsonSchema({"type": "string"}, mode="serialization"), ] = Field() name: str = Field() def __init__( self, source_id: Optional[int] = None, field_id: Optional[int] = None, transform: Optional[Transform[Any, Any]] = None, name: Optional[str] = None, **data: Any, ): if source_id is not None: data["source-id"] = source_id if field_id is not None: data["field-id"] = field_id if transform is not None: data["transform"] = transform if name is not None: data["name"] = name super().__init__(**data) @model_validator(mode="before") @classmethod def map_source_ids_onto_source_id(cls, data: Any) -> Any: if isinstance(data, dict): if "source-id" not in data and (source_ids := data["source-ids"]): if isinstance(source_ids, list): if len(source_ids) == 0: raise ValueError("Empty source-ids is not allowed") if len(source_ids) > 1: raise ValueError("Multi argument transforms are not yet supported") data["source-id"] = source_ids[0] return data def __str__(self) -> str: """Return the string representation of the PartitionField class.""" return f"{self.field_id}: {self.name}: {self.transform}({self.source_id})" class PartitionSpec(IcebergBaseModel): """ PartitionSpec captures the transformation from table data to partition values. Attributes: spec_id(int): any change to PartitionSpec will produce a new specId. fields(Tuple[PartitionField): list of partition fields to produce partition values. """ spec_id: int = Field(alias="spec-id", default=INITIAL_PARTITION_SPEC_ID) fields: Tuple[PartitionField, ...] = Field(default_factory=tuple) def __init__( self, *fields: PartitionField, **data: Any, ): if fields: data["fields"] = tuple(fields) super().__init__(**data) def __eq__(self, other: Any) -> bool: """ Produce a boolean to return True if two objects are considered equal. Note: Equality of PartitionSpec is determined by spec_id and partition fields only. """ if not isinstance(other, PartitionSpec): return False return self.spec_id == other.spec_id and self.fields == other.fields def __str__(self) -> str: """ Produce a human-readable string representation of PartitionSpec. Note: Only include list of partition fields in the PartitionSpec's string representation. """ result_str = "[" if self.fields: result_str += "\n " + "\n ".join([str(field) for field in self.fields]) + "\n" result_str += "]" return result_str def __repr__(self) -> str: """Return the string representation of the PartitionSpec class.""" fields = f"{', '.join(repr(column) for column in self.fields)}, " if self.fields else "" return f"PartitionSpec({fields}spec_id={self.spec_id})" def is_unpartitioned(self) -> bool: return not self.fields @property def last_assigned_field_id(self) -> int: if self.fields: return max(pf.field_id for pf in self.fields) return PARTITION_FIELD_ID_START - 1 @cached_property def source_id_to_fields_map(self) -> Dict[int, List[PartitionField]]: source_id_to_fields_map: Dict[int, List[PartitionField]] = {} for partition_field in self.fields: existing = source_id_to_fields_map.get(partition_field.source_id, []) existing.append(partition_field) source_id_to_fields_map[partition_field.source_id] = existing return source_id_to_fields_map def fields_by_source_id(self, field_id: int) -> List[PartitionField]: return self.source_id_to_fields_map.get(field_id, []) def compatible_with(self, other: PartitionSpec) -> bool: """Produce a boolean to return True if two PartitionSpec are considered compatible.""" if self == other: return True if len(self.fields) != len(other.fields): return False return all( this_field.source_id == that_field.source_id and this_field.transform == that_field.transform and this_field.name == that_field.name for this_field, that_field in zip(self.fields, other.fields) ) def partition_type(self, schema: Schema) -> StructType: """Produce a struct of the PartitionSpec. The partition fields should be optional: - All partition transforms are required to produce null if the input value is null, so it can happen when the source column is optional. - Partition fields may be added later, in which case not all files would have the result field, and it may be null. There is a case where we can guarantee that a partition field in the first and only partition spec that uses a required source column will never be null, but it doesn't seem worth tracking this case. :param schema: The schema to bind to. :return: A StructType that represents the PartitionSpec, with a NestedField for each PartitionField. """ nested_fields = [] for field in self.fields: source_type = schema.find_type(field.source_id) result_type = field.transform.result_type(source_type) required = schema.find_field(field.source_id).required nested_fields.append(NestedField(field.field_id, field.name, result_type, required=required)) return StructType(*nested_fields) def partition_to_path(self, data: Record, schema: Schema) -> str: partition_type = self.partition_type(schema) field_types = partition_type.fields field_strs = [] value_strs = [] for pos in range(len(self.fields)): partition_field = self.fields[pos] value_str = partition_field.transform.to_human_string(field_types[pos].field_type, value=data[pos]) value_strs.append(quote_plus(value_str, safe="")) field_strs.append(quote_plus(partition_field.name, safe="")) path = "/".join([field_str + "=" + value_str for field_str, value_str in zip(field_strs, value_strs)]) return path UNPARTITIONED_PARTITION_SPEC = PartitionSpec(spec_id=0) def assign_fresh_partition_spec_ids(spec: PartitionSpec, old_schema: Schema, fresh_schema: Schema) -> PartitionSpec: partition_fields = [] for pos, field in enumerate(spec.fields): original_column_name = old_schema.find_column_name(field.source_id) if original_column_name is None: raise ValueError(f"Could not find in old schema: {field}") fresh_field = fresh_schema.find_field(original_column_name) if fresh_field is None: raise ValueError(f"Could not find field in fresh schema: {original_column_name}") partition_fields.append( PartitionField( name=field.name, source_id=fresh_field.field_id, field_id=PARTITION_FIELD_ID_START + pos, transform=field.transform, ) ) return PartitionSpec(*partition_fields, spec_id=INITIAL_PARTITION_SPEC_ID) T = TypeVar("T") class PartitionSpecVisitor(Generic[T], ABC): @abstractmethod def identity(self, field_id: int, source_name: str, source_id: int) -> T: """Visit identity partition field.""" @abstractmethod def bucket(self, field_id: int, source_name: str, source_id: int, num_buckets: int) -> T: """Visit bucket partition field.""" @abstractmethod def truncate(self, field_id: int, source_name: str, source_id: int, width: int) -> T: """Visit truncate partition field.""" @abstractmethod def year(self, field_id: int, source_name: str, source_id: int) -> T: """Visit year partition field.""" @abstractmethod def month(self, field_id: int, source_name: str, source_id: int) -> T: """Visit month partition field.""" @abstractmethod def day(self, field_id: int, source_name: str, source_id: int) -> T: """Visit day partition field.""" @abstractmethod def hour(self, field_id: int, source_name: str, source_id: int) -> T: """Visit hour partition field.""" @abstractmethod def always_null(self, field_id: int, source_name: str, source_id: int) -> T: """Visit void partition field.""" @abstractmethod def unknown(self, field_id: int, source_name: str, source_id: int, transform: str) -> T: """Visit unknown partition field.""" raise ValueError(f"Unknown transform is not supported: {transform}") class _PartitionNameGenerator(PartitionSpecVisitor[str]): def identity(self, field_id: int, source_name: str, source_id: int) -> str: return source_name def bucket(self, field_id: int, source_name: str, source_id: int, num_buckets: int) -> str: return f"{source_name}_bucket_{num_buckets}" def truncate(self, field_id: int, source_name: str, source_id: int, width: int) -> str: return source_name + "_trunc_" + str(width) def year(self, field_id: int, source_name: str, source_id: int) -> str: return source_name + "_year" def month(self, field_id: int, source_name: str, source_id: int) -> str: return source_name + "_month" def day(self, field_id: int, source_name: str, source_id: int) -> str: return source_name + "_day" def hour(self, field_id: int, source_name: str, source_id: int) -> str: return source_name + "_hour" def always_null(self, field_id: int, source_name: str, source_id: int) -> str: return source_name + "_null" def unknown(self, field_id: int, source_name: str, source_id: int, transform: str) -> str: return super().unknown(field_id, source_name, source_id, transform) R = TypeVar("R") @singledispatch def _visit(spec: PartitionSpec, schema: Schema, visitor: PartitionSpecVisitor[R]) -> List[R]: return [_visit_partition_field(schema, field, visitor) for field in spec.fields] def _visit_partition_field(schema: Schema, field: PartitionField, visitor: PartitionSpecVisitor[R]) -> R: source_name = schema.find_column_name(field.source_id) if not source_name: raise ValueError(f"Could not find field with id {field.source_id}") transform = field.transform if isinstance(transform, IdentityTransform): return visitor.identity(field.field_id, source_name, field.source_id) elif isinstance(transform, BucketTransform): return visitor.bucket(field.field_id, source_name, field.source_id, transform.num_buckets) elif isinstance(transform, TruncateTransform): return visitor.truncate(field.field_id, source_name, field.source_id, transform.width) elif isinstance(transform, DayTransform): return visitor.day(field.field_id, source_name, field.source_id) elif isinstance(transform, HourTransform): return visitor.hour(field.field_id, source_name, field.source_id) elif isinstance(transform, MonthTransform): return visitor.month(field.field_id, source_name, field.source_id) elif isinstance(transform, YearTransform): return visitor.year(field.field_id, source_name, field.source_id) elif isinstance(transform, VoidTransform): return visitor.always_null(field.field_id, source_name, field.source_id) elif isinstance(transform, UnknownTransform): return visitor.unknown(field.field_id, source_name, field.source_id, repr(transform)) else: raise ValueError(f"Unknown transform {transform}") @dataclass(frozen=True) class PartitionFieldValue: field: PartitionField value: Any @dataclass(frozen=True) class PartitionKey: field_values: List[PartitionFieldValue] partition_spec: PartitionSpec schema: Schema @cached_property def partition(self) -> Record: # partition key transformed with iceberg internal representation as input iceberg_typed_key_values = [] for raw_partition_field_value in self.field_values: partition_fields = self.partition_spec.source_id_to_fields_map[raw_partition_field_value.field.source_id] if len(partition_fields) != 1: raise ValueError(f"Cannot have redundant partitions: {partition_fields}") partition_field = partition_fields[0] iceberg_typed_key_values.append( partition_record_value( partition_field=partition_field, value=raw_partition_field_value.value, schema=self.schema, ) ) return Record(*iceberg_typed_key_values) def to_path(self) -> str: return self.partition_spec.partition_to_path(self.partition, self.schema) def partition_record_value(partition_field: PartitionField, value: Any, schema: Schema) -> Any: """ Return the Partition Record representation of the value. The value is first converted to internal partition representation. For example, UUID is converted to bytes[16], DateType to days since epoch, etc. Then the corresponding PartitionField's transform is applied to return the final partition record value. """ iceberg_type = schema.find_field(name_or_id=partition_field.source_id).field_type return _to_partition_representation(iceberg_type, value) @singledispatch def _to_partition_representation(type: IcebergType, value: Any) -> Any: """Strip the logical type into the physical type. It can be that the value is already transformed into its physical type, in this case it will return the original value. Keep in mind that the bucket transform always will return an int, but an identity transform can return date that still needs to be transformed into an int (days since epoch). """ return TypeError(f"Unsupported partition field type: {type}") @_to_partition_representation.register(TimestampType) @_to_partition_representation.register(TimestamptzType) def _(type: IcebergType, value: Optional[Union[int, datetime]]) -> Optional[int]: if value is None: return None elif isinstance(value, int): return value elif isinstance(value, datetime): return datetime_to_micros(value) else: raise ValueError(f"Type not recognized: {value}") @_to_partition_representation.register(DateType) def _(type: IcebergType, value: Optional[Union[int, date]]) -> Optional[int]: if value is None: return None elif isinstance(value, int): return value elif isinstance(value, date): return date_to_days(value) else: raise ValueError(f"Type not recognized: {value}") @_to_partition_representation.register(TimeType) def _(type: IcebergType, value: Optional[time]) -> Optional[int]: return time_to_micros(value) if value is not None else None @_to_partition_representation.register(UUIDType) def _(type: IcebergType, value: Optional[uuid.UUID]) -> Optional[str]: return str(value) if value is not None else None @_to_partition_representation.register(PrimitiveType) def _(type: IcebergType, value: Optional[Any]) -> Optional[Any]: return value