pyiceberg/table/update/schema.py (694 lines of code) (raw):
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
import itertools
from copy import copy
from dataclasses import dataclass
from enum import Enum
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Union
from pyiceberg.exceptions import ResolveError, ValidationError
from pyiceberg.expressions import literal # type: ignore
from pyiceberg.schema import (
PartnerAccessor,
Schema,
SchemaVisitor,
SchemaWithPartnerVisitor,
assign_fresh_schema_ids,
promote,
visit,
visit_with_partner,
)
from pyiceberg.table.name_mapping import (
NameMapping,
update_mapping,
)
from pyiceberg.table.update import (
AddSchemaUpdate,
AssertCurrentSchemaId,
SetCurrentSchemaUpdate,
SetPropertiesUpdate,
TableRequirement,
TableUpdate,
UpdatesAndRequirements,
UpdateTableMetadata,
)
from pyiceberg.typedef import L
from pyiceberg.types import IcebergType, ListType, MapType, NestedField, PrimitiveType, StructType
if TYPE_CHECKING:
import pyarrow as pa
from pyiceberg.table import Transaction
TABLE_ROOT_ID = -1
class _MoveOperation(Enum):
First = 1
Before = 2
After = 3
@dataclass
class _Move:
field_id: int
full_name: str
op: _MoveOperation
other_field_id: Optional[int] = None
class UpdateSchema(UpdateTableMetadata["UpdateSchema"]):
_schema: Schema
_last_column_id: itertools.count[int]
_identifier_field_names: Set[str]
_adds: Dict[int, List[NestedField]] = {}
_updates: Dict[int, NestedField] = {}
_deletes: Set[int] = set()
_moves: Dict[int, List[_Move]] = {}
_added_name_to_id: Dict[str, int] = {}
# Part of https://github.com/apache/iceberg/pull/8393
_id_to_parent: Dict[int, str] = {}
_allow_incompatible_changes: bool
_case_sensitive: bool
def __init__(
self,
transaction: Transaction,
allow_incompatible_changes: bool = False,
case_sensitive: bool = True,
schema: Optional[Schema] = None,
name_mapping: Optional[NameMapping] = None,
) -> None:
super().__init__(transaction)
if isinstance(schema, Schema):
self._schema = schema
self._last_column_id = itertools.count(1 + schema.highest_field_id)
else:
self._schema = self._transaction.table_metadata.schema()
self._last_column_id = itertools.count(1 + self._transaction.table_metadata.last_column_id)
self._name_mapping = name_mapping
self._identifier_field_names = self._schema.identifier_field_names()
self._adds = {}
self._updates = {}
self._deletes = set()
self._moves = {}
self._added_name_to_id = {}
def get_column_name(field_id: int) -> str:
column_name = self._schema.find_column_name(column_id=field_id)
if column_name is None:
raise ValueError(f"Could not find field-id: {field_id}")
return column_name
self._id_to_parent = {
field_id: get_column_name(parent_field_id) for field_id, parent_field_id in self._schema._lazy_id_to_parent.items()
}
self._allow_incompatible_changes = allow_incompatible_changes
self._case_sensitive = case_sensitive
self._transaction = transaction
def case_sensitive(self, case_sensitive: bool) -> UpdateSchema:
"""Determine if the case of schema needs to be considered when comparing column names.
Args:
case_sensitive: When false case is not considered in column name comparisons.
Returns:
This for method chaining
"""
self._case_sensitive = case_sensitive
return self
def union_by_name(self, new_schema: Union[Schema, "pa.Schema"]) -> UpdateSchema:
from pyiceberg.catalog import Catalog
visit_with_partner(
Catalog._convert_schema_if_needed(new_schema),
-1,
_UnionByNameVisitor(update_schema=self, existing_schema=self._schema, case_sensitive=self._case_sensitive),
# type: ignore
PartnerIdByNameAccessor(partner_schema=self._schema, case_sensitive=self._case_sensitive),
)
return self
def add_column(
self,
path: Union[str, Tuple[str, ...]],
field_type: IcebergType,
doc: Optional[str] = None,
required: bool = False,
default_value: Optional[L] = None,
) -> UpdateSchema:
"""Add a new column to a nested struct or Add a new top-level column.
Because "." may be interpreted as a column path separator or may be used in field names, it
is not allowed to add nested column by passing in a string. To add to nested structures or
to add fields with names that contain "." use a tuple instead to indicate the path.
If type is a nested type, its field IDs are reassigned when added to the existing schema.
Args:
path: Name for the new column.
field_type: Type for the new column.
doc: Documentation string for the new column.
required: Whether the new column is required.
default_value: Default value for the new column.
Returns:
This for method chaining.
"""
if isinstance(path, str):
if "." in path:
raise ValueError(f"Cannot add column with ambiguous name: {path}, provide a tuple instead")
path = (path,)
name = path[-1]
parent = path[:-1]
full_name = ".".join(path)
parent_full_path = ".".join(parent)
parent_id: int = TABLE_ROOT_ID
if len(parent) > 0:
parent_field = self._schema.find_field(parent_full_path, self._case_sensitive)
parent_type = parent_field.field_type
if isinstance(parent_type, MapType):
parent_field = parent_type.value_field
elif isinstance(parent_type, ListType):
parent_field = parent_type.element_field
if not parent_field.field_type.is_struct:
raise ValueError(f"Cannot add column '{name}' to non-struct type: {parent_full_path}")
parent_id = parent_field.field_id
existing_field = None
try:
existing_field = self._schema.find_field(full_name, self._case_sensitive)
except ValueError:
pass
if existing_field is not None and existing_field.field_id not in self._deletes:
raise ValueError(f"Cannot add column, name already exists: {full_name}")
# assign new IDs in order
new_id = self.assign_new_column_id()
new_type = assign_fresh_schema_ids(field_type, self.assign_new_column_id)
if default_value is not None:
try:
# To make sure that the value is valid for the type
initial_default = literal(default_value).to(new_type).value
except ValueError as e:
raise ValueError(f"Invalid default value: {e}") from e
else:
initial_default = default_value # type: ignore
if (required and initial_default is None) and not self._allow_incompatible_changes:
# Table format version 1 and 2 cannot add required column because there is no initial value
raise ValueError(f"Incompatible change: cannot add required column: {'.'.join(path)}")
# update tracking for moves
self._added_name_to_id[full_name] = new_id
self._id_to_parent[new_id] = parent_full_path
field = NestedField(
field_id=new_id,
name=name,
field_type=new_type,
required=required,
doc=doc,
initial_default=initial_default,
write_default=initial_default,
)
if parent_id in self._adds:
self._adds[parent_id].append(field)
else:
self._adds[parent_id] = [field]
return self
def delete_column(self, path: Union[str, Tuple[str, ...]]) -> UpdateSchema:
"""Delete a column from a table.
Args:
path: The path to the column.
Returns:
The UpdateSchema with the delete operation staged.
"""
name = (path,) if isinstance(path, str) else path
full_name = ".".join(name)
field = self._schema.find_field(full_name, case_sensitive=self._case_sensitive)
if field.field_id in self._adds:
raise ValueError(f"Cannot delete a column that has additions: {full_name}")
if field.field_id in self._updates:
raise ValueError(f"Cannot delete a column that has updates: {full_name}")
self._deletes.add(field.field_id)
return self
def set_default_value(self, path: Union[str, Tuple[str, ...]], default_value: Optional[L]) -> UpdateSchema:
"""Set the default value of a column.
Args:
path: The path to the column.
Returns:
The UpdateSchema with the delete operation staged.
"""
self._set_column_default_value(path, default_value)
return self
def rename_column(self, path_from: Union[str, Tuple[str, ...]], new_name: str) -> UpdateSchema:
"""Update the name of a column.
Args:
path_from: The path to the column to be renamed.
new_name: The new path of the column.
Returns:
The UpdateSchema with the rename operation staged.
"""
path_from = ".".join(path_from) if isinstance(path_from, tuple) else path_from
field_from = self._schema.find_field(path_from, self._case_sensitive)
if field_from.field_id in self._deletes:
raise ValueError(f"Cannot rename a column that will be deleted: {path_from}")
if updated := self._updates.get(field_from.field_id):
self._updates[field_from.field_id] = NestedField(
field_id=updated.field_id,
name=new_name,
field_type=updated.field_type,
doc=updated.doc,
required=updated.required,
initial_default=updated.initial_default,
write_default=updated.write_default,
)
else:
self._updates[field_from.field_id] = NestedField(
field_id=field_from.field_id,
name=new_name,
field_type=field_from.field_type,
doc=field_from.doc,
required=field_from.required,
initial_default=field_from.initial_default,
write_default=field_from.write_default,
)
# Lookup the field because of casing
from_field_correct_casing = self._schema.find_column_name(field_from.field_id)
if from_field_correct_casing in self._identifier_field_names:
self._identifier_field_names.remove(from_field_correct_casing)
new_identifier_path = f"{from_field_correct_casing[: -len(field_from.name)]}{new_name}"
self._identifier_field_names.add(new_identifier_path)
return self
def make_column_optional(self, path: Union[str, Tuple[str, ...]]) -> UpdateSchema:
"""Make a column optional.
Args:
path: The path to the field.
Returns:
The UpdateSchema with the requirement change staged.
"""
self._set_column_requirement(path, required=False)
return self
def set_identifier_fields(self, *fields: str) -> None:
self._identifier_field_names = set(fields)
def _set_column_requirement(self, path: Union[str, Tuple[str, ...]], required: bool) -> None:
path = (path,) if isinstance(path, str) else path
name = ".".join(path)
field = self._schema.find_field(name, self._case_sensitive)
if (field.required and required) or (field.optional and not required):
# if the change is a noop, allow it even if allowIncompatibleChanges is false
return
if not self._allow_incompatible_changes and required:
raise ValueError(f"Cannot change column nullability: {name}: optional -> required")
if field.field_id in self._deletes:
raise ValueError(f"Cannot update a column that will be deleted: {name}")
if updated := self._updates.get(field.field_id):
self._updates[field.field_id] = NestedField(
field_id=updated.field_id,
name=updated.name,
field_type=updated.field_type,
doc=updated.doc,
required=required,
initial_default=updated.initial_default,
write_default=updated.write_default,
)
else:
self._updates[field.field_id] = NestedField(
field_id=field.field_id,
name=field.name,
field_type=field.field_type,
doc=field.doc,
required=required,
initial_default=field.initial_default,
write_default=field.write_default,
)
def _set_column_default_value(self, path: Union[str, Tuple[str, ...]], default_value: Any) -> None:
path = (path,) if isinstance(path, str) else path
name = ".".join(path)
field = self._schema.find_field(name, self._case_sensitive)
if default_value is not None:
try:
# To make sure that the value is valid for the type
default_value = literal(default_value).to(field.field_type).value
except ValueError as e:
raise ValueError(f"Invalid default value: {e}") from e
if field.required and default_value == field.write_default:
# if the change is a noop, allow it even if allowIncompatibleChanges is false
return
if not self._allow_incompatible_changes and field.required and default_value is None:
raise ValueError("Cannot change change default-value of a required column to None")
if field.field_id in self._deletes:
raise ValueError(f"Cannot update a column that will be deleted: {name}")
if updated := self._updates.get(field.field_id):
self._updates[field.field_id] = NestedField(
field_id=updated.field_id,
name=updated.name,
field_type=updated.field_type,
doc=updated.doc,
required=updated.required,
initial_default=updated.initial_default,
write_default=default_value,
)
else:
self._updates[field.field_id] = NestedField(
field_id=field.field_id,
name=field.name,
field_type=field.field_type,
doc=field.doc,
required=field.required,
initial_default=field.initial_default,
write_default=default_value,
)
def update_column(
self,
path: Union[str, Tuple[str, ...]],
field_type: Optional[IcebergType] = None,
required: Optional[bool] = None,
doc: Optional[str] = None,
) -> UpdateSchema:
"""Update the type of column.
Args:
path: The path to the field.
field_type: The new type
required: If the field should be required
doc: Documentation describing the column
Returns:
The UpdateSchema with the type update staged.
"""
path = (path,) if isinstance(path, str) else path
full_name = ".".join(path)
if field_type is None and required is None and doc is None:
return self
field = self._schema.find_field(full_name, self._case_sensitive)
if field.field_id in self._deletes:
raise ValueError(f"Cannot update a column that will be deleted: {full_name}")
if field_type is not None:
if not field.field_type.is_primitive:
raise ValidationError(f"Cannot change column type: {field.field_type} is not a primitive")
if not self._allow_incompatible_changes and field.field_type != field_type:
try:
promote(field.field_type, field_type)
except ResolveError as e:
raise ValidationError(f"Cannot change column type: {full_name}: {field.field_type} -> {field_type}") from e
# if other updates for the same field exist in one transaction:
if updated := self._updates.get(field.field_id):
self._updates[field.field_id] = NestedField(
field_id=updated.field_id,
name=updated.name,
field_type=field_type or updated.field_type,
doc=doc if doc is not None else updated.doc,
required=updated.required,
initial_default=updated.initial_default,
write_default=updated.write_default,
)
else:
self._updates[field.field_id] = NestedField(
field_id=field.field_id,
name=field.name,
field_type=field_type or field.field_type,
doc=doc if doc is not None else field.doc,
required=field.required,
initial_default=field.initial_default,
write_default=field.write_default,
)
if required is not None:
self._set_column_requirement(path, required=required)
return self
def _find_for_move(self, name: str) -> Optional[int]:
try:
return self._schema.find_field(name, self._case_sensitive).field_id
except ValueError:
pass
return self._added_name_to_id.get(name)
def _move(self, move: _Move) -> None:
if parent_name := self._id_to_parent.get(move.field_id):
parent_field = self._schema.find_field(parent_name, case_sensitive=self._case_sensitive)
if not parent_field.field_type.is_struct:
raise ValueError(f"Cannot move fields in non-struct type: {parent_field.field_type}")
if move.op == _MoveOperation.After or move.op == _MoveOperation.Before:
if move.other_field_id is None:
raise ValueError("Expected other field when performing before/after move")
if self._id_to_parent.get(move.field_id) != self._id_to_parent.get(move.other_field_id):
raise ValueError(f"Cannot move field {move.full_name} to a different struct")
self._moves[parent_field.field_id] = self._moves.get(parent_field.field_id, []) + [move]
else:
# In the top level field
if move.op == _MoveOperation.After or move.op == _MoveOperation.Before:
if move.other_field_id is None:
raise ValueError("Expected other field when performing before/after move")
if other_struct := self._id_to_parent.get(move.other_field_id):
raise ValueError(f"Cannot move field {move.full_name} to a different struct: {other_struct}")
self._moves[TABLE_ROOT_ID] = self._moves.get(TABLE_ROOT_ID, []) + [move]
def move_first(self, path: Union[str, Tuple[str, ...]]) -> UpdateSchema:
"""Move the field to the first position of the parent struct.
Args:
path: The path to the field.
Returns:
The UpdateSchema with the move operation staged.
"""
full_name = ".".join(path) if isinstance(path, tuple) else path
field_id = self._find_for_move(full_name)
if field_id is None:
raise ValueError(f"Cannot move missing column: {full_name}")
self._move(_Move(field_id=field_id, full_name=full_name, op=_MoveOperation.First))
return self
def move_before(self, path: Union[str, Tuple[str, ...]], before_path: Union[str, Tuple[str, ...]]) -> UpdateSchema:
"""Move the field to before another field.
Args:
path: The path to the field.
Returns:
The UpdateSchema with the move operation staged.
"""
full_name = ".".join(path) if isinstance(path, tuple) else path
field_id = self._find_for_move(full_name)
if field_id is None:
raise ValueError(f"Cannot move missing column: {full_name}")
before_full_name = (
".".join(
before_path,
)
if isinstance(before_path, tuple)
else before_path
)
before_field_id = self._find_for_move(before_full_name)
if before_field_id is None:
raise ValueError(f"Cannot move {full_name} before missing column: {before_full_name}")
if field_id == before_field_id:
raise ValueError(f"Cannot move {full_name} before itself")
self._move(_Move(field_id=field_id, full_name=full_name, other_field_id=before_field_id, op=_MoveOperation.Before))
return self
def move_after(self, path: Union[str, Tuple[str, ...]], after_name: Union[str, Tuple[str, ...]]) -> UpdateSchema:
"""Move the field to after another field.
Args:
path: The path to the field.
Returns:
The UpdateSchema with the move operation staged.
"""
full_name = ".".join(path) if isinstance(path, tuple) else path
field_id = self._find_for_move(full_name)
if field_id is None:
raise ValueError(f"Cannot move missing column: {full_name}")
after_path = ".".join(after_name) if isinstance(after_name, tuple) else after_name
after_field_id = self._find_for_move(after_path)
if after_field_id is None:
raise ValueError(f"Cannot move {full_name} after missing column: {after_path}")
if field_id == after_field_id:
raise ValueError(f"Cannot move {full_name} after itself")
self._move(_Move(field_id=field_id, full_name=full_name, other_field_id=after_field_id, op=_MoveOperation.After))
return self
def _commit(self) -> UpdatesAndRequirements:
"""Apply the pending changes and commit."""
from pyiceberg.table import TableProperties
new_schema = self._apply()
existing_schema_id = next(
(schema.schema_id for schema in self._transaction.table_metadata.schemas if schema == new_schema), None
)
requirements: Tuple[TableRequirement, ...] = ()
updates: Tuple[TableUpdate, ...] = ()
# Check if it is different current schema ID
if existing_schema_id != self._schema.schema_id:
requirements += (AssertCurrentSchemaId(current_schema_id=self._schema.schema_id),)
if existing_schema_id is None:
last_column_id = max(self._transaction.table_metadata.last_column_id, new_schema.highest_field_id)
updates += (
AddSchemaUpdate(schema=new_schema, last_column_id=last_column_id),
SetCurrentSchemaUpdate(schema_id=-1),
)
else:
updates += (SetCurrentSchemaUpdate(schema_id=existing_schema_id),)
if name_mapping := self._name_mapping:
updated_name_mapping = update_mapping(name_mapping, self._updates, self._adds)
updates += (
SetPropertiesUpdate(updates={TableProperties.DEFAULT_NAME_MAPPING: updated_name_mapping.model_dump_json()}),
)
return updates, requirements
def _apply(self) -> Schema:
"""Apply the pending changes to the original schema and returns the result.
Returns:
the result Schema when all pending updates are applied
"""
struct = visit(self._schema, _ApplyChanges(self._adds, self._updates, self._deletes, self._moves))
if struct is None:
# Should never happen
raise ValueError("Could not apply changes")
# Check the field-ids
new_schema = Schema(*struct.fields)
field_ids = set()
for name in self._identifier_field_names:
try:
field = new_schema.find_field(name, case_sensitive=self._case_sensitive)
except ValueError as e:
raise ValueError(
f"Cannot find identifier field {name}. In case of deletion, update the identifier fields first."
) from e
field_ids.add(field.field_id)
if txn := self._transaction:
next_schema_id = 1 + (
max(schema.schema_id for schema in txn.table_metadata.schemas) if txn.table_metadata is not None else 0
)
else:
next_schema_id = 0
return Schema(*struct.fields, schema_id=next_schema_id, identifier_field_ids=field_ids)
def assign_new_column_id(self) -> int:
return next(self._last_column_id)
class _ApplyChanges(SchemaVisitor[Optional[IcebergType]]):
_adds: Dict[int, List[NestedField]]
_updates: Dict[int, NestedField]
_deletes: Set[int]
_moves: Dict[int, List[_Move]]
def __init__(
self,
adds: Dict[int, List[NestedField]],
updates: Dict[int, NestedField],
deletes: Set[int],
moves: Dict[int, List[_Move]],
) -> None:
self._adds = adds
self._updates = updates
self._deletes = deletes
self._moves = moves
def schema(self, schema: Schema, struct_result: Optional[IcebergType]) -> Optional[IcebergType]:
added = self._adds.get(TABLE_ROOT_ID)
moves = self._moves.get(TABLE_ROOT_ID)
if added is not None or moves is not None:
if not isinstance(struct_result, StructType):
raise ValueError(f"Cannot add fields to non-struct: {struct_result}")
if new_fields := _add_and_move_fields(struct_result.fields, added or [], moves or []):
return StructType(*new_fields)
return struct_result
def struct(self, struct: StructType, field_results: List[Optional[IcebergType]]) -> Optional[IcebergType]:
has_changes = False
new_fields = []
for idx, result_type in enumerate(field_results):
result_type = field_results[idx]
# Has been deleted
if result_type is None:
has_changes = True
continue
field = struct.fields[idx]
name = field.name
doc = field.doc
required = field.required
write_default = field.write_default
# There is an update
if update := self._updates.get(field.field_id):
name = update.name
doc = update.doc
required = update.required
write_default = update.write_default
if (
field.name == name
and field.field_type == result_type
and field.required == required
and field.doc == doc
and field.write_default == write_default
):
new_fields.append(field)
else:
has_changes = True
new_fields.append(
NestedField(
field_id=field.field_id,
name=name,
field_type=result_type,
required=required,
doc=doc,
initial_default=field.initial_default,
write_default=write_default,
)
)
if has_changes:
return StructType(*new_fields)
return struct
def field(self, field: NestedField, field_result: Optional[IcebergType]) -> Optional[IcebergType]:
# the API validates deletes, updates, and additions don't conflict handle deletes
if field.field_id in self._deletes:
return None
# handle updates
if (update := self._updates.get(field.field_id)) and field.field_type != update.field_type:
return update.field_type
if isinstance(field_result, StructType):
# handle add & moves
added = self._adds.get(field.field_id)
moves = self._moves.get(field.field_id)
if added is not None or moves is not None:
if not isinstance(field.field_type, StructType):
raise ValueError(f"Cannot add fields to non-struct: {field}")
if new_fields := _add_and_move_fields(field_result.fields, added or [], moves or []):
return StructType(*new_fields)
return field_result
def list(self, list_type: ListType, element_result: Optional[IcebergType]) -> Optional[IcebergType]:
element_type = self.field(list_type.element_field, element_result)
if element_type is None:
raise ValueError(f"Cannot delete element type from list: {element_result}")
return ListType(element_id=list_type.element_id, element=element_type, element_required=list_type.element_required)
def map(
self, map_type: MapType, key_result: Optional[IcebergType], value_result: Optional[IcebergType]
) -> Optional[IcebergType]:
key_id: int = map_type.key_field.field_id
if key_id in self._deletes:
raise ValueError(f"Cannot delete map keys: {map_type}")
if key_id in self._updates:
raise ValueError(f"Cannot update map keys: {map_type}")
if key_id in self._adds:
raise ValueError(f"Cannot add fields to map keys: {map_type}")
if map_type.key_type != key_result:
raise ValueError(f"Cannot alter map keys: {map_type}")
value_field: NestedField = map_type.value_field
value_type = self.field(value_field, value_result)
if value_type is None:
raise ValueError(f"Cannot delete value type from map: {value_field}")
return MapType(
key_id=map_type.key_id,
key_type=map_type.key_type,
value_id=map_type.value_id,
value_type=value_type,
value_required=map_type.value_required,
)
def primitive(self, primitive: PrimitiveType) -> Optional[IcebergType]:
return primitive
class _UnionByNameVisitor(SchemaWithPartnerVisitor[int, bool]):
update_schema: UpdateSchema
existing_schema: Schema
case_sensitive: bool
def __init__(self, update_schema: UpdateSchema, existing_schema: Schema, case_sensitive: bool) -> None:
self.update_schema = update_schema
self.existing_schema = existing_schema
self.case_sensitive = case_sensitive
def schema(self, schema: Schema, partner_id: Optional[int], struct_result: bool) -> bool:
return struct_result
def struct(self, struct: StructType, partner_id: Optional[int], missing_positions: List[bool]) -> bool:
if partner_id is None:
return True
fields = struct.fields
partner_struct = self._find_field_type(partner_id)
if not partner_struct.is_struct:
raise ValueError(f"Expected a struct, got: {partner_struct}")
for pos, missing in enumerate(missing_positions):
if missing:
self._add_column(partner_id, fields[pos])
else:
field = fields[pos]
if nested_field := partner_struct.field_by_name(field.name, case_sensitive=self.case_sensitive):
self._update_column(field, nested_field)
return False
def _add_column(self, parent_id: int, field: NestedField) -> None:
if parent_name := self.existing_schema.find_column_name(parent_id):
path: Tuple[str, ...] = (parent_name, field.name)
else:
path = (field.name,)
self.update_schema.add_column(path=path, field_type=field.field_type, required=field.required, doc=field.doc)
def _update_column(self, field: NestedField, existing_field: NestedField) -> None:
full_name = self.existing_schema.find_column_name(existing_field.field_id)
if full_name is None:
raise ValueError(f"Could not find field: {existing_field}")
if field.optional and existing_field.required:
self.update_schema.make_column_optional(full_name)
if field.field_type.is_primitive and field.field_type != existing_field.field_type:
try:
# If the current type is wider than the new type, then
# we perform a noop
_ = promote(field.field_type, existing_field.field_type)
except ResolveError:
# If this is not the case, perform the type evolution
self.update_schema.update_column(full_name, field_type=field.field_type)
if field.doc is not None and field.doc != existing_field.doc:
self.update_schema.update_column(full_name, doc=field.doc)
def _find_field_type(self, field_id: int) -> IcebergType:
if field_id == -1:
return self.existing_schema.as_struct()
else:
return self.existing_schema.find_field(field_id).field_type
def field(self, field: NestedField, partner_id: Optional[int], field_result: bool) -> bool:
return partner_id is None
def list(self, list_type: ListType, list_partner_id: Optional[int], element_missing: bool) -> bool:
if list_partner_id is None:
return True
if element_missing:
raise ValueError("Error traversing schemas: element is missing, but list is present")
partner_list_type = self._find_field_type(list_partner_id)
if not isinstance(partner_list_type, ListType):
raise ValueError(f"Expected list-type, got: {partner_list_type}")
self._update_column(list_type.element_field, partner_list_type.element_field)
return False
def map(self, map_type: MapType, map_partner_id: Optional[int], key_missing: bool, value_missing: bool) -> bool:
if map_partner_id is None:
return True
if key_missing:
raise ValueError("Error traversing schemas: key is missing, but map is present")
if value_missing:
raise ValueError("Error traversing schemas: value is missing, but map is present")
partner_map_type = self._find_field_type(map_partner_id)
if not isinstance(partner_map_type, MapType):
raise ValueError(f"Expected map-type, got: {partner_map_type}")
self._update_column(map_type.key_field, partner_map_type.key_field)
self._update_column(map_type.value_field, partner_map_type.value_field)
return False
def primitive(self, primitive: PrimitiveType, primitive_partner_id: Optional[int]) -> bool:
return primitive_partner_id is None
class PartnerIdByNameAccessor(PartnerAccessor[int]):
partner_schema: Schema
case_sensitive: bool
def __init__(self, partner_schema: Schema, case_sensitive: bool) -> None:
self.partner_schema = partner_schema
self.case_sensitive = case_sensitive
def schema_partner(self, partner: Optional[int]) -> Optional[int]:
return -1
def field_partner(self, partner_field_id: Optional[int], field_id: int, field_name: str) -> Optional[int]:
if partner_field_id is not None:
if partner_field_id == -1:
struct = self.partner_schema.as_struct()
else:
struct = self.partner_schema.find_field(partner_field_id).field_type
if not struct.is_struct:
raise ValueError(f"Expected StructType: {struct}")
if field := struct.field_by_name(name=field_name, case_sensitive=self.case_sensitive):
return field.field_id
return None
def list_element_partner(self, partner_list_id: Optional[int]) -> Optional[int]:
if partner_list_id is not None and (field := self.partner_schema.find_field(partner_list_id)):
if not isinstance(field.field_type, ListType):
raise ValueError(f"Expected ListType: {field}")
return field.field_type.element_field.field_id
else:
return None
def map_key_partner(self, partner_map_id: Optional[int]) -> Optional[int]:
if partner_map_id is not None and (field := self.partner_schema.find_field(partner_map_id)):
if not isinstance(field.field_type, MapType):
raise ValueError(f"Expected MapType: {field}")
return field.field_type.key_field.field_id
else:
return None
def map_value_partner(self, partner_map_id: Optional[int]) -> Optional[int]:
if partner_map_id is not None and (field := self.partner_schema.find_field(partner_map_id)):
if not isinstance(field.field_type, MapType):
raise ValueError(f"Expected MapType: {field}")
return field.field_type.value_field.field_id
else:
return None
def _add_fields(fields: Tuple[NestedField, ...], adds: Optional[List[NestedField]]) -> Tuple[NestedField, ...]:
adds = adds or []
return fields + tuple(adds)
def _move_fields(fields: Tuple[NestedField, ...], moves: List[_Move]) -> Tuple[NestedField, ...]:
reordered = list(copy(fields))
for move in moves:
# Find the field that we're about to move
field = next(field for field in reordered if field.field_id == move.field_id)
# Remove the field that we're about to move from the list
reordered = [field for field in reordered if field.field_id != move.field_id]
if move.op == _MoveOperation.First:
reordered = [field] + reordered
elif move.op == _MoveOperation.Before or move.op == _MoveOperation.After:
other_field_id = move.other_field_id
other_field_pos = next(i for i, field in enumerate(reordered) if field.field_id == other_field_id)
if move.op == _MoveOperation.Before:
reordered.insert(other_field_pos, field)
else:
reordered.insert(other_field_pos + 1, field)
else:
raise ValueError(f"Unknown operation: {move.op}")
return tuple(reordered)
def _add_and_move_fields(
fields: Tuple[NestedField, ...], adds: List[NestedField], moves: List[_Move]
) -> Optional[Tuple[NestedField, ...]]:
if len(adds) > 0:
# always apply adds first so that added fields can be moved
added = _add_fields(fields, adds)
if len(moves) > 0:
return _move_fields(added, moves)
else:
return added
elif len(moves) > 0:
return _move_fields(fields, moves)
return None if len(adds) == 0 else tuple(*fields, *adds)