pyiceberg/table/update/__init__.py (552 lines of code) (raw):
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
import itertools
import uuid
from abc import ABC, abstractmethod
from datetime import datetime
from functools import singledispatch
from typing import TYPE_CHECKING, Annotated, Any, Dict, Generic, List, Literal, Optional, Tuple, TypeVar, Union, cast
from pydantic import Field, field_validator, model_validator
from pyiceberg.exceptions import CommitFailedException
from pyiceberg.partitioning import PARTITION_FIELD_ID_START, PartitionSpec
from pyiceberg.schema import Schema
from pyiceberg.table.metadata import SUPPORTED_TABLE_FORMAT_VERSION, TableMetadata, TableMetadataUtil
from pyiceberg.table.refs import MAIN_BRANCH, SnapshotRef
from pyiceberg.table.snapshots import (
MetadataLogEntry,
Snapshot,
SnapshotLogEntry,
)
from pyiceberg.table.sorting import SortOrder
from pyiceberg.table.statistics import StatisticsFile, filter_statistics_by_snapshot_id
from pyiceberg.typedef import (
IcebergBaseModel,
Properties,
)
from pyiceberg.types import (
transform_dict_value_to_str,
)
from pyiceberg.utils.datetime import datetime_to_millis
from pyiceberg.utils.deprecated import deprecation_notice
from pyiceberg.utils.properties import property_as_int
if TYPE_CHECKING:
from pyiceberg.table import Transaction
U = TypeVar("U")
class UpdateTableMetadata(ABC, Generic[U]):
_transaction: Transaction
def __init__(self, transaction: Transaction) -> None:
self._transaction = transaction
@abstractmethod
def _commit(self) -> UpdatesAndRequirements: ...
def commit(self) -> None:
self._transaction._apply(*self._commit())
def __exit__(self, _: Any, value: Any, traceback: Any) -> None:
"""Close and commit the change."""
self.commit()
def __enter__(self) -> U:
"""Update the table."""
return self # type: ignore
class AssignUUIDUpdate(IcebergBaseModel):
action: Literal["assign-uuid"] = Field(default="assign-uuid")
uuid: uuid.UUID
class UpgradeFormatVersionUpdate(IcebergBaseModel):
action: Literal["upgrade-format-version"] = Field(default="upgrade-format-version")
format_version: int = Field(alias="format-version")
class AddSchemaUpdate(IcebergBaseModel):
action: Literal["add-schema"] = Field(default="add-schema")
schema_: Schema = Field(alias="schema")
# This field is required: https://github.com/apache/iceberg/pull/7445
last_column_id: Optional[int] = Field(
alias="last-column-id",
default=None,
deprecated=deprecation_notice(
deprecated_in="0.9.0",
removed_in="0.10.0",
help_message="last-field-id is handled internally, and should not be part of the update.",
),
)
class SetCurrentSchemaUpdate(IcebergBaseModel):
action: Literal["set-current-schema"] = Field(default="set-current-schema")
schema_id: int = Field(
alias="schema-id", description="Schema ID to set as current, or -1 to set last added schema", default=-1
)
class AddPartitionSpecUpdate(IcebergBaseModel):
action: Literal["add-spec"] = Field(default="add-spec")
spec: PartitionSpec
class SetDefaultSpecUpdate(IcebergBaseModel):
action: Literal["set-default-spec"] = Field(default="set-default-spec")
spec_id: int = Field(
alias="spec-id", description="Partition spec ID to set as the default, or -1 to set last added spec", default=-1
)
class AddSortOrderUpdate(IcebergBaseModel):
action: Literal["add-sort-order"] = Field(default="add-sort-order")
sort_order: SortOrder = Field(alias="sort-order")
class SetDefaultSortOrderUpdate(IcebergBaseModel):
action: Literal["set-default-sort-order"] = Field(default="set-default-sort-order")
sort_order_id: int = Field(
alias="sort-order-id", description="Sort order ID to set as the default, or -1 to set last added sort order", default=-1
)
class AddSnapshotUpdate(IcebergBaseModel):
action: Literal["add-snapshot"] = Field(default="add-snapshot")
snapshot: Snapshot
class SetSnapshotRefUpdate(IcebergBaseModel):
action: Literal["set-snapshot-ref"] = Field(default="set-snapshot-ref")
ref_name: str = Field(alias="ref-name")
type: Literal["tag", "branch"]
snapshot_id: int = Field(alias="snapshot-id")
max_ref_age_ms: Annotated[Optional[int], Field(alias="max-ref-age-ms", default=None)]
max_snapshot_age_ms: Annotated[Optional[int], Field(alias="max-snapshot-age-ms", default=None)]
min_snapshots_to_keep: Annotated[Optional[int], Field(alias="min-snapshots-to-keep", default=None)]
class RemoveSnapshotsUpdate(IcebergBaseModel):
action: Literal["remove-snapshots"] = Field(default="remove-snapshots")
snapshot_ids: List[int] = Field(alias="snapshot-ids")
class RemoveSnapshotRefUpdate(IcebergBaseModel):
action: Literal["remove-snapshot-ref"] = Field(default="remove-snapshot-ref")
ref_name: str = Field(alias="ref-name")
class SetLocationUpdate(IcebergBaseModel):
action: Literal["set-location"] = Field(default="set-location")
location: str
class SetPropertiesUpdate(IcebergBaseModel):
action: Literal["set-properties"] = Field(default="set-properties")
updates: Dict[str, str]
@field_validator("updates", mode="before")
def transform_properties_dict_value_to_str(cls, properties: Properties) -> Dict[str, str]:
return transform_dict_value_to_str(properties)
class RemovePropertiesUpdate(IcebergBaseModel):
action: Literal["remove-properties"] = Field(default="remove-properties")
removals: List[str]
class SetStatisticsUpdate(IcebergBaseModel):
action: Literal["set-statistics"] = Field(default="set-statistics")
statistics: StatisticsFile
snapshot_id: Optional[int] = Field(
None,
alias="snapshot-id",
description="snapshot-id is **DEPRECATED for REMOVAL** since it contains redundant information. Use `statistics.snapshot-id` field instead.",
)
@model_validator(mode="before")
def validate_snapshot_id(cls, data: Dict[str, Any]) -> Dict[str, Any]:
stats = cast(StatisticsFile, data["statistics"])
data["snapshot_id"] = stats.snapshot_id
return data
class RemoveStatisticsUpdate(IcebergBaseModel):
action: Literal["remove-statistics"] = Field(default="remove-statistics")
snapshot_id: int = Field(alias="snapshot-id")
TableUpdate = Annotated[
Union[
AssignUUIDUpdate,
UpgradeFormatVersionUpdate,
AddSchemaUpdate,
SetCurrentSchemaUpdate,
AddPartitionSpecUpdate,
SetDefaultSpecUpdate,
AddSortOrderUpdate,
SetDefaultSortOrderUpdate,
AddSnapshotUpdate,
SetSnapshotRefUpdate,
RemoveSnapshotsUpdate,
RemoveSnapshotRefUpdate,
SetLocationUpdate,
SetPropertiesUpdate,
RemovePropertiesUpdate,
SetStatisticsUpdate,
RemoveStatisticsUpdate,
],
Field(discriminator="action"),
]
class _TableMetadataUpdateContext:
_updates: List[TableUpdate]
def __init__(self) -> None:
self._updates = []
def add_update(self, update: TableUpdate) -> None:
self._updates.append(update)
def is_added_snapshot(self, snapshot_id: int) -> bool:
return any(
update.snapshot.snapshot_id == snapshot_id for update in self._updates if isinstance(update, AddSnapshotUpdate)
)
def is_added_schema(self, schema_id: int) -> bool:
return any(update.schema_.schema_id == schema_id for update in self._updates if isinstance(update, AddSchemaUpdate))
def is_added_partition_spec(self, spec_id: int) -> bool:
return any(update.spec.spec_id == spec_id for update in self._updates if isinstance(update, AddPartitionSpecUpdate))
def is_added_sort_order(self, sort_order_id: int) -> bool:
return any(
update.sort_order.order_id == sort_order_id for update in self._updates if isinstance(update, AddSortOrderUpdate)
)
def has_changes(self) -> bool:
return len(self._updates) > 0
@singledispatch
def _apply_table_update(update: TableUpdate, base_metadata: TableMetadata, context: _TableMetadataUpdateContext) -> TableMetadata:
"""Apply a table update to the table metadata.
Args:
update: The update to be applied.
base_metadata: The base metadata to be updated.
context: Contains previous updates and other change tracking information in the current transaction.
Returns:
The updated metadata.
"""
raise NotImplementedError(f"Unsupported table update: {update}")
@_apply_table_update.register(AssignUUIDUpdate)
def _(update: AssignUUIDUpdate, base_metadata: TableMetadata, context: _TableMetadataUpdateContext) -> TableMetadata:
if update.uuid == base_metadata.table_uuid:
return base_metadata
context.add_update(update)
return base_metadata.model_copy(update={"table_uuid": update.uuid})
@_apply_table_update.register(SetLocationUpdate)
def _(update: SetLocationUpdate, base_metadata: TableMetadata, context: _TableMetadataUpdateContext) -> TableMetadata:
context.add_update(update)
return base_metadata.model_copy(update={"location": update.location})
@_apply_table_update.register(UpgradeFormatVersionUpdate)
def _(
update: UpgradeFormatVersionUpdate,
base_metadata: TableMetadata,
context: _TableMetadataUpdateContext,
) -> TableMetadata:
if update.format_version > SUPPORTED_TABLE_FORMAT_VERSION:
raise ValueError(f"Unsupported table format version: {update.format_version}")
elif update.format_version < base_metadata.format_version:
raise ValueError(f"Cannot downgrade v{base_metadata.format_version} table to v{update.format_version}")
elif update.format_version == base_metadata.format_version:
return base_metadata
updated_metadata = base_metadata.model_copy(update={"format_version": update.format_version})
context.add_update(update)
return TableMetadataUtil._construct_without_validation(updated_metadata)
@_apply_table_update.register(SetPropertiesUpdate)
def _(update: SetPropertiesUpdate, base_metadata: TableMetadata, context: _TableMetadataUpdateContext) -> TableMetadata:
if len(update.updates) == 0:
return base_metadata
properties = dict(base_metadata.properties)
properties.update(update.updates)
context.add_update(update)
return base_metadata.model_copy(update={"properties": properties})
@_apply_table_update.register(RemovePropertiesUpdate)
def _(update: RemovePropertiesUpdate, base_metadata: TableMetadata, context: _TableMetadataUpdateContext) -> TableMetadata:
if len(update.removals) == 0:
return base_metadata
properties = dict(base_metadata.properties)
for key in update.removals:
properties.pop(key)
context.add_update(update)
return base_metadata.model_copy(update={"properties": properties})
@_apply_table_update.register(AddSchemaUpdate)
def _(update: AddSchemaUpdate, base_metadata: TableMetadata, context: _TableMetadataUpdateContext) -> TableMetadata:
metadata_updates: Dict[str, Any] = {
"last_column_id": max(base_metadata.last_column_id, update.schema_.highest_field_id),
"schemas": base_metadata.schemas + [update.schema_],
}
context.add_update(update)
return base_metadata.model_copy(update=metadata_updates)
@_apply_table_update.register(SetCurrentSchemaUpdate)
def _(update: SetCurrentSchemaUpdate, base_metadata: TableMetadata, context: _TableMetadataUpdateContext) -> TableMetadata:
new_schema_id = update.schema_id
if new_schema_id == -1:
# The last added schema should be in base_metadata.schemas at this point
new_schema_id = max(schema.schema_id for schema in base_metadata.schemas)
if not context.is_added_schema(new_schema_id):
raise ValueError("Cannot set current schema to last added schema when no schema has been added")
if new_schema_id == base_metadata.current_schema_id:
return base_metadata
schema = base_metadata.schema_by_id(new_schema_id)
if schema is None:
raise ValueError(f"Schema with id {new_schema_id} does not exist")
context.add_update(update)
return base_metadata.model_copy(update={"current_schema_id": new_schema_id})
@_apply_table_update.register(AddPartitionSpecUpdate)
def _(update: AddPartitionSpecUpdate, base_metadata: TableMetadata, context: _TableMetadataUpdateContext) -> TableMetadata:
for spec in base_metadata.partition_specs:
# Only raise in case of a discrepancy
if spec.spec_id == update.spec.spec_id and spec != update.spec:
raise ValueError(f"Partition spec with id {spec.spec_id} already exists: {spec}")
metadata_updates: Dict[str, Any] = {
"partition_specs": base_metadata.partition_specs + [update.spec],
"last_partition_id": max(
max([field.field_id for field in update.spec.fields], default=0),
base_metadata.last_partition_id or PARTITION_FIELD_ID_START - 1,
),
}
context.add_update(update)
return base_metadata.model_copy(update=metadata_updates)
@_apply_table_update.register(SetDefaultSpecUpdate)
def _(update: SetDefaultSpecUpdate, base_metadata: TableMetadata, context: _TableMetadataUpdateContext) -> TableMetadata:
new_spec_id = update.spec_id
if new_spec_id == -1:
new_spec_id = max(spec.spec_id for spec in base_metadata.partition_specs)
if not context.is_added_partition_spec(new_spec_id):
raise ValueError("Cannot set current partition spec to last added one when no partition spec has been added")
if new_spec_id == base_metadata.default_spec_id:
return base_metadata
found_spec_id = False
for spec in base_metadata.partition_specs:
found_spec_id = spec.spec_id == new_spec_id
if found_spec_id:
break
if not found_spec_id:
raise ValueError(f"Failed to find spec with id {new_spec_id}")
context.add_update(update)
return base_metadata.model_copy(update={"default_spec_id": new_spec_id})
@_apply_table_update.register(AddSnapshotUpdate)
def _(update: AddSnapshotUpdate, base_metadata: TableMetadata, context: _TableMetadataUpdateContext) -> TableMetadata:
if len(base_metadata.schemas) == 0:
raise ValueError("Attempting to add a snapshot before a schema is added")
elif len(base_metadata.partition_specs) == 0:
raise ValueError("Attempting to add a snapshot before a partition spec is added")
elif len(base_metadata.sort_orders) == 0:
raise ValueError("Attempting to add a snapshot before a sort order is added")
elif base_metadata.snapshot_by_id(update.snapshot.snapshot_id) is not None:
raise ValueError(f"Snapshot with id {update.snapshot.snapshot_id} already exists")
elif (
base_metadata.format_version == 2
and update.snapshot.sequence_number is not None
and update.snapshot.sequence_number <= base_metadata.last_sequence_number
and update.snapshot.parent_snapshot_id is not None
):
raise ValueError(
f"Cannot add snapshot with sequence number {update.snapshot.sequence_number} "
f"older than last sequence number {base_metadata.last_sequence_number}"
)
context.add_update(update)
return base_metadata.model_copy(
update={
"last_updated_ms": update.snapshot.timestamp_ms,
"last_sequence_number": update.snapshot.sequence_number,
"snapshots": base_metadata.snapshots + [update.snapshot],
}
)
@_apply_table_update.register(SetSnapshotRefUpdate)
def _(update: SetSnapshotRefUpdate, base_metadata: TableMetadata, context: _TableMetadataUpdateContext) -> TableMetadata:
snapshot_ref = SnapshotRef(
snapshot_id=update.snapshot_id,
snapshot_ref_type=update.type,
min_snapshots_to_keep=update.min_snapshots_to_keep,
max_snapshot_age_ms=update.max_snapshot_age_ms,
max_ref_age_ms=update.max_ref_age_ms,
)
existing_ref = base_metadata.refs.get(update.ref_name)
if existing_ref is not None and existing_ref == snapshot_ref:
return base_metadata
snapshot = base_metadata.snapshot_by_id(snapshot_ref.snapshot_id)
if snapshot is None:
raise ValueError(f"Cannot set {update.ref_name} to unknown snapshot {snapshot_ref.snapshot_id}")
metadata_updates: Dict[str, Any] = {}
if context.is_added_snapshot(snapshot_ref.snapshot_id):
metadata_updates["last_updated_ms"] = snapshot.timestamp_ms
if update.ref_name == MAIN_BRANCH:
metadata_updates["current_snapshot_id"] = snapshot_ref.snapshot_id
if "last_updated_ms" not in metadata_updates:
metadata_updates["last_updated_ms"] = datetime_to_millis(datetime.now().astimezone())
metadata_updates["snapshot_log"] = base_metadata.snapshot_log + [
SnapshotLogEntry(
snapshot_id=snapshot_ref.snapshot_id,
timestamp_ms=metadata_updates["last_updated_ms"],
)
]
metadata_updates["refs"] = {**base_metadata.refs, update.ref_name: snapshot_ref}
context.add_update(update)
return base_metadata.model_copy(update=metadata_updates)
@_apply_table_update.register(RemoveSnapshotsUpdate)
def _(update: RemoveSnapshotsUpdate, base_metadata: TableMetadata, context: _TableMetadataUpdateContext) -> TableMetadata:
for remove_snapshot_id in update.snapshot_ids:
if not any(snapshot.snapshot_id == remove_snapshot_id for snapshot in base_metadata.snapshots):
raise ValueError(f"Snapshot with snapshot id {remove_snapshot_id} does not exist: {base_metadata.snapshots}")
snapshots = [
(
snapshot.model_copy(update={"parent_snapshot_id": None})
if snapshot.parent_snapshot_id in update.snapshot_ids
else snapshot
)
for snapshot in base_metadata.snapshots
if snapshot.snapshot_id not in update.snapshot_ids
]
snapshot_log = [
snapshot_log_entry
for snapshot_log_entry in base_metadata.snapshot_log
if snapshot_log_entry.snapshot_id not in update.snapshot_ids
]
remove_ref_updates = (
RemoveSnapshotRefUpdate(ref_name=ref_name)
for ref_name, ref in base_metadata.refs.items()
if ref.snapshot_id in update.snapshot_ids
)
remove_statistics_updates = (
RemoveStatisticsUpdate(statistics_file.snapshot_id)
for statistics_file in base_metadata.statistics
if statistics_file.snapshot_id in update.snapshot_ids
)
updates = itertools.chain(remove_ref_updates, remove_statistics_updates)
new_metadata = base_metadata
for upd in updates:
new_metadata = _apply_table_update(upd, new_metadata, context)
context.add_update(update)
return new_metadata.model_copy(update={"snapshots": snapshots, "snapshot_log": snapshot_log})
@_apply_table_update.register(RemoveSnapshotRefUpdate)
def _(update: RemoveSnapshotRefUpdate, base_metadata: TableMetadata, context: _TableMetadataUpdateContext) -> TableMetadata:
if update.ref_name not in base_metadata.refs:
return base_metadata
existing_ref = base_metadata.refs[update.ref_name]
if base_metadata.snapshot_by_id(existing_ref.snapshot_id) is None:
raise ValueError(f"Cannot remove {update.ref_name} ref with unknown snapshot {existing_ref.snapshot_id}")
current_snapshot_id = None if update.ref_name == MAIN_BRANCH else base_metadata.current_snapshot_id
metadata_refs = {ref_name: ref for ref_name, ref in base_metadata.refs.items() if ref_name != update.ref_name}
context.add_update(update)
return base_metadata.model_copy(update={"refs": metadata_refs, "current_snapshot_id": current_snapshot_id})
@_apply_table_update.register(AddSortOrderUpdate)
def _(update: AddSortOrderUpdate, base_metadata: TableMetadata, context: _TableMetadataUpdateContext) -> TableMetadata:
for sort in base_metadata.sort_orders:
# Only raise in case of a discrepancy
if sort.order_id == update.sort_order.order_id and sort != update.sort_order:
raise ValueError(f"Sort-order with id {sort.order_id} already exists: {sort}")
context.add_update(update)
return base_metadata.model_copy(
update={
"sort_orders": base_metadata.sort_orders + [update.sort_order],
}
)
@_apply_table_update.register(SetDefaultSortOrderUpdate)
def _(
update: SetDefaultSortOrderUpdate,
base_metadata: TableMetadata,
context: _TableMetadataUpdateContext,
) -> TableMetadata:
new_sort_order_id = update.sort_order_id
if new_sort_order_id == -1:
# The last added sort order should be in base_metadata.sort_orders at this point
new_sort_order_id = max(sort_order.order_id for sort_order in base_metadata.sort_orders)
if not context.is_added_sort_order(new_sort_order_id):
raise ValueError("Cannot set current sort order to the last added one when no sort order has been added")
if new_sort_order_id == base_metadata.default_sort_order_id:
return base_metadata
sort_order = base_metadata.sort_order_by_id(new_sort_order_id)
if sort_order is None:
raise ValueError(f"Sort order with id {new_sort_order_id} does not exist")
context.add_update(update)
return base_metadata.model_copy(update={"default_sort_order_id": new_sort_order_id})
@_apply_table_update.register(SetStatisticsUpdate)
def _(update: SetStatisticsUpdate, base_metadata: TableMetadata, context: _TableMetadataUpdateContext) -> TableMetadata:
statistics = filter_statistics_by_snapshot_id(base_metadata.statistics, update.statistics.snapshot_id)
context.add_update(update)
return base_metadata.model_copy(update={"statistics": statistics + [update.statistics]})
@_apply_table_update.register(RemoveStatisticsUpdate)
def _(update: RemoveStatisticsUpdate, base_metadata: TableMetadata, context: _TableMetadataUpdateContext) -> TableMetadata:
if not any(stat.snapshot_id == update.snapshot_id for stat in base_metadata.statistics):
raise ValueError(f"Statistics with snapshot id {update.snapshot_id} does not exist")
statistics = filter_statistics_by_snapshot_id(base_metadata.statistics, update.snapshot_id)
context.add_update(update)
return base_metadata.model_copy(update={"statistics": statistics})
def update_table_metadata(
base_metadata: TableMetadata,
updates: Tuple[TableUpdate, ...],
enforce_validation: bool = False,
metadata_location: Optional[str] = None,
) -> TableMetadata:
"""Update the table metadata with the given updates in one transaction.
Args:
base_metadata: The base metadata to be updated.
updates: The updates in one transaction.
enforce_validation: Whether to trigger validation after applying the updates.
metadata_location: Current metadata location of the table
Returns:
The metadata with the updates applied.
"""
context = _TableMetadataUpdateContext()
new_metadata = base_metadata
for update in updates:
new_metadata = _apply_table_update(update, new_metadata, context)
# Update last_updated_ms if it was not updated by update operations
if context.has_changes():
if metadata_location:
new_metadata = _update_table_metadata_log(new_metadata, metadata_location, base_metadata.last_updated_ms)
if base_metadata.last_updated_ms == new_metadata.last_updated_ms:
new_metadata = new_metadata.model_copy(update={"last_updated_ms": datetime_to_millis(datetime.now().astimezone())})
if enforce_validation:
return TableMetadataUtil.parse_obj(new_metadata.model_dump())
else:
return new_metadata.model_copy(deep=True)
def _update_table_metadata_log(base_metadata: TableMetadata, metadata_location: str, last_updated_ms: int) -> TableMetadata:
from pyiceberg.table import TableProperties
"""
Update the metadata log of the table.
Args:
base_metadata: The base metadata to be updated.
metadata_location: Current metadata location of the table
last_updated_ms: The timestamp of the last update of table metadata
Returns:
The metadata with the updates applied to metadata-log.
"""
max_metadata_log_entries = max(
1,
property_as_int(
base_metadata.properties,
TableProperties.METADATA_PREVIOUS_VERSIONS_MAX,
TableProperties.METADATA_PREVIOUS_VERSIONS_MAX_DEFAULT,
), # type: ignore
)
previous_metadata_log = base_metadata.metadata_log
if len(base_metadata.metadata_log) >= max_metadata_log_entries: # type: ignore
remove_index = len(base_metadata.metadata_log) - max_metadata_log_entries + 1 # type: ignore
previous_metadata_log = base_metadata.metadata_log[remove_index:]
metadata_updates: Dict[str, Any] = {
"metadata_log": previous_metadata_log + [MetadataLogEntry(metadata_file=metadata_location, timestamp_ms=last_updated_ms)]
}
return base_metadata.model_copy(update=metadata_updates)
class ValidatableTableRequirement(IcebergBaseModel):
type: str
@abstractmethod
def validate(self, base_metadata: Optional[TableMetadata]) -> None:
"""Validate the requirement against the base metadata.
Args:
base_metadata: The base metadata to be validated against.
Raises:
CommitFailedException: When the requirement is not met.
"""
...
class AssertCreate(ValidatableTableRequirement):
"""The table must not already exist; used for create transactions."""
type: Literal["assert-create"] = Field(default="assert-create")
def validate(self, base_metadata: Optional[TableMetadata]) -> None:
if base_metadata is not None:
raise CommitFailedException("Table already exists")
class AssertTableUUID(ValidatableTableRequirement):
"""The table UUID must match the requirement's `uuid`."""
type: Literal["assert-table-uuid"] = Field(default="assert-table-uuid")
uuid: uuid.UUID
def validate(self, base_metadata: Optional[TableMetadata]) -> None:
if base_metadata is None:
raise CommitFailedException("Requirement failed: current table metadata is missing")
elif self.uuid != base_metadata.table_uuid:
raise CommitFailedException(f"Table UUID does not match: {self.uuid} != {base_metadata.table_uuid}")
class AssertRefSnapshotId(ValidatableTableRequirement):
"""The table branch or tag identified by the requirement's `ref` must reference the requirement's `snapshot-id`.
if `snapshot-id` is `null` or missing, the ref must not already exist.
"""
type: Literal["assert-ref-snapshot-id"] = Field(default="assert-ref-snapshot-id")
ref: str = Field(...)
snapshot_id: Optional[int] = Field(default=None, alias="snapshot-id")
def validate(self, base_metadata: Optional[TableMetadata]) -> None:
if base_metadata is None:
raise CommitFailedException("Requirement failed: current table metadata is missing")
elif snapshot_ref := base_metadata.refs.get(self.ref):
ref_type = snapshot_ref.snapshot_ref_type
if self.snapshot_id is None:
raise CommitFailedException(f"Requirement failed: {ref_type} {self.ref} was created concurrently")
elif self.snapshot_id != snapshot_ref.snapshot_id:
raise CommitFailedException(
f"Requirement failed: {ref_type} {self.ref} has changed: expected id {self.snapshot_id}, found {snapshot_ref.snapshot_id}"
)
elif self.snapshot_id is not None:
raise CommitFailedException(f"Requirement failed: branch or tag {self.ref} is missing, expected {self.snapshot_id}")
class AssertLastAssignedFieldId(ValidatableTableRequirement):
"""The table's last assigned column id must match the requirement's `last-assigned-field-id`."""
type: Literal["assert-last-assigned-field-id"] = Field(default="assert-last-assigned-field-id")
last_assigned_field_id: int = Field(..., alias="last-assigned-field-id")
def validate(self, base_metadata: Optional[TableMetadata]) -> None:
if base_metadata is None:
raise CommitFailedException("Requirement failed: current table metadata is missing")
elif base_metadata.last_column_id != self.last_assigned_field_id:
raise CommitFailedException(
f"Requirement failed: last assigned field id has changed: expected {self.last_assigned_field_id}, found {base_metadata.last_column_id}"
)
class AssertCurrentSchemaId(ValidatableTableRequirement):
"""The table's current schema id must match the requirement's `current-schema-id`."""
type: Literal["assert-current-schema-id"] = Field(default="assert-current-schema-id")
current_schema_id: int = Field(..., alias="current-schema-id")
def validate(self, base_metadata: Optional[TableMetadata]) -> None:
if base_metadata is None:
raise CommitFailedException("Requirement failed: current table metadata is missing")
elif self.current_schema_id != base_metadata.current_schema_id:
raise CommitFailedException(
f"Requirement failed: current schema id has changed: expected {self.current_schema_id}, found {base_metadata.current_schema_id}"
)
class AssertLastAssignedPartitionId(ValidatableTableRequirement):
"""The table's last assigned partition id must match the requirement's `last-assigned-partition-id`."""
type: Literal["assert-last-assigned-partition-id"] = Field(default="assert-last-assigned-partition-id")
last_assigned_partition_id: Optional[int] = Field(..., alias="last-assigned-partition-id")
def validate(self, base_metadata: Optional[TableMetadata]) -> None:
if base_metadata is None:
raise CommitFailedException("Requirement failed: current table metadata is missing")
elif base_metadata.last_partition_id != self.last_assigned_partition_id:
raise CommitFailedException(
f"Requirement failed: last assigned partition id has changed: expected {self.last_assigned_partition_id}, found {base_metadata.last_partition_id}"
)
class AssertDefaultSpecId(ValidatableTableRequirement):
"""The table's default spec id must match the requirement's `default-spec-id`."""
type: Literal["assert-default-spec-id"] = Field(default="assert-default-spec-id")
default_spec_id: int = Field(..., alias="default-spec-id")
def validate(self, base_metadata: Optional[TableMetadata]) -> None:
if base_metadata is None:
raise CommitFailedException("Requirement failed: current table metadata is missing")
elif self.default_spec_id != base_metadata.default_spec_id:
raise CommitFailedException(
f"Requirement failed: default spec id has changed: expected {self.default_spec_id}, found {base_metadata.default_spec_id}"
)
class AssertDefaultSortOrderId(ValidatableTableRequirement):
"""The table's default sort order id must match the requirement's `default-sort-order-id`."""
type: Literal["assert-default-sort-order-id"] = Field(default="assert-default-sort-order-id")
default_sort_order_id: int = Field(..., alias="default-sort-order-id")
def validate(self, base_metadata: Optional[TableMetadata]) -> None:
if base_metadata is None:
raise CommitFailedException("Requirement failed: current table metadata is missing")
elif self.default_sort_order_id != base_metadata.default_sort_order_id:
raise CommitFailedException(
f"Requirement failed: default sort order id has changed: expected {self.default_sort_order_id}, found {base_metadata.default_sort_order_id}"
)
TableRequirement = Annotated[
Union[
AssertCreate,
AssertTableUUID,
AssertRefSnapshotId,
AssertLastAssignedFieldId,
AssertCurrentSchemaId,
AssertLastAssignedPartitionId,
AssertDefaultSpecId,
AssertDefaultSortOrderId,
],
Field(discriminator="type"),
]
UpdatesAndRequirements = Tuple[Tuple[TableUpdate, ...], Tuple[TableRequirement, ...]]