# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=W0511
from __future__ import annotations

import itertools
from abc import ABC, abstractmethod
from dataclasses import dataclass
from functools import cached_property, partial, singledispatch
from typing import (
    TYPE_CHECKING,
    Any,
    Callable,
    Dict,
    Generic,
    List,
    Literal,
    Optional,
    Set,
    Tuple,
    TypeVar,
    Union,
)

from pydantic import Field, PrivateAttr, model_validator

from pyiceberg.exceptions import ResolveError
from pyiceberg.typedef import EMPTY_DICT, IcebergBaseModel, StructProtocol
from pyiceberg.types import (
    BinaryType,
    BooleanType,
    DateType,
    DecimalType,
    DoubleType,
    FixedType,
    FloatType,
    IcebergType,
    IntegerType,
    ListType,
    LongType,
    MapType,
    NestedField,
    PrimitiveType,
    StringType,
    StructType,
    TimestampNanoType,
    TimestampType,
    TimestamptzNanoType,
    TimestamptzType,
    TimeType,
    UnknownType,
    UUIDType,
)

if TYPE_CHECKING:
    import pyarrow as pa

    from pyiceberg.table.name_mapping import (
        NameMapping,
    )

T = TypeVar("T")
P = TypeVar("P")

INITIAL_SCHEMA_ID = 0


class Schema(IcebergBaseModel):
    """A table Schema.

    Example:
        >>> from pyiceberg import schema
        >>> from pyiceberg import types
    """

    type: Literal["struct"] = "struct"
    fields: Tuple[NestedField, ...] = Field(default_factory=tuple)
    schema_id: int = Field(alias="schema-id", default=INITIAL_SCHEMA_ID)
    identifier_field_ids: List[int] = Field(alias="identifier-field-ids", default_factory=list)

    _name_to_id: Dict[str, int] = PrivateAttr()

    def __init__(self, *fields: NestedField, **data: Any):
        if fields:
            data["fields"] = fields
        super().__init__(**data)
        self._name_to_id = index_by_name(self)

    def __str__(self) -> str:
        """Return the string representation of the Schema class."""
        return "table {\n" + "\n".join(["  " + str(field) for field in self.columns]) + "\n}"

    def __repr__(self) -> str:
        """Return the string representation of the Schema class."""
        return f"Schema({', '.join(repr(column) for column in self.columns)}, schema_id={self.schema_id}, identifier_field_ids={self.identifier_field_ids})"

    def __len__(self) -> int:
        """Return the length of an instance of the Literal class."""
        return len(self.fields)

    def __eq__(self, other: Any) -> bool:
        """Return the equality of two instances of the Schema class."""
        if not other:
            return False

        if not isinstance(other, Schema):
            return False

        if len(self.columns) != len(other.columns):
            return False

        identifier_field_ids_is_equal = self.identifier_field_ids == other.identifier_field_ids
        schema_is_equal = all(lhs == rhs for lhs, rhs in zip(self.columns, other.columns))

        return identifier_field_ids_is_equal and schema_is_equal

    @model_validator(mode="after")
    def check_schema(self) -> Schema:
        if self.identifier_field_ids:
            for field_id in self.identifier_field_ids:
                self._validate_identifier_field(field_id)

        return self

    @property
    def columns(self) -> Tuple[NestedField, ...]:
        """A tuple of the top-level fields."""
        return self.fields

    @cached_property
    def _lazy_id_to_field(self) -> Dict[int, NestedField]:
        """Return an index of field ID to NestedField instance.

        This is calculated once when called for the first time. Subsequent calls to this method will use a cached index.
        """
        return index_by_id(self)

    @cached_property
    def _lazy_id_to_parent(self) -> Dict[int, int]:
        """Returns an index of field ID to parent field IDs.

        This is calculated once when called for the first time. Subsequent calls to this method will use a cached index.
        """
        return _index_parents(self)

    @cached_property
    def _lazy_name_to_id_lower(self) -> Dict[str, int]:
        """Return an index of lower-case field names to field IDs.

        This is calculated once when called for the first time. Subsequent calls to this method will use a cached index.
        """
        return {name.lower(): field_id for name, field_id in self._name_to_id.items()}

    @cached_property
    def _lazy_id_to_name(self) -> Dict[int, str]:
        """Return an index of field ID to full name.

        This is calculated once when called for the first time. Subsequent calls to this method will use a cached index.
        """
        return index_name_by_id(self)

    @cached_property
    def _lazy_id_to_accessor(self) -> Dict[int, Accessor]:
        """Return an index of field ID to accessor.

        This is calculated once when called for the first time. Subsequent calls to this method will use a cached index.
        """
        return build_position_accessors(self)

    def as_struct(self) -> StructType:
        """Return the schema as a struct."""
        return StructType(*self.fields)

    def as_arrow(self) -> "pa.Schema":
        """Return the schema as an Arrow schema."""
        from pyiceberg.io.pyarrow import schema_to_pyarrow

        return schema_to_pyarrow(self)

    def find_field(self, name_or_id: Union[str, int], case_sensitive: bool = True) -> NestedField:
        """Find a field using a field name or field ID.

        Args:
            name_or_id (Union[str, int]): Either a field name or a field ID.
            case_sensitive (bool, optional): Whether to perform a case-sensitive lookup using a field name. Defaults to True.

        Raises:
            ValueError: When the value cannot be found.

        Returns:
            NestedField: The matched NestedField.
        """
        if isinstance(name_or_id, int):
            if name_or_id not in self._lazy_id_to_field:
                raise ValueError(f"Could not find field with id: {name_or_id}")
            return self._lazy_id_to_field[name_or_id]

        if case_sensitive:
            field_id = self._name_to_id.get(name_or_id)
        else:
            field_id = self._lazy_name_to_id_lower.get(name_or_id.lower())

        if field_id is None:
            raise ValueError(f"Could not find field with name {name_or_id}, case_sensitive={case_sensitive}")

        return self._lazy_id_to_field[field_id]

    def find_type(self, name_or_id: Union[str, int], case_sensitive: bool = True) -> IcebergType:
        """Find a field type using a field name or field ID.

        Args:
            name_or_id (Union[str, int]): Either a field name or a field ID.
            case_sensitive (bool, optional): Whether to perform a case-sensitive lookup using a field name. Defaults to True.

        Returns:
            NestedField: The type of the matched NestedField.
        """
        field = self.find_field(name_or_id=name_or_id, case_sensitive=case_sensitive)
        if not field:
            raise ValueError(f"Could not find field with name or id {name_or_id}, case_sensitive={case_sensitive}")
        return field.field_type

    @property
    def highest_field_id(self) -> int:
        return max(self._lazy_id_to_name.keys(), default=0)

    @cached_property
    def name_mapping(self) -> NameMapping:
        from pyiceberg.table.name_mapping import create_mapping_from_schema

        return create_mapping_from_schema(self)

    def find_column_name(self, column_id: int) -> Optional[str]:
        """Find a column name given a column ID.

        Args:
            column_id (int): The ID of the column.

        Returns:
            str: The column name (or None if the column ID cannot be found).
        """
        return self._lazy_id_to_name.get(column_id)

    @property
    def column_names(self) -> List[str]:
        """
        Return a list of all the column names, including nested fields.

        Excludes short names.

        Returns:
            List[str]: The column names.
        """
        return list(self._lazy_id_to_name.values())

    def accessor_for_field(self, field_id: int) -> Accessor:
        """Find a schema position accessor given a field ID.

        Args:
            field_id (int): The ID of the field.

        Raises:
            ValueError: When the value cannot be found.

        Returns:
            Accessor: An accessor for the given field ID.
        """
        if field_id not in self._lazy_id_to_accessor:
            raise ValueError(f"Could not find accessor for field with id: {field_id}")

        return self._lazy_id_to_accessor[field_id]

    def identifier_field_names(self) -> Set[str]:
        """Return the names of the identifier fields.

        Returns:
            Set of names of the identifier fields
        """
        ids = set()
        for field_id in self.identifier_field_ids:
            column_name = self.find_column_name(field_id)
            if column_name is None:
                raise ValueError(f"Could not find identifier column id: {field_id}")
            ids.add(column_name)

        return ids

    def select(self, *names: str, case_sensitive: bool = True) -> Schema:
        """Return a new schema instance pruned to a subset of columns.

        Args:
            names (List[str]): A list of column names.
            case_sensitive (bool, optional): Whether to perform a case-sensitive lookup for each column name. Defaults to True.

        Returns:
            Schema: A new schema with pruned columns.

        Raises:
            ValueError: If a column is selected that doesn't exist.
        """
        try:
            if case_sensitive:
                ids = {self._name_to_id[name] for name in names}
            else:
                ids = {self._lazy_name_to_id_lower[name.lower()] for name in names}
        except KeyError as e:
            raise ValueError(f"Could not find column: {e}") from e

        return prune_columns(self, ids)

    @property
    def field_ids(self) -> Set[int]:
        """Return the IDs of the current schema."""
        return set(self._name_to_id.values())

    def _validate_identifier_field(self, field_id: int) -> None:
        """Validate that the field with the given ID is a valid identifier field.

        Args:
          field_id: The ID of the field to validate.

        Raises:
          ValueError: If the field is not valid.
        """
        field = self.find_field(field_id)
        if not field.field_type.is_primitive:
            raise ValueError(f"Identifier field {field_id} invalid: not a primitive type field")

        if not field.required:
            raise ValueError(f"Identifier field {field_id} invalid: not a required field")

        if isinstance(field.field_type, (DoubleType, FloatType)):
            raise ValueError(f"Identifier field {field_id} invalid: must not be float or double field")

        # Check whether the nested field is in a chain of required struct fields
        # Exploring from root for better error message for list and map types
        parent_id = self._lazy_id_to_parent.get(field.field_id)
        fields: List[int] = []
        while parent_id is not None:
            fields.append(parent_id)
            parent_id = self._lazy_id_to_parent.get(parent_id)

        while fields:
            parent = self.find_field(fields.pop())
            if not parent.field_type.is_struct:
                raise ValueError(f"Cannot add field {field.name} as an identifier field: must not be nested in {parent}")

            if not parent.required:
                raise ValueError(
                    f"Cannot add field {field.name} as an identifier field: must not be nested in an optional field {parent}"
                )

    def check_format_version_compatibility(self, format_version: int) -> None:
        """Check that the schema is compatible for the given table format version.

        Args:
          format_version: The Iceberg table format version.

        Raises:
          ValueError: If the schema is not compatible for the format version.
        """
        for field in self._lazy_id_to_field.values():
            if format_version < field.field_type.minimum_format_version():
                raise ValueError(
                    f"{field.field_type} is only supported in {field.field_type.minimum_format_version()} or higher. Current format version is: {format_version}"
                )


class SchemaVisitor(Generic[T], ABC):
    def before_field(self, field: NestedField) -> None:
        """Override this method to perform an action immediately before visiting a field."""

    def after_field(self, field: NestedField) -> None:
        """Override this method to perform an action immediately after visiting a field."""

    def before_list_element(self, element: NestedField) -> None:
        """Override this method to perform an action immediately before visiting an element within a ListType."""
        self.before_field(element)

    def after_list_element(self, element: NestedField) -> None:
        """Override this method to perform an action immediately after visiting an element within a ListType."""
        self.after_field(element)

    def before_map_key(self, key: NestedField) -> None:
        """Override this method to perform an action immediately before visiting a key within a MapType."""
        self.before_field(key)

    def after_map_key(self, key: NestedField) -> None:
        """Override this method to perform an action immediately after visiting a key within a MapType."""
        self.after_field(key)

    def before_map_value(self, value: NestedField) -> None:
        """Override this method to perform an action immediately before visiting a value within a MapType."""
        self.before_field(value)

    def after_map_value(self, value: NestedField) -> None:
        """Override this method to perform an action immediately after visiting a value within a MapType."""
        self.after_field(value)

    @abstractmethod
    def schema(self, schema: Schema, struct_result: T) -> T:
        """Visit a Schema."""

    @abstractmethod
    def struct(self, struct: StructType, field_results: List[T]) -> T:
        """Visit a StructType."""

    @abstractmethod
    def field(self, field: NestedField, field_result: T) -> T:
        """Visit a NestedField."""

    @abstractmethod
    def list(self, list_type: ListType, element_result: T) -> T:
        """Visit a ListType."""

    @abstractmethod
    def map(self, map_type: MapType, key_result: T, value_result: T) -> T:
        """Visit a MapType."""

    @abstractmethod
    def primitive(self, primitive: PrimitiveType) -> T:
        """Visit a PrimitiveType."""


class PreOrderSchemaVisitor(Generic[T], ABC):
    @abstractmethod
    def schema(self, schema: Schema, struct_result: Callable[[], T]) -> T:
        """Visit a Schema."""

    @abstractmethod
    def struct(self, struct: StructType, field_results: List[Callable[[], T]]) -> T:
        """Visit a StructType."""

    @abstractmethod
    def field(self, field: NestedField, field_result: Callable[[], T]) -> T:
        """Visit a NestedField."""

    @abstractmethod
    def list(self, list_type: ListType, element_result: Callable[[], T]) -> T:
        """Visit a ListType."""

    @abstractmethod
    def map(self, map_type: MapType, key_result: Callable[[], T], value_result: Callable[[], T]) -> T:
        """Visit a MapType."""

    @abstractmethod
    def primitive(self, primitive: PrimitiveType) -> T:
        """Visit a PrimitiveType."""


class SchemaWithPartnerVisitor(Generic[P, T], ABC):
    def before_field(self, field: NestedField, field_partner: Optional[P]) -> None:
        """Override this method to perform an action immediately before visiting a field."""

    def after_field(self, field: NestedField, field_partner: Optional[P]) -> None:
        """Override this method to perform an action immediately after visiting a field."""

    def before_list_element(self, element: NestedField, element_partner: Optional[P]) -> None:
        """Override this method to perform an action immediately before visiting an element within a ListType."""
        self.before_field(element, element_partner)

    def after_list_element(self, element: NestedField, element_partner: Optional[P]) -> None:
        """Override this method to perform an action immediately after visiting an element within a ListType."""
        self.after_field(element, element_partner)

    def before_map_key(self, key: NestedField, key_partner: Optional[P]) -> None:
        """Override this method to perform an action immediately before visiting a key within a MapType."""
        self.before_field(key, key_partner)

    def after_map_key(self, key: NestedField, key_partner: Optional[P]) -> None:
        """Override this method to perform an action immediately after visiting a key within a MapType."""
        self.after_field(key, key_partner)

    def before_map_value(self, value: NestedField, value_partner: Optional[P]) -> None:
        """Override this method to perform an action immediately before visiting a value within a MapType."""
        self.before_field(value, value_partner)

    def after_map_value(self, value: NestedField, value_partner: Optional[P]) -> None:
        """Override this method to perform an action immediately after visiting a value within a MapType."""
        self.after_field(value, value_partner)

    @abstractmethod
    def schema(self, schema: Schema, schema_partner: Optional[P], struct_result: T) -> T:
        """Visit a schema with a partner."""

    @abstractmethod
    def struct(self, struct: StructType, struct_partner: Optional[P], field_results: List[T]) -> T:
        """Visit a struct type with a partner."""

    @abstractmethod
    def field(self, field: NestedField, field_partner: Optional[P], field_result: T) -> T:
        """Visit a nested field with a partner."""

    @abstractmethod
    def list(self, list_type: ListType, list_partner: Optional[P], element_result: T) -> T:
        """Visit a list type with a partner."""

    @abstractmethod
    def map(self, map_type: MapType, map_partner: Optional[P], key_result: T, value_result: T) -> T:
        """Visit a map type with a partner."""

    @abstractmethod
    def primitive(self, primitive: PrimitiveType, primitive_partner: Optional[P]) -> T:
        """Visit a primitive type with a partner."""


class PrimitiveWithPartnerVisitor(SchemaWithPartnerVisitor[P, T]):
    def primitive(self, primitive: PrimitiveType, primitive_partner: Optional[P]) -> T:
        """Visit a PrimitiveType."""
        if isinstance(primitive, BooleanType):
            return self.visit_boolean(primitive, primitive_partner)
        elif isinstance(primitive, IntegerType):
            return self.visit_integer(primitive, primitive_partner)
        elif isinstance(primitive, LongType):
            return self.visit_long(primitive, primitive_partner)
        elif isinstance(primitive, FloatType):
            return self.visit_float(primitive, primitive_partner)
        elif isinstance(primitive, DoubleType):
            return self.visit_double(primitive, primitive_partner)
        elif isinstance(primitive, DecimalType):
            return self.visit_decimal(primitive, primitive_partner)
        elif isinstance(primitive, DateType):
            return self.visit_date(primitive, primitive_partner)
        elif isinstance(primitive, TimeType):
            return self.visit_time(primitive, primitive_partner)
        elif isinstance(primitive, TimestampType):
            return self.visit_timestamp(primitive, primitive_partner)
        elif isinstance(primitive, TimestampNanoType):
            return self.visit_timestamp_ns(primitive, primitive_partner)
        elif isinstance(primitive, TimestamptzType):
            return self.visit_timestamptz(primitive, primitive_partner)
        elif isinstance(primitive, TimestamptzNanoType):
            return self.visit_timestamptz_ns(primitive, primitive_partner)
        elif isinstance(primitive, StringType):
            return self.visit_string(primitive, primitive_partner)
        elif isinstance(primitive, UUIDType):
            return self.visit_uuid(primitive, primitive_partner)
        elif isinstance(primitive, FixedType):
            return self.visit_fixed(primitive, primitive_partner)
        elif isinstance(primitive, BinaryType):
            return self.visit_binary(primitive, primitive_partner)
        elif isinstance(primitive, UnknownType):
            return self.visit_unknown(primitive, primitive_partner)
        else:
            raise ValueError(f"Type not recognized: {primitive}")

    @abstractmethod
    def visit_boolean(self, boolean_type: BooleanType, partner: Optional[P]) -> T:
        """Visit a BooleanType."""

    @abstractmethod
    def visit_integer(self, integer_type: IntegerType, partner: Optional[P]) -> T:
        """Visit a IntegerType."""

    @abstractmethod
    def visit_long(self, long_type: LongType, partner: Optional[P]) -> T:
        """Visit a LongType."""

    @abstractmethod
    def visit_float(self, float_type: FloatType, partner: Optional[P]) -> T:
        """Visit a FloatType."""

    @abstractmethod
    def visit_double(self, double_type: DoubleType, partner: Optional[P]) -> T:
        """Visit a DoubleType."""

    @abstractmethod
    def visit_decimal(self, decimal_type: DecimalType, partner: Optional[P]) -> T:
        """Visit a DecimalType."""

    @abstractmethod
    def visit_date(self, date_type: DateType, partner: Optional[P]) -> T:
        """Visit a DecimalType."""

    @abstractmethod
    def visit_time(self, time_type: TimeType, partner: Optional[P]) -> T:
        """Visit a DecimalType."""

    @abstractmethod
    def visit_timestamp(self, timestamp_type: TimestampType, partner: Optional[P]) -> T:
        """Visit a TimestampType."""

    @abstractmethod
    def visit_timestamp_ns(self, timestamp_ns_type: TimestampNanoType, partner: Optional[P]) -> T:
        """Visit a TimestampNanoType."""

    @abstractmethod
    def visit_timestamptz(self, timestamptz_type: TimestamptzType, partner: Optional[P]) -> T:
        """Visit a TimestamptzType."""

    @abstractmethod
    def visit_timestamptz_ns(self, timestamptz_ns_type: TimestamptzNanoType, partner: Optional[P]) -> T:
        """Visit a TimestamptzNanoType."""

    @abstractmethod
    def visit_string(self, string_type: StringType, partner: Optional[P]) -> T:
        """Visit a StringType."""

    @abstractmethod
    def visit_uuid(self, uuid_type: UUIDType, partner: Optional[P]) -> T:
        """Visit a UUIDType."""

    @abstractmethod
    def visit_fixed(self, fixed_type: FixedType, partner: Optional[P]) -> T:
        """Visit a FixedType."""

    @abstractmethod
    def visit_binary(self, binary_type: BinaryType, partner: Optional[P]) -> T:
        """Visit a BinaryType."""

    @abstractmethod
    def visit_unknown(self, unknown_type: UnknownType, partner: Optional[P]) -> T:
        """Visit a UnknownType."""


class PartnerAccessor(Generic[P], ABC):
    @abstractmethod
    def schema_partner(self, partner: Optional[P]) -> Optional[P]:
        """Return the equivalent of the schema as a struct."""

    @abstractmethod
    def field_partner(self, partner_struct: Optional[P], field_id: int, field_name: str) -> Optional[P]:
        """Return the equivalent struct field by name or id in the partner struct."""

    @abstractmethod
    def list_element_partner(self, partner_list: Optional[P]) -> Optional[P]:
        """Return the equivalent list element in the partner list."""

    @abstractmethod
    def map_key_partner(self, partner_map: Optional[P]) -> Optional[P]:
        """Return the equivalent map key in the partner map."""

    @abstractmethod
    def map_value_partner(self, partner_map: Optional[P]) -> Optional[P]:
        """Return the equivalent map value in the partner map."""


@singledispatch
def visit_with_partner(
    schema_or_type: Union[Schema, IcebergType], partner: P, visitor: SchemaWithPartnerVisitor[T, P], accessor: PartnerAccessor[P]
) -> T:
    raise ValueError(f"Unsupported type: {schema_or_type}")


@visit_with_partner.register(Schema)
def _(schema: Schema, partner: P, visitor: SchemaWithPartnerVisitor[P, T], accessor: PartnerAccessor[P]) -> T:
    struct_partner = accessor.schema_partner(partner)
    return visitor.schema(schema, partner, visit_with_partner(schema.as_struct(), struct_partner, visitor, accessor))  # type: ignore


@visit_with_partner.register(StructType)
def _(struct: StructType, partner: P, visitor: SchemaWithPartnerVisitor[P, T], accessor: PartnerAccessor[P]) -> T:
    field_results = []
    for field in struct.fields:
        field_partner = accessor.field_partner(partner, field.field_id, field.name)
        visitor.before_field(field, field_partner)
        try:
            field_result = visit_with_partner(field.field_type, field_partner, visitor, accessor)  # type: ignore
            field_results.append(visitor.field(field, field_partner, field_result))
        finally:
            visitor.after_field(field, field_partner)

    return visitor.struct(struct, partner, field_results)


@visit_with_partner.register(ListType)
def _(list_type: ListType, partner: P, visitor: SchemaWithPartnerVisitor[P, T], accessor: PartnerAccessor[P]) -> T:
    element_partner = accessor.list_element_partner(partner)
    visitor.before_list_element(list_type.element_field, element_partner)
    try:
        element_result = visit_with_partner(list_type.element_type, element_partner, visitor, accessor)  # type: ignore
    finally:
        visitor.after_list_element(list_type.element_field, element_partner)

    return visitor.list(list_type, partner, element_result)


@visit_with_partner.register(MapType)
def _(map_type: MapType, partner: P, visitor: SchemaWithPartnerVisitor[P, T], accessor: PartnerAccessor[P]) -> T:
    key_partner = accessor.map_key_partner(partner)
    visitor.before_map_key(map_type.key_field, key_partner)
    try:
        key_result = visit_with_partner(map_type.key_type, key_partner, visitor, accessor)  # type: ignore
    finally:
        visitor.after_map_key(map_type.key_field, key_partner)

    value_partner = accessor.map_value_partner(partner)
    visitor.before_map_value(map_type.value_field, value_partner)
    try:
        value_result = visit_with_partner(map_type.value_type, value_partner, visitor, accessor)  # type: ignore
    finally:
        visitor.after_map_value(map_type.value_field, value_partner)
    return visitor.map(map_type, partner, key_result, value_result)


@visit_with_partner.register(PrimitiveType)
def _(primitive: PrimitiveType, partner: P, visitor: SchemaWithPartnerVisitor[P, T], _: PartnerAccessor[P]) -> T:
    return visitor.primitive(primitive, partner)


class SchemaVisitorPerPrimitiveType(SchemaVisitor[T], ABC):
    def primitive(self, primitive: PrimitiveType) -> T:
        """Visit a PrimitiveType."""
        if isinstance(primitive, FixedType):
            return self.visit_fixed(primitive)
        elif isinstance(primitive, DecimalType):
            return self.visit_decimal(primitive)
        elif isinstance(primitive, BooleanType):
            return self.visit_boolean(primitive)
        elif isinstance(primitive, IntegerType):
            return self.visit_integer(primitive)
        elif isinstance(primitive, LongType):
            return self.visit_long(primitive)
        elif isinstance(primitive, FloatType):
            return self.visit_float(primitive)
        elif isinstance(primitive, DoubleType):
            return self.visit_double(primitive)
        elif isinstance(primitive, DateType):
            return self.visit_date(primitive)
        elif isinstance(primitive, TimeType):
            return self.visit_time(primitive)
        elif isinstance(primitive, TimestampType):
            return self.visit_timestamp(primitive)
        elif isinstance(primitive, TimestampNanoType):
            return self.visit_timestamp_ns(primitive)
        elif isinstance(primitive, TimestamptzType):
            return self.visit_timestamptz(primitive)
        elif isinstance(primitive, TimestamptzNanoType):
            return self.visit_timestamptz_ns(primitive)
        elif isinstance(primitive, StringType):
            return self.visit_string(primitive)
        elif isinstance(primitive, UUIDType):
            return self.visit_uuid(primitive)
        elif isinstance(primitive, BinaryType):
            return self.visit_binary(primitive)
        elif isinstance(primitive, UnknownType):
            return self.visit_unknown(primitive)
        else:
            raise ValueError(f"Type not recognized: {primitive}")

    @abstractmethod
    def visit_fixed(self, fixed_type: FixedType) -> T:
        """Visit a FixedType."""

    @abstractmethod
    def visit_decimal(self, decimal_type: DecimalType) -> T:
        """Visit a DecimalType."""

    @abstractmethod
    def visit_boolean(self, boolean_type: BooleanType) -> T:
        """Visit a BooleanType."""

    @abstractmethod
    def visit_integer(self, integer_type: IntegerType) -> T:
        """Visit a IntegerType."""

    @abstractmethod
    def visit_long(self, long_type: LongType) -> T:
        """Visit a LongType."""

    @abstractmethod
    def visit_float(self, float_type: FloatType) -> T:
        """Visit a FloatType."""

    @abstractmethod
    def visit_double(self, double_type: DoubleType) -> T:
        """Visit a DoubleType."""

    @abstractmethod
    def visit_date(self, date_type: DateType) -> T:
        """Visit a DecimalType."""

    @abstractmethod
    def visit_time(self, time_type: TimeType) -> T:
        """Visit a DecimalType."""

    @abstractmethod
    def visit_timestamp(self, timestamp_type: TimestampType) -> T:
        """Visit a TimestampType."""

    @abstractmethod
    def visit_timestamp_ns(self, timestamp_type: TimestampNanoType) -> T:
        """Visit a TimestampNanoType."""

    @abstractmethod
    def visit_timestamptz(self, timestamptz_type: TimestamptzType) -> T:
        """Visit a TimestamptzType."""

    @abstractmethod
    def visit_timestamptz_ns(self, timestamptz_ns_type: TimestamptzNanoType) -> T:
        """Visit a TimestamptzNanoType."""

    @abstractmethod
    def visit_string(self, string_type: StringType) -> T:
        """Visit a StringType."""

    @abstractmethod
    def visit_uuid(self, uuid_type: UUIDType) -> T:
        """Visit a UUIDType."""

    @abstractmethod
    def visit_binary(self, binary_type: BinaryType) -> T:
        """Visit a BinaryType."""

    @abstractmethod
    def visit_unknown(self, unknown_type: UnknownType) -> T:
        """Visit a UnknownType."""


@dataclass(init=True, eq=True, frozen=True)
class Accessor:
    """An accessor for a specific position in a container that implements the StructProtocol."""

    position: int
    inner: Optional[Accessor] = None

    def __str__(self) -> str:
        """Return the string representation of the Accessor class."""
        return f"Accessor(position={self.position},inner={self.inner})"

    def __repr__(self) -> str:
        """Return the string representation of the Accessor class."""
        return self.__str__()

    def get(self, container: StructProtocol) -> Any:
        """Return the value at self.position in `container`.

        Args:
            container (StructProtocol): A container to access at position `self.position`.

        Returns:
            Any: The value at position `self.position` in the container.
        """
        pos = self.position
        val = container[pos]
        inner = self
        while inner.inner:
            inner = inner.inner
            val = val[inner.position]

        return val


@singledispatch
def visit(obj: Union[Schema, IcebergType], visitor: SchemaVisitor[T]) -> T:
    """Apply a schema visitor to any point within a schema.

    The function traverses the schema in post-order fashion.

    Args:
        obj (Union[Schema, IcebergType]): An instance of a Schema or an IcebergType.
        visitor (SchemaVisitor[T]): An instance of an implementation of the generic SchemaVisitor base class.

    Raises:
        NotImplementedError: If attempting to visit an unrecognized object type.
    """
    raise NotImplementedError(f"Cannot visit non-type: {obj}")


@visit.register(Schema)
def _(obj: Schema, visitor: SchemaVisitor[T]) -> T:
    """Visit a Schema with a concrete SchemaVisitor."""
    return visitor.schema(obj, visit(obj.as_struct(), visitor))


@visit.register(StructType)
def _(obj: StructType, visitor: SchemaVisitor[T]) -> T:
    """Visit a StructType with a concrete SchemaVisitor."""
    results = []

    for field in obj.fields:
        visitor.before_field(field)
        result = visit(field.field_type, visitor)
        visitor.after_field(field)
        results.append(visitor.field(field, result))

    return visitor.struct(obj, results)


@visit.register(ListType)
def _(obj: ListType, visitor: SchemaVisitor[T]) -> T:
    """Visit a ListType with a concrete SchemaVisitor."""
    visitor.before_list_element(obj.element_field)
    result = visit(obj.element_type, visitor)
    visitor.after_list_element(obj.element_field)

    return visitor.list(obj, result)


@visit.register(MapType)
def _(obj: MapType, visitor: SchemaVisitor[T]) -> T:
    """Visit a MapType with a concrete SchemaVisitor."""
    visitor.before_map_key(obj.key_field)
    key_result = visit(obj.key_type, visitor)
    visitor.after_map_key(obj.key_field)

    visitor.before_map_value(obj.value_field)
    value_result = visit(obj.value_type, visitor)
    visitor.after_map_value(obj.value_field)

    return visitor.map(obj, key_result, value_result)


@visit.register(PrimitiveType)
def _(obj: PrimitiveType, visitor: SchemaVisitor[T]) -> T:
    """Visit a PrimitiveType with a concrete SchemaVisitor."""
    return visitor.primitive(obj)


@singledispatch
def pre_order_visit(obj: Union[Schema, IcebergType], visitor: PreOrderSchemaVisitor[T]) -> T:
    """Apply a schema visitor to any point within a schema.

    The function traverses the schema in pre-order fashion. This is a slimmed down version
    compared to the post-order traversal (missing before and after methods), mostly
    because we don't use the pre-order traversal much.

    Args:
        obj (Union[Schema, IcebergType]): An instance of a Schema or an IcebergType.
        visitor (PreOrderSchemaVisitor[T]): An instance of an implementation of the generic PreOrderSchemaVisitor base class.

    Raises:
        NotImplementedError: If attempting to visit an unrecognized object type.
    """
    raise NotImplementedError(f"Cannot visit non-type: {obj}")


@pre_order_visit.register(Schema)
def _(obj: Schema, visitor: PreOrderSchemaVisitor[T]) -> T:
    """Visit a Schema with a concrete PreOrderSchemaVisitor."""
    return visitor.schema(obj, lambda: pre_order_visit(obj.as_struct(), visitor))


@pre_order_visit.register(StructType)
def _(obj: StructType, visitor: PreOrderSchemaVisitor[T]) -> T:
    """Visit a StructType with a concrete PreOrderSchemaVisitor."""
    return visitor.struct(
        obj,
        [
            partial(
                lambda field: visitor.field(field, partial(lambda field: pre_order_visit(field.field_type, visitor), field)),
                field,
            )
            for field in obj.fields
        ],
    )


@pre_order_visit.register(ListType)
def _(obj: ListType, visitor: PreOrderSchemaVisitor[T]) -> T:
    """Visit a ListType with a concrete PreOrderSchemaVisitor."""
    return visitor.list(obj, lambda: pre_order_visit(obj.element_type, visitor))


@pre_order_visit.register(MapType)
def _(obj: MapType, visitor: PreOrderSchemaVisitor[T]) -> T:
    """Visit a MapType with a concrete PreOrderSchemaVisitor."""
    return visitor.map(obj, lambda: pre_order_visit(obj.key_type, visitor), lambda: pre_order_visit(obj.value_type, visitor))


@pre_order_visit.register(PrimitiveType)
def _(obj: PrimitiveType, visitor: PreOrderSchemaVisitor[T]) -> T:
    """Visit a PrimitiveType with a concrete PreOrderSchemaVisitor."""
    return visitor.primitive(obj)


class _IndexById(SchemaVisitor[Dict[int, NestedField]]):
    """A schema visitor for generating a field ID to NestedField index."""

    def __init__(self) -> None:
        self._index: Dict[int, NestedField] = {}

    def schema(self, schema: Schema, struct_result: Dict[int, NestedField]) -> Dict[int, NestedField]:
        return self._index

    def struct(self, struct: StructType, field_results: List[Dict[int, NestedField]]) -> Dict[int, NestedField]:
        return self._index

    def field(self, field: NestedField, field_result: Dict[int, NestedField]) -> Dict[int, NestedField]:
        """Add the field ID to the index."""
        self._index[field.field_id] = field
        return self._index

    def list(self, list_type: ListType, element_result: Dict[int, NestedField]) -> Dict[int, NestedField]:
        """Add the list element ID to the index."""
        self._index[list_type.element_field.field_id] = list_type.element_field
        return self._index

    def map(
        self, map_type: MapType, key_result: Dict[int, NestedField], value_result: Dict[int, NestedField]
    ) -> Dict[int, NestedField]:
        """Add the key ID and value ID as individual items in the index."""
        self._index[map_type.key_field.field_id] = map_type.key_field
        self._index[map_type.value_field.field_id] = map_type.value_field
        return self._index

    def primitive(self, primitive: PrimitiveType) -> Dict[int, NestedField]:
        return self._index


def index_by_id(schema_or_type: Union[Schema, IcebergType]) -> Dict[int, NestedField]:
    """Generate an index of field IDs to NestedField instances.

    Args:
        schema_or_type (Union[Schema, IcebergType]): A schema or type to index.

    Returns:
        Dict[int, NestedField]: An index of field IDs to NestedField instances.
    """
    return visit(schema_or_type, _IndexById())


class _IndexParents(SchemaVisitor[Dict[int, int]]):
    def __init__(self) -> None:
        self.id_to_parent: Dict[int, int] = {}
        self.id_stack: List[int] = []

    def before_field(self, field: NestedField) -> None:
        self.id_stack.append(field.field_id)

    def after_field(self, field: NestedField) -> None:
        self.id_stack.pop()

    def schema(self, schema: Schema, struct_result: Dict[int, int]) -> Dict[int, int]:
        return self.id_to_parent

    def struct(self, struct: StructType, field_results: List[Dict[int, int]]) -> Dict[int, int]:
        for field in struct.fields:
            parent_id = self.id_stack[-1] if self.id_stack else None
            if parent_id is not None:
                # fields in the root struct are not added
                self.id_to_parent[field.field_id] = parent_id

        return self.id_to_parent

    def field(self, field: NestedField, field_result: Dict[int, int]) -> Dict[int, int]:
        return self.id_to_parent

    def list(self, list_type: ListType, element_result: Dict[int, int]) -> Dict[int, int]:
        self.id_to_parent[list_type.element_id] = self.id_stack[-1]
        return self.id_to_parent

    def map(self, map_type: MapType, key_result: Dict[int, int], value_result: Dict[int, int]) -> Dict[int, int]:
        self.id_to_parent[map_type.key_id] = self.id_stack[-1]
        self.id_to_parent[map_type.value_id] = self.id_stack[-1]
        return self.id_to_parent

    def primitive(self, primitive: PrimitiveType) -> Dict[int, int]:
        return self.id_to_parent


def _index_parents(schema_or_type: Union[Schema, IcebergType]) -> Dict[int, int]:
    """Generate an index of field IDs to their parent field IDs.

    Args:
        schema_or_type (Union[Schema, IcebergType]): A schema or type to index.

    Returns:
        Dict[int, int]: An index of field IDs to their parent field IDs.
    """
    return visit(schema_or_type, _IndexParents())


class _IndexByName(SchemaVisitor[Dict[str, int]]):
    """A schema visitor for generating a field name to field ID index."""

    def __init__(self) -> None:
        self._index: Dict[str, int] = {}
        self._short_name_to_id: Dict[str, int] = {}
        self._combined_index: Dict[str, int] = {}
        self._field_names: List[str] = []
        self._short_field_names: List[str] = []

    def before_map_value(self, value: NestedField) -> None:
        if not isinstance(value.field_type, StructType):
            self._short_field_names.append(value.name)
        self._field_names.append(value.name)

    def after_map_value(self, value: NestedField) -> None:
        if not isinstance(value.field_type, StructType):
            self._short_field_names.pop()
        self._field_names.pop()

    def before_list_element(self, element: NestedField) -> None:
        """Short field names omit element when the element is a StructType."""
        if not isinstance(element.field_type, StructType):
            self._short_field_names.append(element.name)
        self._field_names.append(element.name)

    def after_list_element(self, element: NestedField) -> None:
        if not isinstance(element.field_type, StructType):
            self._short_field_names.pop()
        self._field_names.pop()

    def before_field(self, field: NestedField) -> None:
        """Store the field name."""
        self._field_names.append(field.name)
        self._short_field_names.append(field.name)

    def after_field(self, field: NestedField) -> None:
        """Remove the last field name stored."""
        self._field_names.pop()
        self._short_field_names.pop()

    def schema(self, schema: Schema, struct_result: Dict[str, int]) -> Dict[str, int]:
        return self._index

    def struct(self, struct: StructType, field_results: List[Dict[str, int]]) -> Dict[str, int]:
        return self._index

    def field(self, field: NestedField, field_result: Dict[str, int]) -> Dict[str, int]:
        """Add the field name to the index."""
        self._add_field(field.name, field.field_id)
        return self._index

    def list(self, list_type: ListType, element_result: Dict[str, int]) -> Dict[str, int]:
        """Add the list element name to the index."""
        self._add_field(list_type.element_field.name, list_type.element_field.field_id)
        return self._index

    def map(self, map_type: MapType, key_result: Dict[str, int], value_result: Dict[str, int]) -> Dict[str, int]:
        """Add the key name and value name as individual items in the index."""
        self._add_field(map_type.key_field.name, map_type.key_field.field_id)
        self._add_field(map_type.value_field.name, map_type.value_field.field_id)
        return self._index

    def _add_field(self, name: str, field_id: int) -> None:
        """Add a field name to the index, mapping its full name to its field ID.

        Args:
            name (str): The field name.
            field_id (int): The field ID.

        Raises:
            ValueError: If the field name is already contained in the index.
        """
        full_name = name

        if self._field_names:
            full_name = ".".join([".".join(self._field_names), name])

        if full_name in self._index:
            raise ValueError(f"Invalid schema, multiple fields for name {full_name}: {self._index[full_name]} and {field_id}")
        self._index[full_name] = field_id

        if self._short_field_names:
            short_name = ".".join([".".join(self._short_field_names), name])
            self._short_name_to_id[short_name] = field_id

    def primitive(self, primitive: PrimitiveType) -> Dict[str, int]:
        return self._index

    def by_name(self) -> Dict[str, int]:
        """Return an index of combined full and short names.

        Note: Only short names that do not conflict with full names are included.
        """
        combined_index = self._short_name_to_id.copy()
        combined_index.update(self._index)
        return combined_index

    def by_id(self) -> Dict[int, str]:
        """Return an index of ID to full names."""
        id_to_full_name = {value: key for key, value in self._index.items()}
        return id_to_full_name


def index_by_name(schema_or_type: Union[Schema, IcebergType]) -> Dict[str, int]:
    """Generate an index of field names to field IDs.

    Args:
        schema_or_type (Union[Schema, IcebergType]): A schema or type to index.

    Returns:
        Dict[str, int]: An index of field names to field IDs.
    """
    if len(schema_or_type.fields) > 0:
        indexer = _IndexByName()
        visit(schema_or_type, indexer)
        return indexer.by_name()
    else:
        return EMPTY_DICT


def index_name_by_id(schema_or_type: Union[Schema, IcebergType]) -> Dict[int, str]:
    """Generate an index of field IDs full field names.

    Args:
        schema_or_type (Union[Schema, IcebergType]): A schema or type to index.

    Returns:
        Dict[str, int]: An index of field IDs to full names.
    """
    indexer = _IndexByName()
    visit(schema_or_type, indexer)
    return indexer.by_id()


Position = int


class _BuildPositionAccessors(SchemaVisitor[Dict[Position, Accessor]]):
    """A schema visitor for generating a field ID to accessor index.

    Example:
        >>> from pyiceberg.schema import Schema
        >>> from pyiceberg.types import *
        >>> schema = Schema(
        ...     NestedField(field_id=2, name="id", field_type=IntegerType(), required=False),
        ...     NestedField(field_id=1, name="data", field_type=StringType(), required=True),
        ...     NestedField(
        ...         field_id=3,
        ...         name="location",
        ...         field_type=StructType(
        ...             NestedField(field_id=5, name="latitude", field_type=FloatType(), required=False),
        ...             NestedField(field_id=6, name="longitude", field_type=FloatType(), required=False),
        ...         ),
        ...         required=True,
        ...     ),
        ...     schema_id=1,
        ...     identifier_field_ids=[1],
        ... )
        >>> result = build_position_accessors(schema)
        >>> expected = {
        ...     2: Accessor(position=0, inner=None),
        ...     1: Accessor(position=1, inner=None),
        ...     5: Accessor(position=2, inner=Accessor(position=0, inner=None)),
        ...     6: Accessor(position=2, inner=Accessor(position=1, inner=None))
        ...     3: Accessor(position=2, inner=None),
        ... }
        >>> result == expected
        True
    """

    def schema(self, schema: Schema, struct_result: Dict[Position, Accessor]) -> Dict[Position, Accessor]:
        return struct_result

    def struct(self, struct: StructType, field_results: List[Dict[Position, Accessor]]) -> Dict[Position, Accessor]:
        result = {}

        for position, field in enumerate(struct.fields):
            if field_results[position]:
                for inner_field_id, acc in field_results[position].items():
                    result[inner_field_id] = Accessor(position, inner=acc)
            result[field.field_id] = Accessor(position)

        return result

    def field(self, field: NestedField, field_result: Dict[Position, Accessor]) -> Dict[Position, Accessor]:
        return field_result

    def list(self, list_type: ListType, element_result: Dict[Position, Accessor]) -> Dict[Position, Accessor]:
        return {}

    def map(
        self, map_type: MapType, key_result: Dict[Position, Accessor], value_result: Dict[Position, Accessor]
    ) -> Dict[Position, Accessor]:
        return {}

    def primitive(self, primitive: PrimitiveType) -> Dict[Position, Accessor]:
        return {}


def build_position_accessors(schema_or_type: Union[Schema, IcebergType]) -> Dict[int, Accessor]:
    """Generate an index of field IDs to schema position accessors.

    Args:
        schema_or_type (Union[Schema, IcebergType]): A schema or type to index.

    Returns:
        Dict[int, Accessor]: An index of field IDs to accessors.
    """
    return visit(schema_or_type, _BuildPositionAccessors())


def assign_fresh_schema_ids(schema_or_type: Union[Schema, IcebergType], next_id: Optional[Callable[[], int]] = None) -> Schema:
    """Traverses the schema, and sets new IDs."""
    return pre_order_visit(schema_or_type, _SetFreshIDs(next_id_func=next_id))


class _SetFreshIDs(PreOrderSchemaVisitor[IcebergType]):
    """Traverses the schema and assigns monotonically increasing ids."""

    old_id_to_new_id: Dict[int, int]

    def __init__(self, next_id_func: Optional[Callable[[], int]] = None) -> None:
        self.old_id_to_new_id = {}
        counter = itertools.count(1)
        self.next_id_func = next_id_func if next_id_func is not None else lambda: next(counter)

    def _get_and_increment(self, current_id: int) -> int:
        new_id = self.next_id_func()
        self.old_id_to_new_id[current_id] = new_id
        return new_id

    def schema(self, schema: Schema, struct_result: Callable[[], StructType]) -> Schema:
        return Schema(
            *struct_result().fields,
            identifier_field_ids=[self.old_id_to_new_id[field_id] for field_id in schema.identifier_field_ids],
        )

    def struct(self, struct: StructType, field_results: List[Callable[[], IcebergType]]) -> StructType:
        new_ids = [self._get_and_increment(field.field_id) for field in struct.fields]
        new_fields = []
        for field_id, field, field_type in zip(new_ids, struct.fields, field_results):
            new_fields.append(
                NestedField(
                    field_id=field_id,
                    name=field.name,
                    field_type=field_type(),
                    required=field.required,
                    doc=field.doc,
                )
            )
        return StructType(*new_fields)

    def field(self, field: NestedField, field_result: Callable[[], IcebergType]) -> IcebergType:
        return field_result()

    def list(self, list_type: ListType, element_result: Callable[[], IcebergType]) -> ListType:
        element_id = self._get_and_increment(list_type.element_id)
        return ListType(
            element_id=element_id,
            element=element_result(),
            element_required=list_type.element_required,
        )

    def map(self, map_type: MapType, key_result: Callable[[], IcebergType], value_result: Callable[[], IcebergType]) -> MapType:
        key_id = self._get_and_increment(map_type.key_id)
        value_id = self._get_and_increment(map_type.value_id)
        return MapType(
            key_id=key_id,
            key_type=key_result(),
            value_id=value_id,
            value_type=value_result(),
            value_required=map_type.value_required,
        )

    def primitive(self, primitive: PrimitiveType) -> PrimitiveType:
        return primitive


# Implementation copied from Apache Iceberg repo.
def make_compatible_name(name: str) -> str:
    if not _valid_avro_name(name):
        return _sanitize_name(name)
    return name


def _valid_avro_name(name: str) -> bool:
    length = len(name)
    assert length > 0, ValueError("Can not validate empty avro name")
    first = name[0]
    if not (first.isalpha() or first == "_"):
        return False

    for character in name[1:]:
        if not (character.isalnum() or character == "_"):
            return False
    return True


def _sanitize_name(name: str) -> str:
    sb = []
    first = name[0]
    if not (first.isalpha() or first == "_"):
        sb.append(_sanitize_char(first))
    else:
        sb.append(first)

    for character in name[1:]:
        if not (character.isalnum() or character == "_"):
            sb.append(_sanitize_char(character))
        else:
            sb.append(character)
    return "".join(sb)


def _sanitize_char(character: str) -> str:
    return "_" + character if character.isdigit() else "_x" + hex(ord(character))[2:].upper()


def sanitize_column_names(schema: Schema) -> Schema:
    """Sanitize column names to make them compatible with Avro.

    The column name should be starting with '_' or digit followed by a string only contains '_', digit or alphabet,
    otherwise it will be sanitized to conform the avro naming convention.

    Args:
        schema: The schema to be sanitized.

    Returns:
        The sanitized schema.
    """
    result = visit(schema.as_struct(), _SanitizeColumnsVisitor())
    return Schema(
        *(result or StructType()).fields,
        schema_id=schema.schema_id,
        identifier_field_ids=schema.identifier_field_ids,
    )


class _SanitizeColumnsVisitor(SchemaVisitor[Optional[IcebergType]]):
    def schema(self, schema: Schema, struct_result: Optional[IcebergType]) -> Optional[IcebergType]:
        return struct_result

    def field(self, field: NestedField, field_result: Optional[IcebergType]) -> Optional[IcebergType]:
        return NestedField(
            field_id=field.field_id,
            name=make_compatible_name(field.name),
            field_type=field_result,
            doc=field.doc,
            required=field.required,
        )

    def struct(self, struct: StructType, field_results: List[Optional[IcebergType]]) -> Optional[IcebergType]:
        return StructType(*[field for field in field_results if field is not None])

    def list(self, list_type: ListType, element_result: Optional[IcebergType]) -> Optional[IcebergType]:
        return ListType(element_id=list_type.element_id, element_type=element_result, element_required=list_type.element_required)

    def map(
        self, map_type: MapType, key_result: Optional[IcebergType], value_result: Optional[IcebergType]
    ) -> Optional[IcebergType]:
        return MapType(
            key_id=map_type.key_id,
            value_id=map_type.value_id,
            key_type=key_result,
            value_type=value_result,
            value_required=map_type.value_required,
        )

    def primitive(self, primitive: PrimitiveType) -> Optional[IcebergType]:
        return primitive


def prune_columns(schema: Schema, selected: Set[int], select_full_types: bool = True) -> Schema:
    """Prunes a column by only selecting a set of field-ids.

    Args:
        schema: The schema to be pruned.
        selected: The field-ids to be included.
        select_full_types: Return the full struct when a subset is recorded

    Returns:
        The pruned schema.
    """
    result = visit(schema.as_struct(), _PruneColumnsVisitor(selected, select_full_types))
    return Schema(
        *(result or StructType()).fields,
        schema_id=schema.schema_id,
        identifier_field_ids=list(selected.intersection(schema.identifier_field_ids)),
    )


class _PruneColumnsVisitor(SchemaVisitor[Optional[IcebergType]]):
    selected: Set[int]
    select_full_types: bool

    def __init__(self, selected: Set[int], select_full_types: bool):
        self.selected = selected
        self.select_full_types = select_full_types

    def schema(self, schema: Schema, struct_result: Optional[IcebergType]) -> Optional[IcebergType]:
        return struct_result

    def struct(self, struct: StructType, field_results: List[Optional[IcebergType]]) -> Optional[IcebergType]:
        fields = struct.fields
        selected_fields = []
        same_type = True

        for idx, projected_type in enumerate(field_results):
            field = fields[idx]
            if field.field_type == projected_type:
                selected_fields.append(field)
            elif projected_type is not None:
                same_type = False
                # Type has changed, create a new field with the projected type
                selected_fields.append(
                    NestedField(
                        field_id=field.field_id,
                        name=field.name,
                        field_type=projected_type,
                        doc=field.doc,
                        required=field.required,
                    )
                )

        if selected_fields:
            if len(selected_fields) == len(fields) and same_type is True:
                # Nothing has changed, and we can return the original struct
                return struct
            else:
                return StructType(*selected_fields)
        return None

    def field(self, field: NestedField, field_result: Optional[IcebergType]) -> Optional[IcebergType]:
        if field.field_id in self.selected:
            if self.select_full_types:
                return field.field_type
            elif field.field_type.is_struct:
                return self._project_selected_struct(field_result)
            else:
                if not field.field_type.is_primitive:
                    raise ValueError(
                        f"Cannot explicitly project List or Map types, {field.field_id}:{field.name} of type {field.field_type} was selected"
                    )
                # Selected non-struct field
                return field.field_type
        elif field_result is not None:
            # This field wasn't selected but a subfield was so include that
            return field_result
        else:
            return None

    def list(self, list_type: ListType, element_result: Optional[IcebergType]) -> Optional[IcebergType]:
        if list_type.element_id in self.selected:
            if self.select_full_types:
                return list_type
            elif list_type.element_type and list_type.element_type.is_struct:
                projected_struct = self._project_selected_struct(element_result)
                return self._project_list(list_type, projected_struct)
            else:
                if not list_type.element_type.is_primitive:
                    raise ValueError(
                        f"Cannot explicitly project List or Map types, {list_type.element_id} of type {list_type.element_type} was selected"
                    )
                return list_type
        elif element_result is not None:
            return self._project_list(list_type, element_result)
        else:
            return None

    def map(
        self, map_type: MapType, key_result: Optional[IcebergType], value_result: Optional[IcebergType]
    ) -> Optional[IcebergType]:
        if map_type.value_id in self.selected:
            if self.select_full_types:
                return map_type
            elif map_type.value_type and map_type.value_type.is_struct:
                projected_struct = self._project_selected_struct(value_result)
                return self._project_map(map_type, projected_struct)
            if not map_type.value_type.is_primitive:
                raise ValueError(
                    f"Cannot explicitly project List or Map types, Map value {map_type.value_id} of type {map_type.value_type} was selected"
                )
            return map_type
        elif value_result is not None:
            return self._project_map(map_type, value_result)
        elif map_type.key_id in self.selected:
            return map_type
        return None

    def primitive(self, primitive: PrimitiveType) -> Optional[IcebergType]:
        return None

    @staticmethod
    def _project_selected_struct(projected_field: Optional[IcebergType]) -> StructType:
        if projected_field and not isinstance(projected_field, StructType):
            raise ValueError("Expected a struct")

        if projected_field is None:
            return StructType()
        else:
            return projected_field

    @staticmethod
    def _project_list(list_type: ListType, element_result: IcebergType) -> ListType:
        if list_type.element_type == element_result:
            return list_type
        else:
            return ListType(
                element_id=list_type.element_id, element_type=element_result, element_required=list_type.element_required
            )

    @staticmethod
    def _project_map(map_type: MapType, value_result: IcebergType) -> MapType:
        if map_type.value_type == value_result:
            return map_type
        else:
            return MapType(
                key_id=map_type.key_id,
                value_id=map_type.value_id,
                key_type=map_type.key_type,
                value_type=value_result,
                value_required=map_type.value_required,
            )


@singledispatch
def promote(file_type: IcebergType, read_type: IcebergType) -> IcebergType:
    """Promotes reading a file type to a read type.

    Args:
        file_type (IcebergType): The type of the Avro file.
        read_type (IcebergType): The requested read type.

    Raises:
        ResolveError: If attempting to resolve an unrecognized object type.
    """
    if file_type == read_type:
        return file_type
    else:
        raise ResolveError(f"Cannot promote {file_type} to {read_type}")


@promote.register(IntegerType)
def _(file_type: IntegerType, read_type: IcebergType) -> IcebergType:
    if isinstance(read_type, LongType):
        # Ints/Longs are binary compatible in Avro, so this is okay
        return read_type
    else:
        raise ResolveError(f"Cannot promote an int to {read_type}")


@promote.register(FloatType)
def _(file_type: FloatType, read_type: IcebergType) -> IcebergType:
    if isinstance(read_type, DoubleType):
        # A double type is wider
        return read_type
    else:
        raise ResolveError(f"Cannot promote an float to {read_type}")


@promote.register(StringType)
def _(file_type: StringType, read_type: IcebergType) -> IcebergType:
    if isinstance(read_type, BinaryType):
        return read_type
    else:
        raise ResolveError(f"Cannot promote an string to {read_type}")


@promote.register(BinaryType)
def _(file_type: BinaryType, read_type: IcebergType) -> IcebergType:
    if isinstance(read_type, StringType):
        return read_type
    else:
        raise ResolveError(f"Cannot promote an binary to {read_type}")


@promote.register(DecimalType)
def _(file_type: DecimalType, read_type: IcebergType) -> IcebergType:
    if isinstance(read_type, DecimalType):
        if file_type.precision <= read_type.precision and file_type.scale == file_type.scale:
            return read_type
        else:
            raise ResolveError(f"Cannot reduce precision from {file_type} to {read_type}")
    else:
        raise ResolveError(f"Cannot promote an decimal to {read_type}")


@promote.register(FixedType)
def _(file_type: FixedType, read_type: IcebergType) -> IcebergType:
    if isinstance(read_type, UUIDType) and len(file_type) == 16:
        # Since pyarrow reads parquet UUID as fixed 16-byte binary, the promotion is needed to ensure read compatibility
        return read_type
    else:
        raise ResolveError(f"Cannot promote {file_type} to {read_type}")


def _check_schema_compatible(requested_schema: Schema, provided_schema: Schema) -> None:
    """
    Check if the `provided_schema` is compatible with `requested_schema`.

    Both Schemas must have valid IDs and share the same ID for the same field names.

    Two schemas are considered compatible when:
    1. All `required` fields in `requested_schema` are present and are also `required` in the `provided_schema`
    2. Field Types are consistent for fields that are present in both schemas. I.e. the field type
       in the `provided_schema` can be promoted to the field type of the same field ID in `requested_schema`

    Raises:
        ValueError: If the schemas are not compatible.
    """
    pre_order_visit(requested_schema, _SchemaCompatibilityVisitor(provided_schema))


class _SchemaCompatibilityVisitor(PreOrderSchemaVisitor[bool]):
    provided_schema: Schema

    def __init__(self, provided_schema: Schema):
        from rich.console import Console
        from rich.table import Table as RichTable

        self.provided_schema = provided_schema
        self.rich_table = RichTable(show_header=True, header_style="bold")
        self.rich_table.add_column("")
        self.rich_table.add_column("Table field")
        self.rich_table.add_column("Dataframe field")
        self.console = Console(record=True)

    def _is_field_compatible(self, lhs: NestedField) -> bool:
        # Validate nullability first.
        # An optional field can be missing in the provided schema
        # But a required field must exist as a required field
        try:
            rhs = self.provided_schema.find_field(lhs.field_id)
        except ValueError:
            if lhs.required:
                self.rich_table.add_row("❌", str(lhs), "Missing")
                return False
            else:
                self.rich_table.add_row("✅", str(lhs), "Missing")
                return True

        if lhs.required and not rhs.required:
            self.rich_table.add_row("❌", str(lhs), str(rhs))
            return False

        # Check type compatibility
        if lhs.field_type == rhs.field_type:
            self.rich_table.add_row("✅", str(lhs), str(rhs))
            return True
        # We only check that the parent node is also of the same type.
        # We check the type of the child nodes when we traverse them later.
        elif any(
            (isinstance(lhs.field_type, container_type) and isinstance(rhs.field_type, container_type))
            for container_type in {StructType, MapType, ListType}
        ):
            self.rich_table.add_row("✅", str(lhs), str(rhs))
            return True
        else:
            try:
                # If type can be promoted to the requested schema
                # it is considered compatible
                promote(rhs.field_type, lhs.field_type)
                self.rich_table.add_row("✅", str(lhs), str(rhs))
                return True
            except ResolveError:
                self.rich_table.add_row("❌", str(lhs), str(rhs))
                return False

    def schema(self, schema: Schema, struct_result: Callable[[], bool]) -> bool:
        if not (result := struct_result()):
            self.console.print(self.rich_table)
            raise ValueError(f"Mismatch in fields:\n{self.console.export_text()}")
        return result

    def struct(self, struct: StructType, field_results: List[Callable[[], bool]]) -> bool:
        results = [result() for result in field_results]
        return all(results)

    def field(self, field: NestedField, field_result: Callable[[], bool]) -> bool:
        return self._is_field_compatible(field) and field_result()

    def list(self, list_type: ListType, element_result: Callable[[], bool]) -> bool:
        return self._is_field_compatible(list_type.element_field) and element_result()

    def map(self, map_type: MapType, key_result: Callable[[], bool], value_result: Callable[[], bool]) -> bool:
        return all(
            [
                self._is_field_compatible(map_type.key_field),
                self._is_field_compatible(map_type.value_field),
                key_result(),
                value_result(),
            ]
        )

    def primitive(self, primitive: PrimitiveType) -> bool:
        return True
