pyiceberg/table/sorting.py (121 lines of code) (raw):

# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. # pylint: disable=keyword-arg-before-vararg from enum import Enum from typing import Annotated, Any, Callable, Dict, List, Optional, Union from pydantic import ( BeforeValidator, Field, PlainSerializer, WithJsonSchema, model_validator, ) from pyiceberg.schema import Schema from pyiceberg.transforms import IdentityTransform, Transform, parse_transform from pyiceberg.typedef import IcebergBaseModel from pyiceberg.types import IcebergType class SortDirection(Enum): ASC = "asc" DESC = "desc" def __str__(self) -> str: """Return the string representation of the SortDirection class.""" return self.name def __repr__(self) -> str: """Return the string representation of the SortDirection class.""" return f"SortDirection.{self.name}" class NullOrder(Enum): NULLS_FIRST = "nulls-first" NULLS_LAST = "nulls-last" def __str__(self) -> str: """Return the string representation of the NullOrder class.""" return self.name.replace("_", " ") def __repr__(self) -> str: """Return the string representation of the NullOrder class.""" return f"NullOrder.{self.name}" class SortField(IcebergBaseModel): """Sort order field. Args: source_id (int): Source column id from the table’s schema. transform (str): Transform that is used to produce values to be sorted on from the source column. This is the same transform as described in partition transforms. direction (SortDirection): Sort direction, that can only be either asc or desc. null_order (NullOrder): Null order that describes the order of null values when sorted. Can only be either nulls-first or nulls-last. """ def __init__( self, source_id: Optional[int] = None, transform: Optional[Union[Transform[Any, Any], Callable[[IcebergType], Transform[Any, Any]]]] = None, direction: Optional[SortDirection] = None, null_order: Optional[NullOrder] = None, **data: Any, ): if source_id is not None: data["source-id"] = source_id if transform is not None: data["transform"] = transform if direction is not None: data["direction"] = direction if null_order is not None: data["null-order"] = null_order super().__init__(**data) @model_validator(mode="before") def set_null_order(cls, values: Dict[str, Any]) -> Dict[str, Any]: values["direction"] = values["direction"] if values.get("direction") else SortDirection.ASC if not values.get("null-order"): values["null-order"] = NullOrder.NULLS_FIRST if values["direction"] == SortDirection.ASC else NullOrder.NULLS_LAST return values @model_validator(mode="before") @classmethod def map_source_ids_onto_source_id(cls, data: Any) -> Any: if isinstance(data, dict): if "source-id" not in data and (source_ids := data["source-ids"]): if isinstance(source_ids, list): if len(source_ids) == 0: raise ValueError("Empty source-ids is not allowed") if len(source_ids) > 1: raise ValueError("Multi argument transforms are not yet supported") data["source-id"] = source_ids[0] return data source_id: int = Field(alias="source-id") transform: Annotated[ # type: ignore Transform, BeforeValidator(parse_transform), PlainSerializer(lambda c: str(c), return_type=str), # pylint: disable=W0108 WithJsonSchema({"type": "string"}, mode="serialization"), ] = Field(default=IdentityTransform()) direction: SortDirection = Field() null_order: NullOrder = Field(alias="null-order") def __str__(self) -> str: """Return the string representation of the SortField class.""" if isinstance(self.transform, IdentityTransform): # In the case of an identity transform, we can omit the transform return f"{self.source_id} {self.direction} {self.null_order}" else: return f"{self.transform}({self.source_id}) {self.direction} {self.null_order}" INITIAL_SORT_ORDER_ID = 1 class SortOrder(IcebergBaseModel): """Describes how the data is sorted within the table. Users can sort their data within partitions by columns to gain performance. The order of the sort fields within the list defines the order in which the sort is applied to the data. Args: fields (List[SortField]): The fields how the table is sorted. Keyword Args: order_id (int): An unique id of the sort-order of a table. """ order_id: int = Field(alias="order-id", default=INITIAL_SORT_ORDER_ID) fields: List[SortField] = Field(default_factory=list) def __init__(self, *fields: SortField, **data: Any): if fields: data["fields"] = fields super().__init__(**data) @property def is_unsorted(self) -> bool: return len(self.fields) == 0 def __str__(self) -> str: """Return the string representation of the SortOrder class.""" result_str = "[" if self.fields: result_str += "\n " + "\n ".join([str(field) for field in self.fields]) + "\n" result_str += "]" return result_str def __repr__(self) -> str: """Return the string representation of the SortOrder class.""" fields = f"{', '.join(repr(column) for column in self.fields)}, " if self.fields else "" return f"SortOrder({fields}order_id={self.order_id})" UNSORTED_SORT_ORDER_ID = 0 UNSORTED_SORT_ORDER = SortOrder(order_id=UNSORTED_SORT_ORDER_ID) def assign_fresh_sort_order_ids( sort_order: SortOrder, old_schema: Schema, fresh_schema: Schema, sort_order_id: int = INITIAL_SORT_ORDER_ID ) -> SortOrder: if sort_order.is_unsorted: return UNSORTED_SORT_ORDER fresh_fields = [] for field in sort_order.fields: original_field = old_schema.find_column_name(field.source_id) if original_field is None: raise ValueError(f"Could not find in old schema: {field}") fresh_field = fresh_schema.find_field(original_field) if fresh_field is None: raise ValueError(f"Could not find field in fresh schema: {original_field}") fresh_fields.append( SortField( source_id=fresh_field.field_id, transform=field.transform, direction=field.direction, null_order=field.null_order, ) ) return SortOrder(*fresh_fields, order_id=sort_order_id)