pyiceberg/table/name_mapping.py (248 lines of code) (raw):

# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. """ Contains everything around the name mapping. More information can be found on here: https://iceberg.apache.org/spec/#name-mapping-serialization """ from __future__ import annotations from abc import ABC, abstractmethod from collections import ChainMap from functools import cached_property, singledispatch from typing import Any, Dict, Generic, Iterator, List, Optional, TypeVar, Union from pydantic import Field, conlist, field_validator, model_serializer from pyiceberg.schema import P, PartnerAccessor, Schema, SchemaVisitor, SchemaWithPartnerVisitor, visit, visit_with_partner from pyiceberg.typedef import IcebergBaseModel, IcebergRootModel from pyiceberg.types import IcebergType, ListType, MapType, NestedField, PrimitiveType, StructType class MappedField(IcebergBaseModel): field_id: Optional[int] = Field(alias="field-id", default=None) names: List[str] = conlist(str) fields: List[MappedField] = Field(default_factory=list) @field_validator("fields", mode="before") @classmethod def convert_null_to_empty_List(cls, v: Any) -> Any: return v or [] @model_serializer def ser_model(self) -> Dict[str, Any]: """Set custom serializer to leave out the field when it is empty.""" serialized: Dict[str, Any] = {"names": self.names} if self.field_id is not None: serialized["field-id"] = self.field_id if len(self.fields) > 0: serialized["fields"] = self.fields return serialized def __len__(self) -> int: """Return the number of fields.""" return len(self.fields) def __str__(self) -> str: """Convert the mapped-field into a nicely formatted string.""" # Otherwise the UTs fail because the order of the set can change fields_str = ", ".join([str(e) for e in self.fields]) or "" fields_str = " " + fields_str if fields_str else "" field_id = "?" if self.field_id is None else (str(self.field_id) or "?") return "([" + ", ".join(self.names) + "] -> " + field_id + fields_str + ")" class NameMapping(IcebergRootModel[List[MappedField]]): root: List[MappedField] @cached_property def _field_by_name(self) -> Dict[str, MappedField]: return visit_name_mapping(self, _IndexByName()) def __len__(self) -> int: """Return the number of mappings.""" return len(self.root) def __iter__(self) -> Iterator[MappedField]: """Iterate over the mapped fields.""" return iter(self.root) def __str__(self) -> str: """Convert the name-mapping into a nicely formatted string.""" if len(self.root) == 0: return "[]" else: return "[\n " + "\n ".join([str(e) for e in self.root]) + "\n]" S = TypeVar("S") T = TypeVar("T") class NameMappingVisitor(Generic[S, T], ABC): @abstractmethod def mapping(self, nm: NameMapping, field_results: S) -> S: """Visit a NameMapping.""" @abstractmethod def fields(self, struct: List[MappedField], field_results: List[T]) -> S: """Visit a List[MappedField].""" @abstractmethod def field(self, field: MappedField, field_result: S) -> T: """Visit a MappedField.""" class _IndexByName(NameMappingVisitor[Dict[str, MappedField], Dict[str, MappedField]]): def mapping(self, nm: NameMapping, field_results: Dict[str, MappedField]) -> Dict[str, MappedField]: return field_results def fields(self, struct: List[MappedField], field_results: List[Dict[str, MappedField]]) -> Dict[str, MappedField]: return dict(ChainMap(*field_results)) def field(self, field: MappedField, field_result: Dict[str, MappedField]) -> Dict[str, MappedField]: result: Dict[str, MappedField] = { f"{field_name}.{key}": result_field for key, result_field in field_result.items() for field_name in field.names } for name in field.names: result[name] = field return result @singledispatch def visit_name_mapping(obj: Union[NameMapping, List[MappedField], MappedField], visitor: NameMappingVisitor[S, T]) -> S: """Traverse the name mapping in post-order traversal.""" raise NotImplementedError(f"Cannot visit non-type: {obj}") @visit_name_mapping.register(NameMapping) def _(obj: NameMapping, visitor: NameMappingVisitor[S, T]) -> S: return visitor.mapping(obj, visit_name_mapping(obj.root, visitor)) @visit_name_mapping.register(list) def _(fields: List[MappedField], visitor: NameMappingVisitor[S, T]) -> S: results = [visitor.field(field, visit_name_mapping(field.fields, visitor)) for field in fields] return visitor.fields(fields, results) def parse_mapping_from_json(mapping: str) -> NameMapping: return NameMapping.model_validate_json(mapping) class _CreateMapping(SchemaVisitor[List[MappedField]]): def schema(self, schema: Schema, struct_result: List[MappedField]) -> List[MappedField]: return struct_result def struct(self, struct: StructType, field_results: List[List[MappedField]]) -> List[MappedField]: return [ MappedField(field_id=field.field_id, names=[field.name], fields=result) for field, result in zip(struct.fields, field_results) ] def field(self, field: NestedField, field_result: List[MappedField]) -> List[MappedField]: return field_result def list(self, list_type: ListType, element_result: List[MappedField]) -> List[MappedField]: return [MappedField(field_id=list_type.element_id, names=["element"], fields=element_result)] def map(self, map_type: MapType, key_result: List[MappedField], value_result: List[MappedField]) -> List[MappedField]: return [ MappedField(field_id=map_type.key_id, names=["key"], fields=key_result), MappedField(field_id=map_type.value_id, names=["value"], fields=value_result), ] def primitive(self, primitive: PrimitiveType) -> List[MappedField]: return [] class _UpdateMapping(NameMappingVisitor[List[MappedField], MappedField]): _updates: Dict[int, NestedField] _adds: Dict[int, List[NestedField]] def __init__(self, updates: Dict[int, NestedField], adds: Dict[int, List[NestedField]]): self._updates = updates self._adds = adds @staticmethod def _remove_reassigned_names(field: MappedField, assignments: Dict[str, int]) -> Optional[MappedField]: removed_names = set() for name in field.names: if (assigned_id := assignments.get(name)) and assigned_id != field.field_id: removed_names.add(name) remaining_names = [f for f in field.names if f not in removed_names] if remaining_names: return MappedField(field_id=field.field_id, names=remaining_names, fields=field.fields) else: return None def _add_new_fields(self, mapped_fields: List[MappedField], parent_id: int) -> List[MappedField]: if fields_to_add := self._adds.get(parent_id): fields: List[MappedField] = [] new_fields: List[MappedField] = [] for add in fields_to_add: new_fields.append( MappedField(field_id=add.field_id, names=[add.name], fields=visit(add.field_type, _CreateMapping())) ) reassignments = {f.name: f.field_id for f in fields_to_add} fields = [ updated_field for field in mapped_fields if (updated_field := self._remove_reassigned_names(field, reassignments)) is not None ] + new_fields return fields else: return mapped_fields def mapping(self, nm: NameMapping, field_results: List[MappedField]) -> List[MappedField]: return self._add_new_fields(field_results, -1) def fields(self, struct: List[MappedField], field_results: List[MappedField]) -> List[MappedField]: reassignments: Dict[str, int] = { update.name: update.field_id for f in field_results if f.field_id is not None and (update := self._updates.get(f.field_id)) } return [ updated_field for field in field_results if (updated_field := self._remove_reassigned_names(field, reassignments)) is not None ] def field(self, field: MappedField, field_result: List[MappedField]) -> MappedField: if field.field_id is None: return field field_names = field.names if (update := self._updates.get(field.field_id)) is not None and update.name not in field_names: field_names.append(update.name) return MappedField(field_id=field.field_id, names=field_names, fields=self._add_new_fields(field_result, field.field_id)) def create_mapping_from_schema(schema: Schema) -> NameMapping: return NameMapping(visit(schema, _CreateMapping())) def update_mapping(mapping: NameMapping, updates: Dict[int, NestedField], adds: Dict[int, List[NestedField]]) -> NameMapping: return NameMapping(visit_name_mapping(mapping, _UpdateMapping(updates, adds))) class NameMappingAccessor(PartnerAccessor[MappedField]): def schema_partner(self, partner: Optional[MappedField]) -> Optional[MappedField]: return partner def field_partner( self, partner_struct: Optional[Union[List[MappedField], MappedField]], _: int, field_name: str ) -> Optional[MappedField]: if partner_struct is not None: if isinstance(partner_struct, MappedField): partner_struct = partner_struct.fields for field in partner_struct: if field_name in field.names: return field return None def list_element_partner(self, partner_list: Optional[MappedField]) -> Optional[MappedField]: if partner_list is not None: for field in partner_list.fields: if "element" in field.names: return field return None def map_key_partner(self, partner_map: Optional[MappedField]) -> Optional[MappedField]: if partner_map is not None: for field in partner_map.fields: if "key" in field.names: return field return None def map_value_partner(self, partner_map: Optional[MappedField]) -> Optional[MappedField]: if partner_map is not None: for field in partner_map.fields: if "value" in field.names: return field return None class NameMappingProjectionVisitor(SchemaWithPartnerVisitor[MappedField, IcebergType]): current_path: List[str] def __init__(self) -> None: # For keeping track where we are in case when a field cannot be found self.current_path = [] def before_field(self, field: NestedField, field_partner: Optional[P]) -> None: self.current_path.append(field.name) def after_field(self, field: NestedField, field_partner: Optional[P]) -> None: self.current_path.pop() def before_list_element(self, element: NestedField, element_partner: Optional[P]) -> None: self.current_path.append("element") def after_list_element(self, element: NestedField, element_partner: Optional[P]) -> None: self.current_path.pop() def before_map_key(self, key: NestedField, key_partner: Optional[P]) -> None: self.current_path.append("key") def after_map_key(self, key: NestedField, key_partner: Optional[P]) -> None: self.current_path.pop() def before_map_value(self, value: NestedField, value_partner: Optional[P]) -> None: self.current_path.append("value") def after_map_value(self, value: NestedField, value_partner: Optional[P]) -> None: self.current_path.pop() def schema(self, schema: Schema, schema_partner: Optional[MappedField], struct_result: StructType) -> IcebergType: return Schema(*struct_result.fields, schema_id=schema.schema_id) def struct(self, struct: StructType, struct_partner: Optional[MappedField], field_results: List[NestedField]) -> IcebergType: return StructType(*field_results) def field(self, field: NestedField, field_partner: Optional[MappedField], field_result: IcebergType) -> IcebergType: if field_partner is None or field_partner.field_id is None: raise ValueError(f"Field or field ID missing from NameMapping: {'.'.join(self.current_path)}") return NestedField( field_id=field_partner.field_id, name=field.name, field_type=field_result, required=field.required, doc=field.doc, initial_default=field.initial_default, initial_write=field.write_default, ) def list(self, list_type: ListType, list_partner: Optional[MappedField], element_result: IcebergType) -> IcebergType: if list_partner is None: raise ValueError(f"Could not find field with name: {'.'.join(self.current_path)}") element_id = next(field for field in list_partner.fields if "element" in field.names).field_id return ListType(element_id=element_id, element=element_result, element_required=list_type.element_required) def map( self, map_type: MapType, map_partner: Optional[MappedField], key_result: IcebergType, value_result: IcebergType ) -> IcebergType: if map_partner is None: raise ValueError(f"Could not find field with name: {'.'.join(self.current_path)}") key_id = next(field for field in map_partner.fields if "key" in field.names).field_id value_id = next(field for field in map_partner.fields if "value" in field.names).field_id return MapType( key_id=key_id, key_type=key_result, value_id=value_id, value_type=value_result, value_required=map_type.value_required, ) def primitive(self, primitive: PrimitiveType, primitive_partner: Optional[MappedField]) -> PrimitiveType: if primitive_partner is None: raise ValueError(f"Could not find field with name: {'.'.join(self.current_path)}") return primitive def apply_name_mapping(schema_without_ids: Schema, name_mapping: NameMapping) -> Schema: return visit_with_partner(schema_without_ids, name_mapping, NameMappingProjectionVisitor(), NameMappingAccessor()) # type: ignore