pyiceberg/table/update/spec.py (239 lines of code) (raw):
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Union
from pyiceberg.expressions import (
Reference,
)
from pyiceberg.partitioning import (
INITIAL_PARTITION_SPEC_ID,
PARTITION_FIELD_ID_START,
PartitionField,
PartitionSpec,
_PartitionNameGenerator,
_visit_partition_field,
)
from pyiceberg.schema import Schema
from pyiceberg.table.update import (
AddPartitionSpecUpdate,
AssertLastAssignedPartitionId,
SetDefaultSpecUpdate,
TableRequirement,
TableUpdate,
UpdatesAndRequirements,
UpdateTableMetadata,
)
from pyiceberg.transforms import IdentityTransform, TimeTransform, Transform, VoidTransform, parse_transform
if TYPE_CHECKING:
from pyiceberg.table import Transaction
class UpdateSpec(UpdateTableMetadata["UpdateSpec"]):
_transaction: Transaction
_name_to_field: Dict[str, PartitionField] = {}
_name_to_added_field: Dict[str, PartitionField] = {}
_transform_to_field: Dict[Tuple[int, str], PartitionField] = {}
_transform_to_added_field: Dict[Tuple[int, str], PartitionField] = {}
_renames: Dict[str, str] = {}
_added_time_fields: Dict[int, PartitionField] = {}
_case_sensitive: bool
_adds: List[PartitionField]
_deletes: Set[int]
_last_assigned_partition_id: int
def __init__(self, transaction: Transaction, case_sensitive: bool = True) -> None:
super().__init__(transaction)
self._name_to_field = {field.name: field for field in transaction.table_metadata.spec().fields}
self._name_to_added_field = {}
self._transform_to_field = {
(field.source_id, repr(field.transform)): field for field in transaction.table_metadata.spec().fields
}
self._transform_to_added_field = {}
self._adds = []
self._deletes = set()
self._last_assigned_partition_id = transaction.table_metadata.last_partition_id or PARTITION_FIELD_ID_START - 1
self._renames = {}
self._transaction = transaction
self._case_sensitive = case_sensitive
self._added_time_fields = {}
def add_field(
self,
source_column_name: str,
transform: Union[str, Transform[Any, Any]],
partition_field_name: Optional[str] = None,
) -> UpdateSpec:
ref = Reference(source_column_name)
bound_ref = ref.bind(self._transaction.table_metadata.schema(), self._case_sensitive)
if isinstance(transform, str):
transform = parse_transform(transform)
# verify transform can actually bind it
output_type = bound_ref.field.field_type
if not transform.can_transform(output_type):
raise ValueError(f"{transform} cannot transform {output_type} values from {bound_ref.field.name}")
transform_key = (bound_ref.field.field_id, repr(transform))
existing_partition_field = self._transform_to_field.get(transform_key)
if existing_partition_field and self._is_duplicate_partition(transform, existing_partition_field):
raise ValueError(f"Duplicate partition field for ${ref.name}=${ref}, ${existing_partition_field} already exists")
added = self._transform_to_added_field.get(transform_key)
if added:
raise ValueError(f"Already added partition: {added.name}")
new_field = self._partition_field((bound_ref.field.field_id, transform), partition_field_name)
if new_field.name in self._name_to_added_field:
raise ValueError(f"Already added partition field with name: {new_field.name}")
if isinstance(new_field.transform, TimeTransform):
existing_time_field = self._added_time_fields.get(new_field.source_id)
if existing_time_field:
raise ValueError(f"Cannot add time partition field: {new_field.name} conflicts with {existing_time_field.name}")
self._added_time_fields[new_field.source_id] = new_field
self._transform_to_added_field[transform_key] = new_field
existing_partition_field = self._name_to_field.get(new_field.name)
if existing_partition_field and new_field.field_id not in self._deletes:
if isinstance(existing_partition_field.transform, VoidTransform):
self.rename_field(
existing_partition_field.name, existing_partition_field.name + "_" + str(existing_partition_field.field_id)
)
else:
raise ValueError(f"Cannot add duplicate partition field name: {existing_partition_field.name}")
self._name_to_added_field[new_field.name] = new_field
self._adds.append(new_field)
return self
def add_identity(self, source_column_name: str) -> UpdateSpec:
return self.add_field(source_column_name, IdentityTransform(), None)
def remove_field(self, name: str) -> UpdateSpec:
added = self._name_to_added_field.get(name)
if added:
raise ValueError(f"Cannot delete newly added field {name}")
renamed = self._renames.get(name)
if renamed:
raise ValueError(f"Cannot rename and delete field {name}")
field = self._name_to_field.get(name)
if not field:
raise ValueError(f"No such partition field: {name}")
self._deletes.add(field.field_id)
return self
def rename_field(self, name: str, new_name: str) -> UpdateSpec:
existing_field = self._name_to_field.get(new_name)
if existing_field and isinstance(existing_field.transform, VoidTransform):
return self.rename_field(name, name + "_" + str(existing_field.field_id))
added = self._name_to_added_field.get(name)
if added:
raise ValueError("Cannot rename recently added partitions")
field = self._name_to_field.get(name)
if not field:
raise ValueError(f"Cannot find partition field {name}")
if field.field_id in self._deletes:
raise ValueError(f"Cannot delete and rename partition field {name}")
self._renames[name] = new_name
return self
def _commit(self) -> UpdatesAndRequirements:
new_spec = self._apply()
updates: Tuple[TableUpdate, ...] = ()
requirements: Tuple[TableRequirement, ...] = ()
if self._transaction.table_metadata.default_spec_id != new_spec.spec_id:
if new_spec.spec_id not in self._transaction.table_metadata.specs():
updates = (
AddPartitionSpecUpdate(spec=new_spec),
SetDefaultSpecUpdate(spec_id=-1),
)
else:
updates = (SetDefaultSpecUpdate(spec_id=new_spec.spec_id),)
required_last_assigned_partitioned_id = self._transaction.table_metadata.last_partition_id
requirements = (AssertLastAssignedPartitionId(last_assigned_partition_id=required_last_assigned_partitioned_id),)
return updates, requirements
def _apply(self) -> PartitionSpec:
def _check_and_add_partition_name(schema: Schema, name: str, source_id: int, partition_names: Set[str]) -> None:
try:
field = schema.find_field(name)
except ValueError:
field = None
if source_id is not None and field is not None and field.field_id != source_id:
raise ValueError(f"Cannot create identity partition from a different field in the schema {name}")
elif field is not None and source_id != field.field_id:
raise ValueError(f"Cannot create partition from name that exists in schema {name}")
if not name:
raise ValueError("Undefined name")
if name in partition_names:
raise ValueError(f"Partition name has to be unique: {name}")
partition_names.add(name)
def _add_new_field(
schema: Schema, source_id: int, field_id: int, name: str, transform: Transform[Any, Any], partition_names: Set[str]
) -> PartitionField:
_check_and_add_partition_name(schema, name, source_id, partition_names)
return PartitionField(source_id, field_id, transform, name)
partition_fields = []
partition_names: Set[str] = set()
for field in self._transaction.table_metadata.spec().fields:
if field.field_id not in self._deletes:
renamed = self._renames.get(field.name)
if renamed:
new_field = _add_new_field(
self._transaction.table_metadata.schema(),
field.source_id,
field.field_id,
renamed,
field.transform,
partition_names,
)
else:
new_field = _add_new_field(
self._transaction.table_metadata.schema(),
field.source_id,
field.field_id,
field.name,
field.transform,
partition_names,
)
partition_fields.append(new_field)
elif self._transaction.table_metadata.format_version == 1:
renamed = self._renames.get(field.name)
if renamed:
new_field = _add_new_field(
self._transaction.table_metadata.schema(),
field.source_id,
field.field_id,
renamed,
VoidTransform(),
partition_names,
)
else:
new_field = _add_new_field(
self._transaction.table_metadata.schema(),
field.source_id,
field.field_id,
field.name,
VoidTransform(),
partition_names,
)
partition_fields.append(new_field)
for added_field in self._adds:
new_field = PartitionField(
source_id=added_field.source_id,
field_id=added_field.field_id,
transform=added_field.transform,
name=added_field.name,
)
partition_fields.append(new_field)
# Reuse spec id or create a new one.
new_spec = PartitionSpec(*partition_fields)
new_spec_id = INITIAL_PARTITION_SPEC_ID
for spec in self._transaction.table_metadata.specs().values():
if new_spec.compatible_with(spec):
new_spec_id = spec.spec_id
break
elif new_spec_id <= spec.spec_id:
new_spec_id = spec.spec_id + 1
return PartitionSpec(*partition_fields, spec_id=new_spec_id)
def _partition_field(self, transform_key: Tuple[int, Transform[Any, Any]], name: Optional[str]) -> PartitionField:
if self._transaction.table_metadata.format_version == 2:
source_id, transform = transform_key
historical_fields = []
for spec in self._transaction.table_metadata.specs().values():
for field in spec.fields:
historical_fields.append(field)
for field in historical_fields:
if field.source_id == source_id and repr(field.transform) == repr(transform):
if name is None or field.name == name:
return PartitionField(source_id, field.field_id, transform, field.name)
new_field_id = self._new_field_id()
if name is None:
tmp_field = PartitionField(transform_key[0], new_field_id, transform_key[1], "unassigned_field_name")
name = _visit_partition_field(self._transaction.table_metadata.schema(), tmp_field, _PartitionNameGenerator())
return PartitionField(transform_key[0], new_field_id, transform_key[1], name)
def _new_field_id(self) -> int:
self._last_assigned_partition_id += 1
return self._last_assigned_partition_id
def _is_duplicate_partition(self, transform: Transform[Any, Any], partition_field: PartitionField) -> bool:
return partition_field.field_id not in self._deletes and partition_field.transform == transform