pyiceberg/table/update/spec.py (239 lines of code) (raw):

# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. from __future__ import annotations from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Union from pyiceberg.expressions import ( Reference, ) from pyiceberg.partitioning import ( INITIAL_PARTITION_SPEC_ID, PARTITION_FIELD_ID_START, PartitionField, PartitionSpec, _PartitionNameGenerator, _visit_partition_field, ) from pyiceberg.schema import Schema from pyiceberg.table.update import ( AddPartitionSpecUpdate, AssertLastAssignedPartitionId, SetDefaultSpecUpdate, TableRequirement, TableUpdate, UpdatesAndRequirements, UpdateTableMetadata, ) from pyiceberg.transforms import IdentityTransform, TimeTransform, Transform, VoidTransform, parse_transform if TYPE_CHECKING: from pyiceberg.table import Transaction class UpdateSpec(UpdateTableMetadata["UpdateSpec"]): _transaction: Transaction _name_to_field: Dict[str, PartitionField] = {} _name_to_added_field: Dict[str, PartitionField] = {} _transform_to_field: Dict[Tuple[int, str], PartitionField] = {} _transform_to_added_field: Dict[Tuple[int, str], PartitionField] = {} _renames: Dict[str, str] = {} _added_time_fields: Dict[int, PartitionField] = {} _case_sensitive: bool _adds: List[PartitionField] _deletes: Set[int] _last_assigned_partition_id: int def __init__(self, transaction: Transaction, case_sensitive: bool = True) -> None: super().__init__(transaction) self._name_to_field = {field.name: field for field in transaction.table_metadata.spec().fields} self._name_to_added_field = {} self._transform_to_field = { (field.source_id, repr(field.transform)): field for field in transaction.table_metadata.spec().fields } self._transform_to_added_field = {} self._adds = [] self._deletes = set() self._last_assigned_partition_id = transaction.table_metadata.last_partition_id or PARTITION_FIELD_ID_START - 1 self._renames = {} self._transaction = transaction self._case_sensitive = case_sensitive self._added_time_fields = {} def add_field( self, source_column_name: str, transform: Union[str, Transform[Any, Any]], partition_field_name: Optional[str] = None, ) -> UpdateSpec: ref = Reference(source_column_name) bound_ref = ref.bind(self._transaction.table_metadata.schema(), self._case_sensitive) if isinstance(transform, str): transform = parse_transform(transform) # verify transform can actually bind it output_type = bound_ref.field.field_type if not transform.can_transform(output_type): raise ValueError(f"{transform} cannot transform {output_type} values from {bound_ref.field.name}") transform_key = (bound_ref.field.field_id, repr(transform)) existing_partition_field = self._transform_to_field.get(transform_key) if existing_partition_field and self._is_duplicate_partition(transform, existing_partition_field): raise ValueError(f"Duplicate partition field for ${ref.name}=${ref}, ${existing_partition_field} already exists") added = self._transform_to_added_field.get(transform_key) if added: raise ValueError(f"Already added partition: {added.name}") new_field = self._partition_field((bound_ref.field.field_id, transform), partition_field_name) if new_field.name in self._name_to_added_field: raise ValueError(f"Already added partition field with name: {new_field.name}") if isinstance(new_field.transform, TimeTransform): existing_time_field = self._added_time_fields.get(new_field.source_id) if existing_time_field: raise ValueError(f"Cannot add time partition field: {new_field.name} conflicts with {existing_time_field.name}") self._added_time_fields[new_field.source_id] = new_field self._transform_to_added_field[transform_key] = new_field existing_partition_field = self._name_to_field.get(new_field.name) if existing_partition_field and new_field.field_id not in self._deletes: if isinstance(existing_partition_field.transform, VoidTransform): self.rename_field( existing_partition_field.name, existing_partition_field.name + "_" + str(existing_partition_field.field_id) ) else: raise ValueError(f"Cannot add duplicate partition field name: {existing_partition_field.name}") self._name_to_added_field[new_field.name] = new_field self._adds.append(new_field) return self def add_identity(self, source_column_name: str) -> UpdateSpec: return self.add_field(source_column_name, IdentityTransform(), None) def remove_field(self, name: str) -> UpdateSpec: added = self._name_to_added_field.get(name) if added: raise ValueError(f"Cannot delete newly added field {name}") renamed = self._renames.get(name) if renamed: raise ValueError(f"Cannot rename and delete field {name}") field = self._name_to_field.get(name) if not field: raise ValueError(f"No such partition field: {name}") self._deletes.add(field.field_id) return self def rename_field(self, name: str, new_name: str) -> UpdateSpec: existing_field = self._name_to_field.get(new_name) if existing_field and isinstance(existing_field.transform, VoidTransform): return self.rename_field(name, name + "_" + str(existing_field.field_id)) added = self._name_to_added_field.get(name) if added: raise ValueError("Cannot rename recently added partitions") field = self._name_to_field.get(name) if not field: raise ValueError(f"Cannot find partition field {name}") if field.field_id in self._deletes: raise ValueError(f"Cannot delete and rename partition field {name}") self._renames[name] = new_name return self def _commit(self) -> UpdatesAndRequirements: new_spec = self._apply() updates: Tuple[TableUpdate, ...] = () requirements: Tuple[TableRequirement, ...] = () if self._transaction.table_metadata.default_spec_id != new_spec.spec_id: if new_spec.spec_id not in self._transaction.table_metadata.specs(): updates = ( AddPartitionSpecUpdate(spec=new_spec), SetDefaultSpecUpdate(spec_id=-1), ) else: updates = (SetDefaultSpecUpdate(spec_id=new_spec.spec_id),) required_last_assigned_partitioned_id = self._transaction.table_metadata.last_partition_id requirements = (AssertLastAssignedPartitionId(last_assigned_partition_id=required_last_assigned_partitioned_id),) return updates, requirements def _apply(self) -> PartitionSpec: def _check_and_add_partition_name(schema: Schema, name: str, source_id: int, partition_names: Set[str]) -> None: try: field = schema.find_field(name) except ValueError: field = None if source_id is not None and field is not None and field.field_id != source_id: raise ValueError(f"Cannot create identity partition from a different field in the schema {name}") elif field is not None and source_id != field.field_id: raise ValueError(f"Cannot create partition from name that exists in schema {name}") if not name: raise ValueError("Undefined name") if name in partition_names: raise ValueError(f"Partition name has to be unique: {name}") partition_names.add(name) def _add_new_field( schema: Schema, source_id: int, field_id: int, name: str, transform: Transform[Any, Any], partition_names: Set[str] ) -> PartitionField: _check_and_add_partition_name(schema, name, source_id, partition_names) return PartitionField(source_id, field_id, transform, name) partition_fields = [] partition_names: Set[str] = set() for field in self._transaction.table_metadata.spec().fields: if field.field_id not in self._deletes: renamed = self._renames.get(field.name) if renamed: new_field = _add_new_field( self._transaction.table_metadata.schema(), field.source_id, field.field_id, renamed, field.transform, partition_names, ) else: new_field = _add_new_field( self._transaction.table_metadata.schema(), field.source_id, field.field_id, field.name, field.transform, partition_names, ) partition_fields.append(new_field) elif self._transaction.table_metadata.format_version == 1: renamed = self._renames.get(field.name) if renamed: new_field = _add_new_field( self._transaction.table_metadata.schema(), field.source_id, field.field_id, renamed, VoidTransform(), partition_names, ) else: new_field = _add_new_field( self._transaction.table_metadata.schema(), field.source_id, field.field_id, field.name, VoidTransform(), partition_names, ) partition_fields.append(new_field) for added_field in self._adds: new_field = PartitionField( source_id=added_field.source_id, field_id=added_field.field_id, transform=added_field.transform, name=added_field.name, ) partition_fields.append(new_field) # Reuse spec id or create a new one. new_spec = PartitionSpec(*partition_fields) new_spec_id = INITIAL_PARTITION_SPEC_ID for spec in self._transaction.table_metadata.specs().values(): if new_spec.compatible_with(spec): new_spec_id = spec.spec_id break elif new_spec_id <= spec.spec_id: new_spec_id = spec.spec_id + 1 return PartitionSpec(*partition_fields, spec_id=new_spec_id) def _partition_field(self, transform_key: Tuple[int, Transform[Any, Any]], name: Optional[str]) -> PartitionField: if self._transaction.table_metadata.format_version == 2: source_id, transform = transform_key historical_fields = [] for spec in self._transaction.table_metadata.specs().values(): for field in spec.fields: historical_fields.append(field) for field in historical_fields: if field.source_id == source_id and repr(field.transform) == repr(transform): if name is None or field.name == name: return PartitionField(source_id, field.field_id, transform, field.name) new_field_id = self._new_field_id() if name is None: tmp_field = PartitionField(transform_key[0], new_field_id, transform_key[1], "unassigned_field_name") name = _visit_partition_field(self._transaction.table_metadata.schema(), tmp_field, _PartitionNameGenerator()) return PartitionField(transform_key[0], new_field_id, transform_key[1], name) def _new_field_id(self) -> int: self._last_assigned_partition_id += 1 return self._last_assigned_partition_id def _is_duplicate_partition(self, transform: Transform[Any, Any], partition_field: PartitionField) -> bool: return partition_field.field_id not in self._deletes and partition_field.transform == transform