#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

# mypy: disable-error-code="operator"

from pyspark.resource import ResourceProfile
from pyspark.sql.connect.utils import check_dependencies

check_dependencies(__name__)

from typing import (
    Any,
    List,
    Optional,
    Type,
    Sequence,
    Union,
    cast,
    TYPE_CHECKING,
    Mapping,
    Dict,
    Tuple,
)
import functools
import json
import pickle
from threading import Lock
from inspect import signature, isclass

import pyarrow as pa

from pyspark.serializers import CloudPickleSerializer
from pyspark.storagelevel import StorageLevel
from pyspark.sql.types import DataType

import pyspark.sql.connect.proto as proto
from pyspark.sql.column import Column
from pyspark.sql.connect.logging import logger
from pyspark.sql.connect.proto import base_pb2 as spark_dot_connect_dot_base__pb2
from pyspark.sql.connect.conversion import storage_level_to_proto
from pyspark.sql.connect.expressions import Expression, SubqueryExpression
from pyspark.sql.connect.types import pyspark_types_to_proto_types, UnparsedDataType
from pyspark.errors import (
    AnalysisException,
    PySparkValueError,
    PySparkPicklingError,
)

if TYPE_CHECKING:
    from pyspark.sql.connect.client import SparkConnectClient
    from pyspark.sql.connect.udf import UserDefinedFunction
    from pyspark.sql.connect.observation import Observation
    from pyspark.sql.connect.session import SparkSession


class LogicalPlan:
    _lock: Lock = Lock()
    _nextPlanId: int = 0

    INDENT = 2

    def __init__(
        self, child: Optional["LogicalPlan"], references: Optional[Sequence["LogicalPlan"]] = None
    ) -> None:
        """

        Parameters
        ----------
        child : :class:`LogicalPlan`, optional.
            The child logical plan.
        references : list of :class:`LogicalPlan`, optional.
            The list of logical plans that are referenced as subqueries in this logical plan.
        """
        self._child = child
        self._root_plan_id = LogicalPlan._fresh_plan_id()

        self._references: Sequence["LogicalPlan"] = references or []
        self._plan_id_with_rel: Optional[int] = None
        if len(self._references) > 0:
            assert all(isinstance(r, LogicalPlan) for r in self._references)
            self._plan_id_with_rel = LogicalPlan._fresh_plan_id()

    @property
    def _plan_id(self) -> int:
        return self._plan_id_with_rel or self._root_plan_id

    @staticmethod
    def _fresh_plan_id() -> int:
        plan_id: Optional[int] = None
        with LogicalPlan._lock:
            plan_id = LogicalPlan._nextPlanId
            LogicalPlan._nextPlanId += 1

        assert plan_id is not None
        return plan_id

    def _create_proto_relation(self) -> proto.Relation:
        plan = proto.Relation()
        plan.common.plan_id = self._root_plan_id
        return plan

    def plan(self, session: "SparkConnectClient") -> proto.Relation:  # type: ignore[empty-body]
        ...

    def command(self, session: "SparkConnectClient") -> proto.Command:  # type: ignore[empty-body]
        ...

    def _verify(self, session: "SparkConnectClient") -> bool:
        """This method is used to verify that the current logical plan
        can be serialized to Proto and back and afterwards is identical."""
        plan = proto.Plan()
        plan.root.CopyFrom(self.plan(session))

        serialized_plan = plan.SerializeToString()
        test_plan = proto.Plan()
        test_plan.ParseFromString(serialized_plan)

        return test_plan == plan

    def to_proto(self, session: "SparkConnectClient", debug: bool = False) -> proto.Plan:
        """
        Generates connect proto plan based on this LogicalPlan.

        Parameters
        ----------
        session : :class:`SparkConnectClient`, optional.
            a session that connects remote spark cluster.
        debug: bool
            if enabled, the proto plan will be printed.
        """
        plan = proto.Plan()
        plan.root.CopyFrom(self.plan(session))

        if debug:
            print(plan)

        return plan

    @property
    def observations(self) -> Dict[str, "Observation"]:
        if self._child is None:
            return {}
        else:
            return self._child.observations

    @staticmethod
    def _collect_references(
        cols_or_exprs: Sequence[Union[Column, Expression]]
    ) -> Sequence["LogicalPlan"]:
        references: List[LogicalPlan] = []

        def append_reference(e: Expression) -> None:
            if isinstance(e, SubqueryExpression):
                references.append(e._plan)

        for col_or_expr in cols_or_exprs:
            if isinstance(col_or_expr, Column):
                col_or_expr._expr.foreach(append_reference)
            else:
                col_or_expr.foreach(append_reference)
        return references

    def _with_relations(
        self, root: proto.Relation, session: "SparkConnectClient"
    ) -> proto.Relation:
        if len(self._references) == 0:
            return root
        else:
            # When there are references to other DataFrame, e.g., subqueries, build new plan like:
            # with_relations [id 10]
            #     root: plan  [id 9]
            #     reference:
            #          refs#1: [id 8]
            #          refs#2: [id 5]
            plan = proto.Relation()
            assert isinstance(self._plan_id_with_rel, int)
            plan.common.plan_id = self._plan_id_with_rel
            plan.with_relations.root.CopyFrom(root)
            plan.with_relations.references.extend([ref.plan(session) for ref in self._references])
            return plan

    def _parameters_to_print(self, parameters: Mapping[str, Any]) -> Mapping[str, Any]:
        """
        Extracts the parameters that are able to be printed. It looks up the signature
        in the constructor of this :class:`LogicalPlan`, and retrieves the variables
        from this instance by the same name (or the name with prefix `_`)  defined
        in the constructor.

        Parameters
        ----------
        parameters : map
            Parameter mapping from ``inspect.signature(...).parameters``

        Returns
        -------
        dict
            A dictionary consisting of a string name and variable found in this
            :class:`LogicalPlan`.

        Notes
        -----
        :class:`LogicalPlan` itself is filtered out and considered as a non-printable
        parameter.

        Examples
        --------
        The example below returns a dictionary from `self._start`, `self._end`,
        `self._num_partitions`.

        >>> rg = Range(0, 10, 1)
        >>> rg._parameters_to_print(signature(rg.__class__.__init__).parameters)
        {'start': 0, 'end': 10, 'step': 1, 'num_partitions': None}

        If the child is defined, it is not considered as a printable instance

        >>> project = Project(rg, "value")
        >>> project._parameters_to_print(signature(project.__class__.__init__).parameters)
        {'columns': ['value']}
        """
        params = {}
        for name, tpe in parameters.items():
            # LogicalPlan is not to print, e.g., LogicalPlan
            is_logical_plan = isclass(tpe.annotation) and isinstance(tpe.annotation, LogicalPlan)
            # Look up the string argument defined as a forward reference e.g., "LogicalPlan"
            is_forwardref_logical_plan = getattr(tpe.annotation, "__forward_arg__", "").endswith(
                "LogicalPlan"
            )
            # Wrapped LogicalPlan, e.g., Optional[LogicalPlan]
            is_nested_logical_plan = any(
                isclass(a) and issubclass(a, LogicalPlan)
                for a in getattr(tpe.annotation, "__args__", ())
            )
            # Wrapped forward reference of LogicalPlan, e.g., Optional["LogicalPlan"].
            is_nested_forwardref_logical_plan = any(
                getattr(a, "__forward_arg__", "").endswith("LogicalPlan")
                for a in getattr(tpe.annotation, "__args__", ())
            )

            if (
                not is_logical_plan
                and not is_forwardref_logical_plan
                and not is_nested_logical_plan
                and not is_nested_forwardref_logical_plan
            ):
                # Searches self.name or self._name
                try:
                    params[name] = getattr(self, name)
                except AttributeError:
                    try:
                        params[name] = getattr(self, "_" + name)
                    except AttributeError:
                        pass  # Simply ignore
        return params

    def print(self, indent: int = 0) -> str:
        """
        Print the simple string representation of the current :class:`LogicalPlan`.

        Parameters
        ----------
        indent : int
            The number of leading spaces for the output string.

        Returns
        -------
        str
            Simple string representation of this :class:`LogicalPlan`.
        """
        params = self._parameters_to_print(signature(self.__class__.__init__).parameters)
        pretty_params = [f"{name}='{param}'" for name, param in params.items()]
        if len(pretty_params) == 0:
            pretty_str = ""
        else:
            pretty_str = " " + ", ".join(pretty_params)
        return f"{' ' * indent}<{self.__class__.__name__}{pretty_str}>\n{self._child_print(indent)}"

    def _repr_html_(self) -> str:
        """Returns a  :class:`LogicalPlan` with HTML code. This is generally called in third-party
        systems such as Jupyter.

        Returns
        -------
        str
            HTML representation of this :class:`LogicalPlan`.
        """
        params = self._parameters_to_print(signature(self.__class__.__init__).parameters)
        pretty_params = [
            f"\n              {name}: " f"{param} <br/>" for name, param in params.items()
        ]
        if len(pretty_params) == 0:
            pretty_str = ""
        else:
            pretty_str = "".join(pretty_params)
        return f"""
        <ul>
           <li>
              <b>{self.__class__.__name__}</b><br/>{pretty_str}
              {self._child_repr()}
           </li>
        </ul>
        """

    def _child_print(self, indent: int) -> str:
        return self._child.print(indent + LogicalPlan.INDENT) if self._child else ""

    def _child_repr(self) -> str:
        return self._child._repr_html_() if self._child is not None else ""


class DataSource(LogicalPlan):
    """A datasource with a format and optional a schema from which Spark reads data"""

    def __init__(
        self,
        format: Optional[str] = None,
        schema: Optional[str] = None,
        options: Optional[Mapping[str, str]] = None,
        paths: Optional[List[str]] = None,
        predicates: Optional[List[str]] = None,
        is_streaming: Optional[bool] = None,
    ) -> None:
        super().__init__(None)

        assert format is None or isinstance(format, str)
        assert schema is None or isinstance(schema, str)

        if options is not None:
            new_options = {}
            for k, v in options.items():
                if v is not None:
                    assert isinstance(k, str)
                    assert isinstance(v, str)
                    new_options[k] = v
            options = new_options

        if paths is not None:
            assert isinstance(paths, list)
            assert all(isinstance(path, str) for path in paths)

        if predicates is not None:
            assert isinstance(predicates, list)
            assert all(isinstance(predicate, str) for predicate in predicates)

        self._format = format
        self._schema = schema
        self._options = options
        self._paths = paths
        self._predicates = predicates
        self._is_streaming = is_streaming

    def plan(self, session: "SparkConnectClient") -> proto.Relation:
        plan = self._create_proto_relation()
        if self._format is not None:
            plan.read.data_source.format = self._format
        if self._schema is not None:
            plan.read.data_source.schema = self._schema
        if self._options is not None and len(self._options) > 0:
            for k, v in self._options.items():
                plan.read.data_source.options[k] = v
        if self._paths is not None and len(self._paths) > 0:
            plan.read.data_source.paths.extend(self._paths)
        if self._predicates is not None and len(self._predicates) > 0:
            plan.read.data_source.predicates.extend(self._predicates)
        if self._is_streaming is not None:
            plan.read.is_streaming = self._is_streaming
        return plan


class Read(LogicalPlan):
    def __init__(
        self,
        table_name: str,
        options: Optional[Dict[str, str]] = None,
        is_streaming: Optional[bool] = None,
    ) -> None:
        super().__init__(None)
        self.table_name = table_name
        self.options = options or {}
        self._is_streaming = is_streaming

    def plan(self, session: "SparkConnectClient") -> proto.Relation:
        plan = self._create_proto_relation()
        plan.read.named_table.unparsed_identifier = self.table_name
        if self._is_streaming is not None:
            plan.read.is_streaming = self._is_streaming
        for k, v in self.options.items():
            plan.read.named_table.options[k] = v
        return plan

    def print(self, indent: int = 0) -> str:
        return f"{' ' * indent}<Read table_name={self.table_name}>\n"


class LocalRelation(LogicalPlan):
    """Creates a LocalRelation plan object based on a PyArrow Table."""

    def __init__(
        self,
        table: Optional["pa.Table"],
        schema: Optional[str] = None,
    ) -> None:
        super().__init__(None)

        if table is None:
            assert schema is not None
        else:
            assert isinstance(table, pa.Table)

        assert schema is None or isinstance(schema, str)

        self._table = table

        self._schema = schema

    def plan(self, session: "SparkConnectClient") -> proto.Relation:
        plan = self._create_proto_relation()
        if self._table is not None:
            sink = pa.BufferOutputStream()
            with pa.ipc.new_stream(sink, self._table.schema) as writer:
                for b in self._table.to_batches():
                    writer.write_batch(b)
            plan.local_relation.data = sink.getvalue().to_pybytes()

        if self._schema is not None:
            plan.local_relation.schema = self._schema
        return plan

    def serialize(self, session: "SparkConnectClient") -> bytes:
        p = self.plan(session)
        return bytes(p.local_relation.SerializeToString())

    def print(self, indent: int = 0) -> str:
        return f"{' ' * indent}<LocalRelation>\n"

    def _repr_html_(self) -> str:
        return """
        <ul>
            <li><b>LocalRelation</b></li>
        </ul>
        """


class CachedLocalRelation(LogicalPlan):
    """Creates a CachedLocalRelation plan object based on a hash of a LocalRelation."""

    def __init__(self, hash: str) -> None:
        super().__init__(None)

        self._hash = hash

    def plan(self, session: "SparkConnectClient") -> proto.Relation:
        plan = self._create_proto_relation()
        clr = plan.cached_local_relation

        clr.hash = self._hash

        return plan

    def print(self, indent: int = 0) -> str:
        return f"{' ' * indent}<CachedLocalRelation>\n"

    def _repr_html_(self) -> str:
        return """
        <ul>
            <li><b>CachedLocalRelation</b></li>
        </ul>
        """


class ShowString(LogicalPlan):
    def __init__(
        self, child: Optional["LogicalPlan"], num_rows: int, truncate: int, vertical: bool
    ) -> None:
        super().__init__(child)
        self.num_rows = num_rows
        self.truncate = truncate
        self.vertical = vertical

    def plan(self, session: "SparkConnectClient") -> proto.Relation:
        assert self._child is not None
        plan = self._create_proto_relation()
        plan.show_string.input.CopyFrom(self._child.plan(session))
        plan.show_string.num_rows = self.num_rows
        plan.show_string.truncate = self.truncate
        plan.show_string.vertical = self.vertical
        return plan


class HtmlString(LogicalPlan):
    def __init__(self, child: Optional["LogicalPlan"], num_rows: int, truncate: int) -> None:
        super().__init__(child)
        self.num_rows = num_rows
        self.truncate = truncate

    def plan(self, session: "SparkConnectClient") -> proto.Relation:
        assert self._child is not None
        plan = self._create_proto_relation()
        plan.html_string.input.CopyFrom(self._child.plan(session))
        plan.html_string.num_rows = self.num_rows
        plan.html_string.truncate = self.truncate
        return plan


class Project(LogicalPlan):
    """Logical plan object for a projection.

    All input arguments are directly serialized into the corresponding protocol buffer
    objects. This class only provides very limited error handling and input validation.

    To be compatible with PySpark, we validate that the input arguments are all
    expressions to be able to serialize them to the server.

    """

    def __init__(
        self,
        child: Optional["LogicalPlan"],
        columns: List[Column],
    ) -> None:
        assert all(isinstance(c, Column) for c in columns)
        super().__init__(child, self._collect_references(columns))
        self._columns = columns

    def plan(self, session: "SparkConnectClient") -> proto.Relation:
        assert self._child is not None
        plan = self._create_proto_relation()
        plan.project.input.CopyFrom(self._child.plan(session))
        plan.project.expressions.extend([c.to_plan(session) for c in self._columns])

        return self._with_relations(plan, session)


class WithColumns(LogicalPlan):
    """Logical plan object for a withColumns operation."""

    def __init__(
        self,
        child: Optional["LogicalPlan"],
        columnNames: Sequence[str],
        columns: Sequence[Column],
        metadata: Optional[Sequence[str]] = None,
    ) -> None:
        assert isinstance(columnNames, list)
        assert len(columnNames) > 0
        assert all(isinstance(c, str) for c in columnNames)

        assert isinstance(columns, list)
        assert len(columns) == len(columnNames)
        assert all(isinstance(c, Column) for c in columns)

        if metadata is not None:
            assert isinstance(metadata, list)
            assert len(metadata) == len(columnNames)
            for m in metadata:
                assert isinstance(m, str)
                # validate json string
                assert m == "" or json.loads(m) is not None

        super().__init__(child, self._collect_references(columns))

        self._columnNames = columnNames
        self._columns = columns
        self._metadata = metadata

    def plan(self, session: "SparkConnectClient") -> proto.Relation:
        assert self._child is not None
        plan = self._create_proto_relation()
        plan.with_columns.input.CopyFrom(self._child.plan(session))

        for i in range(0, len(self._columnNames)):
            alias = proto.Expression.Alias()
            alias.expr.CopyFrom(self._columns[i].to_plan(session))
            alias.name.append(self._columnNames[i])
            if self._metadata is not None:
                alias.metadata = self._metadata[i]
            plan.with_columns.aliases.append(alias)

        return self._with_relations(plan, session)


class WithWatermark(LogicalPlan):
    """Logical plan object for a WithWatermark operation."""

    def __init__(self, child: Optional["LogicalPlan"], event_time: str, delay_threshold: str):
        super().__init__(child)
        self._event_time = event_time
        self._delay_threshold = delay_threshold

    def plan(self, session: "SparkConnectClient") -> proto.Relation:
        assert self._child is not None
        plan = self._create_proto_relation()
        plan.with_watermark.input.CopyFrom(self._child.plan(session))
        plan.with_watermark.event_time = self._event_time
        plan.with_watermark.delay_threshold = self._delay_threshold
        return plan


class CachedRemoteRelation(LogicalPlan):
    """Logical plan object for a DataFrame reference which represents a DataFrame that's been
    cached on the server with a given id."""

    def __init__(self, relation_id: str, spark_session: "SparkSession"):
        super().__init__(None)
        self._relation_id = relation_id
        # Needs to hold the session to make a request itself.
        self._spark_session = spark_session

    def plan(self, session: "SparkConnectClient") -> proto.Relation:
        plan = self._create_proto_relation()
        plan.cached_remote_relation.relation_id = self._relation_id
        return plan

    def __del__(self) -> None:
        session = self._spark_session
        # If session is already closed, all cached DataFrame should be released.
        if session is not None and not session.client.is_closed and self._relation_id is not None:
            try:
                command = RemoveRemoteCachedRelation(self).command(session=session.client)
                req = session.client._execute_plan_request_with_metadata()
                if session.client._user_id:
                    req.user_context.user_id = session.client._user_id
                req.plan.command.CopyFrom(command)

                for attempt in session.client._retrying():
                    with attempt:
                        # !!HACK ALERT!!
                        # unary_stream does not work on Python's exit for an unknown reasons
                        # Therefore, here we open unary_unary channel instead.
                        # See also :class:`SparkConnectServiceStub`.
                        request_serializer = (
                            spark_dot_connect_dot_base__pb2.ExecutePlanRequest.SerializeToString
                        )
                        response_deserializer = (
                            spark_dot_connect_dot_base__pb2.ExecutePlanResponse.FromString
                        )
                        channel = session.client._channel.unary_unary(
                            "/spark.connect.SparkConnectService/ExecutePlan",
                            request_serializer=request_serializer,
                            response_deserializer=response_deserializer,
                        )
                        metadata = session.client._builder.metadata()
                        channel(req, metadata=metadata)  # type: ignore[arg-type]
            except Exception as e:
                logger.warn(f"RemoveRemoteCachedRelation failed with exception: {e}.")


class Hint(LogicalPlan):
    """Logical plan object for a Hint operation."""

    def __init__(
        self,
        child: Optional["LogicalPlan"],
        name: str,
        parameters: Sequence[Column],
    ) -> None:
        assert isinstance(name, str)

        assert parameters is not None and isinstance(parameters, List)
        for param in parameters:
            assert isinstance(param, Column)

        super().__init__(child, self._collect_references(parameters))
        self._name = name
        self._parameters = parameters

    def plan(self, session: "SparkConnectClient") -> proto.Relation:
        assert self._child is not None
        plan = self._create_proto_relation()
        plan.hint.input.CopyFrom(self._child.plan(session))
        plan.hint.name = self._name
        plan.hint.parameters.extend([param.to_plan(session) for param in self._parameters])
        return self._with_relations(plan, session)


class Filter(LogicalPlan):
    def __init__(self, child: Optional["LogicalPlan"], filter: Column) -> None:
        super().__init__(child, self._collect_references([filter]))
        self.filter = filter

    def plan(self, session: "SparkConnectClient") -> proto.Relation:
        assert self._child is not None
        plan = self._create_proto_relation()
        plan.filter.input.CopyFrom(self._child.plan(session))
        plan.filter.condition.CopyFrom(self.filter.to_plan(session))
        return self._with_relations(plan, session)


class Limit(LogicalPlan):
    def __init__(self, child: Optional["LogicalPlan"], limit: int) -> None:
        super().__init__(child)
        self.limit = limit

    def plan(self, session: "SparkConnectClient") -> proto.Relation:
        assert self._child is not None
        plan = self._create_proto_relation()
        plan.limit.input.CopyFrom(self._child.plan(session))
        plan.limit.limit = self.limit
        return plan


class Tail(LogicalPlan):
    def __init__(self, child: Optional["LogicalPlan"], limit: int) -> None:
        super().__init__(child)
        self.limit = limit

    def plan(self, session: "SparkConnectClient") -> proto.Relation:
        assert self._child is not None
        plan = self._create_proto_relation()
        plan.tail.input.CopyFrom(self._child.plan(session))
        plan.tail.limit = self.limit
        return plan


class Offset(LogicalPlan):
    def __init__(self, child: Optional["LogicalPlan"], offset: int = 0) -> None:
        super().__init__(child)
        self.offset = offset

    def plan(self, session: "SparkConnectClient") -> proto.Relation:
        assert self._child is not None
        plan = self._create_proto_relation()
        plan.offset.input.CopyFrom(self._child.plan(session))
        plan.offset.offset = self.offset
        return plan


class Deduplicate(LogicalPlan):
    def __init__(
        self,
        child: Optional["LogicalPlan"],
        all_columns_as_keys: bool = False,
        column_names: Optional[List[str]] = None,
        within_watermark: bool = False,
    ) -> None:
        super().__init__(child)
        self.all_columns_as_keys = all_columns_as_keys
        self.column_names = column_names
        self.within_watermark = within_watermark

    def plan(self, session: "SparkConnectClient") -> proto.Relation:
        assert self._child is not None
        plan = self._create_proto_relation()
        plan.deduplicate.input.CopyFrom(self._child.plan(session))
        plan.deduplicate.all_columns_as_keys = self.all_columns_as_keys
        plan.deduplicate.within_watermark = self.within_watermark
        if self.column_names is not None:
            plan.deduplicate.column_names.extend(self.column_names)
        return plan


class Sort(LogicalPlan):
    def __init__(
        self,
        child: Optional["LogicalPlan"],
        columns: List[Column],
        is_global: bool,
    ) -> None:
        assert all(isinstance(c, Column) for c in columns)
        assert isinstance(is_global, bool)

        super().__init__(child, self._collect_references(columns))
        self.columns = columns
        self.is_global = is_global

    def plan(self, session: "SparkConnectClient") -> proto.Relation:
        assert self._child is not None
        plan = self._create_proto_relation()
        plan.sort.input.CopyFrom(self._child.plan(session))
        plan.sort.order.extend([c.to_plan(session).sort_order for c in self.columns])
        plan.sort.is_global = self.is_global
        return self._with_relations(plan, session)


class Drop(LogicalPlan):
    def __init__(
        self,
        child: Optional["LogicalPlan"],
        columns: List[Union[Column, str]],
    ) -> None:
        if len(columns) > 0:
            assert all(isinstance(c, (Column, str)) for c in columns)

        super().__init__(
            child, self._collect_references([c for c in columns if isinstance(c, Column)])
        )
        self._columns = columns

    def plan(self, session: "SparkConnectClient") -> proto.Relation:
        assert self._child is not None
        plan = self._create_proto_relation()
        plan.drop.input.CopyFrom(self._child.plan(session))
        for c in self._columns:
            if isinstance(c, Column):
                plan.drop.columns.append(c.to_plan(session))
            else:
                plan.drop.column_names.append(c)
        return self._with_relations(plan, session)


class Sample(LogicalPlan):
    def __init__(
        self,
        child: Optional["LogicalPlan"],
        lower_bound: float,
        upper_bound: float,
        with_replacement: bool,
        seed: int,
        deterministic_order: bool = False,
    ) -> None:
        super().__init__(child)
        self.lower_bound = lower_bound
        self.upper_bound = upper_bound
        self.with_replacement = with_replacement
        self.seed = seed
        self.deterministic_order = deterministic_order

    def plan(self, session: "SparkConnectClient") -> proto.Relation:
        assert self._child is not None
        plan = self._create_proto_relation()
        plan.sample.input.CopyFrom(self._child.plan(session))
        plan.sample.lower_bound = self.lower_bound
        plan.sample.upper_bound = self.upper_bound
        plan.sample.with_replacement = self.with_replacement
        plan.sample.seed = self.seed
        plan.sample.deterministic_order = self.deterministic_order
        return plan


class Aggregate(LogicalPlan):
    def __init__(
        self,
        child: Optional["LogicalPlan"],
        group_type: str,
        grouping_cols: Sequence[Column],
        aggregate_cols: Sequence[Column],
        pivot_col: Optional[Column],
        pivot_values: Optional[Sequence[Column]],
        grouping_sets: Optional[Sequence[Sequence[Column]]],
    ) -> None:
        assert isinstance(group_type, str) and group_type in [
            "groupby",
            "rollup",
            "cube",
            "pivot",
            "grouping_sets",
        ]

        assert isinstance(grouping_cols, list) and all(isinstance(c, Column) for c in grouping_cols)

        assert isinstance(aggregate_cols, list) and all(
            isinstance(c, Column) for c in aggregate_cols
        )

        if group_type == "pivot":
            assert pivot_col is not None and isinstance(pivot_col, Column)
            assert pivot_values is None or isinstance(pivot_values, list)
        elif group_type == "grouping_sets":
            assert grouping_sets is None or isinstance(grouping_sets, list)
        else:
            assert pivot_col is None
            assert pivot_values is None
            assert grouping_sets is None

        super().__init__(
            child,
            self._collect_references(
                grouping_cols
                + aggregate_cols
                + ([pivot_col] if pivot_col is not None else [])
                + (pivot_values if pivot_values is not None else [])
                + ([g for gs in grouping_sets for g in gs] if grouping_sets is not None else [])
            ),
        )
        self._group_type = group_type
        self._grouping_cols = grouping_cols
        self._aggregate_cols = aggregate_cols
        self._pivot_col = pivot_col
        self._pivot_values = pivot_values
        self._grouping_sets = grouping_sets

    def plan(self, session: "SparkConnectClient") -> proto.Relation:
        assert self._child is not None
        plan = self._create_proto_relation()
        plan.aggregate.input.CopyFrom(self._child.plan(session))
        plan.aggregate.grouping_expressions.extend(
            [c.to_plan(session) for c in self._grouping_cols]
        )
        plan.aggregate.aggregate_expressions.extend(
            [c.to_plan(session) for c in self._aggregate_cols]
        )

        if self._group_type == "groupby":
            plan.aggregate.group_type = proto.Aggregate.GroupType.GROUP_TYPE_GROUPBY
        elif self._group_type == "rollup":
            plan.aggregate.group_type = proto.Aggregate.GroupType.GROUP_TYPE_ROLLUP
        elif self._group_type == "cube":
            plan.aggregate.group_type = proto.Aggregate.GroupType.GROUP_TYPE_CUBE
        elif self._group_type == "pivot":
            plan.aggregate.group_type = proto.Aggregate.GroupType.GROUP_TYPE_PIVOT
            assert self._pivot_col is not None
            plan.aggregate.pivot.col.CopyFrom(self._pivot_col.to_plan(session))
            if self._pivot_values is not None and len(self._pivot_values) > 0:
                plan.aggregate.pivot.values.extend(
                    [v.to_plan(session).literal for v in self._pivot_values]
                )
        elif self._group_type == "grouping_sets":
            plan.aggregate.group_type = proto.Aggregate.GroupType.GROUP_TYPE_GROUPING_SETS
            assert self._grouping_sets is not None
            for grouping_set in self._grouping_sets:
                plan.aggregate.grouping_sets.append(
                    proto.Aggregate.GroupingSets(
                        grouping_set=[c.to_plan(session) for c in grouping_set]
                    )
                )
        return self._with_relations(plan, session)


class Join(LogicalPlan):
    def __init__(
        self,
        left: Optional["LogicalPlan"],
        right: "LogicalPlan",
        on: Optional[Union[str, List[str], Column, List[Column]]],
        how: Optional[str],
    ) -> None:
        super().__init__(
            left,
            self._collect_references(
                []
                if on is None or isinstance(on, str)
                else [on]
                if isinstance(on, Column)
                else [c for c in on if isinstance(c, Column)]
            ),
        )
        self.left = cast(LogicalPlan, left)
        self.right = right
        self.on = on
        if how is None:
            join_type = proto.Join.JoinType.JOIN_TYPE_INNER
        elif how == "inner":
            join_type = proto.Join.JoinType.JOIN_TYPE_INNER
        elif how in ["outer", "full", "fullouter"]:
            join_type = proto.Join.JoinType.JOIN_TYPE_FULL_OUTER
        elif how in ["leftouter", "left"]:
            join_type = proto.Join.JoinType.JOIN_TYPE_LEFT_OUTER
        elif how in ["rightouter", "right"]:
            join_type = proto.Join.JoinType.JOIN_TYPE_RIGHT_OUTER
        elif how in ["leftsemi", "semi"]:
            join_type = proto.Join.JoinType.JOIN_TYPE_LEFT_SEMI
        elif how in ["leftanti", "anti"]:
            join_type = proto.Join.JoinType.JOIN_TYPE_LEFT_ANTI
        elif how == "cross":
            join_type = proto.Join.JoinType.JOIN_TYPE_CROSS
        else:
            raise AnalysisException(
                errorClass="UNSUPPORTED_JOIN_TYPE",
                messageParameters={
                    "typ": how,
                    "supported": (
                        "'"
                        + "', '".join(
                            [
                                "inner",
                                "outer",
                                "full",
                                "fullouter",
                                "full_outer",
                                "leftouter",
                                "left",
                                "left_outer",
                                "rightouter",
                                "right",
                                "right_outer",
                                "leftsemi",
                                "left_semi",
                                "semi",
                                "leftanti",
                                "left_anti",
                                "anti",
                                "cross",
                            ]
                        )
                        + "'"
                    ),
                },
            )
        self.how = join_type

    def plan(self, session: "SparkConnectClient") -> proto.Relation:
        plan = self._create_proto_relation()
        plan.join.left.CopyFrom(self.left.plan(session))
        plan.join.right.CopyFrom(self.right.plan(session))
        if self.on is not None:
            if not isinstance(self.on, list):
                if isinstance(self.on, str):
                    plan.join.using_columns.append(self.on)
                else:
                    plan.join.join_condition.CopyFrom(self.on.to_plan(session))
            elif len(self.on) > 0:
                if isinstance(self.on[0], str):
                    plan.join.using_columns.extend(cast(str, self.on))
                else:
                    merge_column = functools.reduce(lambda c1, c2: c1 & c2, self.on)
                    plan.join.join_condition.CopyFrom(cast(Column, merge_column).to_plan(session))
        plan.join.join_type = self.how
        return self._with_relations(plan, session)

    @property
    def observations(self) -> Dict[str, "Observation"]:
        return dict(**super().observations, **self.right.observations)

    def print(self, indent: int = 0) -> str:
        i = " " * indent
        o = " " * (indent + LogicalPlan.INDENT)
        n = indent + LogicalPlan.INDENT * 2
        return (
            f"{i}<Join on={self.on} how={self.how}>\n{o}"
            f"left=\n{self.left.print(n)}\n{o}right=\n{self.right.print(n)}"
        )

    def _repr_html_(self) -> str:
        return f"""
        <ul>
            <li>
                <b>Join</b><br />
                Left: {self.left._repr_html_()}
                Right: {self.right._repr_html_()}
            </li>
        </uL>
        """


class AsOfJoin(LogicalPlan):
    def __init__(
        self,
        left: LogicalPlan,
        right: LogicalPlan,
        left_as_of: Column,
        right_as_of: Column,
        on: Optional[Union[str, List[str], Column, List[Column]]],
        how: str,
        tolerance: Optional[Column],
        allow_exact_matches: bool,
        direction: str,
    ) -> None:
        super().__init__(
            left,
            self._collect_references(
                [left_as_of, right_as_of]
                + (
                    []
                    if on is None or isinstance(on, str)
                    else [on]
                    if isinstance(on, Column)
                    else [c for c in on if isinstance(c, Column)]
                )
                + ([tolerance] if tolerance is not None else [])
            ),
        )
        self.left = left
        self.right = right
        self.left_as_of = left_as_of
        self.right_as_of = right_as_of
        self.on = on
        self.how = how
        self.tolerance = tolerance
        self.allow_exact_matches = allow_exact_matches
        self.direction = direction

    def plan(self, session: "SparkConnectClient") -> proto.Relation:
        plan = self._create_proto_relation()
        plan.as_of_join.left.CopyFrom(self.left.plan(session))
        plan.as_of_join.right.CopyFrom(self.right.plan(session))

        plan.as_of_join.left_as_of.CopyFrom(self.left_as_of.to_plan(session))
        plan.as_of_join.right_as_of.CopyFrom(self.right_as_of.to_plan(session))

        if self.on is not None:
            if not isinstance(self.on, list):
                if isinstance(self.on, str):
                    plan.as_of_join.using_columns.append(self.on)
                else:
                    plan.as_of_join.join_expr.CopyFrom(self.on.to_plan(session))
            elif len(self.on) > 0:
                if isinstance(self.on[0], str):
                    plan.as_of_join.using_columns.extend(cast(List[str], self.on))
                else:
                    merge_column = functools.reduce(lambda c1, c2: c1 & c2, self.on)
                    plan.as_of_join.join_expr.CopyFrom(cast(Column, merge_column).to_plan(session))

        plan.as_of_join.join_type = self.how

        if self.tolerance is not None:
            plan.as_of_join.tolerance.CopyFrom(self.tolerance.to_plan(session))

        plan.as_of_join.allow_exact_matches = self.allow_exact_matches
        plan.as_of_join.direction = self.direction

        return self._with_relations(plan, session)

    @property
    def observations(self) -> Dict[str, "Observation"]:
        return dict(**super().observations, **self.right.observations)

    def print(self, indent: int = 0) -> str:
        assert self.left is not None
        assert self.right is not None

        i = " " * indent
        o = " " * (indent + LogicalPlan.INDENT)
        n = indent + LogicalPlan.INDENT * 2
        return (
            f"{i}<AsOfJoin left_as_of={self.left_as_of}, right_as_of={self.right_as_of}, "
            f"on={self.on} how={self.how}>\n{o}"
            f"left=\n{self.left.print(n)}\n{o}right=\n{self.right.print(n)}"
        )

    def _repr_html_(self) -> str:
        assert self.left is not None
        assert self.right is not None

        return f"""
        <ul>
            <li>
                <b>AsOfJoin</b><br />
                Left: {self.left._repr_html_()}
                Right: {self.right._repr_html_()}
            </li>
        </uL>
        """


class LateralJoin(LogicalPlan):
    def __init__(
        self,
        left: Optional[LogicalPlan],
        right: LogicalPlan,
        on: Optional[Column],
        how: Optional[str],
    ) -> None:
        super().__init__(left, self._collect_references([on] if on is not None else []))
        self.left = cast(LogicalPlan, left)
        self.right = right
        self.on = on
        if how is None:
            join_type = proto.Join.JoinType.JOIN_TYPE_INNER
        elif how == "inner":
            join_type = proto.Join.JoinType.JOIN_TYPE_INNER
        elif how in ["leftouter", "left"]:
            join_type = proto.Join.JoinType.JOIN_TYPE_LEFT_OUTER
        elif how == "cross":
            join_type = proto.Join.JoinType.JOIN_TYPE_CROSS
        else:
            raise AnalysisException(
                errorClass="UNSUPPORTED_JOIN_TYPE",
                messageParameters={
                    "typ": how,
                    "supported": (
                        "'"
                        + "', '".join(["inner", "leftouter", "left", "left_outer", "cross"])
                        + "'"
                    ),
                },
            )
        self.how = join_type

    def plan(self, session: "SparkConnectClient") -> proto.Relation:
        plan = self._create_proto_relation()
        plan.lateral_join.left.CopyFrom(self.left.plan(session))
        plan.lateral_join.right.CopyFrom(self.right.plan(session))
        if self.on is not None:
            plan.lateral_join.join_condition.CopyFrom(self.on.to_plan(session))
        plan.lateral_join.join_type = self.how
        return self._with_relations(plan, session)

    @property
    def observations(self) -> Dict[str, "Observation"]:
        return dict(**super().observations, **self.right.observations)

    def print(self, indent: int = 0) -> str:
        i = " " * indent
        o = " " * (indent + LogicalPlan.INDENT)
        n = indent + LogicalPlan.INDENT * 2
        return (
            f"{i}<LateralJoin on={self.on} how={self.how}>\n{o}"
            f"left=\n{self.left.print(n)}\n{o}right=\n{self.right.print(n)}"
        )

    def _repr_html_(self) -> str:
        return f"""
        <ul>
            <li>
                <b>LateralJoin</b><br />
                Left: {self.left._repr_html_()}
                Right: {self.right._repr_html_()}
            </li>
        </uL>
        """


class SetOperation(LogicalPlan):
    def __init__(
        self,
        child: Optional["LogicalPlan"],
        other: Optional["LogicalPlan"],
        set_op: str,
        is_all: bool = True,
        by_name: bool = False,
        allow_missing_columns: bool = False,
    ) -> None:
        super().__init__(child)
        self.other = other
        self.by_name = by_name
        self.is_all = is_all
        self.set_op = set_op
        self.allow_missing_columns = allow_missing_columns

    def plan(self, session: "SparkConnectClient") -> proto.Relation:
        assert self._child is not None
        plan = self._create_proto_relation()
        if self._child is not None:
            plan.set_op.left_input.CopyFrom(self._child.plan(session))
        if self.other is not None:
            plan.set_op.right_input.CopyFrom(self.other.plan(session))
        if self.set_op == "union":
            plan.set_op.set_op_type = proto.SetOperation.SET_OP_TYPE_UNION
        elif self.set_op == "intersect":
            plan.set_op.set_op_type = proto.SetOperation.SET_OP_TYPE_INTERSECT
        elif self.set_op == "except":
            plan.set_op.set_op_type = proto.SetOperation.SET_OP_TYPE_EXCEPT
        else:
            raise PySparkValueError(
                errorClass="UNSUPPORTED_OPERATION",
                messageParameters={"operation": self.set_op},
            )

        plan.set_op.is_all = self.is_all
        plan.set_op.by_name = self.by_name
        plan.set_op.allow_missing_columns = self.allow_missing_columns
        return plan

    @property
    def observations(self) -> Dict[str, "Observation"]:
        return dict(
            **super().observations,
            **(self.other.observations if self.other is not None else {}),
        )

    def print(self, indent: int = 0) -> str:
        assert self._child is not None
        assert self.other is not None

        i = " " * indent
        o = " " * (indent + LogicalPlan.INDENT)
        n = indent + LogicalPlan.INDENT * 2
        return (
            f"{i}SetOperation\n{o}child1=\n{self._child.print(n)}"
            f"\n{o}child2=\n{self.other.print(n)}"
        )

    def _repr_html_(self) -> str:
        assert self._child is not None
        assert self.other is not None

        return f"""
        <ul>
            <li>
                <b>SetOperation</b><br />
                Left: {self._child._repr_html_()}
                Right: {self.other._repr_html_()}
            </li>
        </uL>
        """


class Repartition(LogicalPlan):
    """Repartition Relation into a different number of partitions."""

    def __init__(self, child: Optional["LogicalPlan"], num_partitions: int, shuffle: bool) -> None:
        super().__init__(child)
        self._num_partitions = num_partitions
        self._shuffle = shuffle

    def plan(self, session: "SparkConnectClient") -> proto.Relation:
        plan = self._create_proto_relation()
        if self._child is not None:
            plan.repartition.input.CopyFrom(self._child.plan(session))
        plan.repartition.shuffle = self._shuffle
        plan.repartition.num_partitions = self._num_partitions
        return plan


class RepartitionByExpression(LogicalPlan):
    """Repartition Relation into a different number of partitions using Expression"""

    def __init__(
        self,
        child: Optional["LogicalPlan"],
        num_partitions: Optional[int],
        columns: List[Column],
    ) -> None:
        assert all(isinstance(c, Column) for c in columns)
        super().__init__(child, self._collect_references(columns))
        self.num_partitions = num_partitions
        self.columns = columns

    def plan(self, session: "SparkConnectClient") -> proto.Relation:
        plan = self._create_proto_relation()
        plan.repartition_by_expression.partition_exprs.extend(
            [c.to_plan(session) for c in self.columns]
        )

        if self._child is not None:
            plan.repartition_by_expression.input.CopyFrom(self._child.plan(session))
        if self.num_partitions is not None:
            plan.repartition_by_expression.num_partitions = self.num_partitions
        return self._with_relations(plan, session)


class SubqueryAlias(LogicalPlan):
    """Alias for a relation."""

    def __init__(self, child: Optional["LogicalPlan"], alias: str) -> None:
        super().__init__(child)
        self._alias = alias

    def plan(self, session: "SparkConnectClient") -> proto.Relation:
        plan = self._create_proto_relation()
        if self._child is not None:
            plan.subquery_alias.input.CopyFrom(self._child.plan(session))
        plan.subquery_alias.alias = self._alias
        return plan


class WithRelations(LogicalPlan):
    def __init__(
        self,
        child: Optional["LogicalPlan"],
        references: Sequence["LogicalPlan"],
    ) -> None:
        super().__init__(child)
        assert references is not None and len(references) > 0
        assert all(isinstance(ref, LogicalPlan) for ref in references)
        self._references = references

    def plan(self, session: "SparkConnectClient") -> proto.Relation:
        plan = self._create_proto_relation()
        if self._child is not None:
            plan.with_relations.root.CopyFrom(self._child.plan(session))
        for ref in self._references:
            plan.with_relations.references.append(ref.plan(session))
        return plan


class SQL(LogicalPlan):
    def __init__(
        self,
        query: str,
        args: Optional[List[Column]] = None,
        named_args: Optional[Dict[str, Column]] = None,
        views: Optional[Sequence[SubqueryAlias]] = None,
    ) -> None:
        if args is not None:
            assert isinstance(args, List)
            assert all(isinstance(arg, Column) for arg in args)

        if named_args is not None:
            assert isinstance(named_args, Dict)
            for k, arg in named_args.items():
                assert isinstance(k, str)
                assert isinstance(arg, Column)

        if views is not None:
            assert isinstance(views, List)
            assert all(isinstance(v, SubqueryAlias) for v in views)

        super().__init__(None, views)
        self._query = query
        self._args = args
        self._named_args = named_args
        self._views = views

    def plan(self, session: "SparkConnectClient") -> proto.Relation:
        plan = self._create_proto_relation()
        plan.sql.query = self._query

        if self._args is not None and len(self._args) > 0:
            plan.sql.pos_arguments.extend([arg.to_plan(session) for arg in self._args])
        if self._named_args is not None and len(self._named_args) > 0:
            for k, arg in self._named_args.items():
                plan.sql.named_arguments[k].CopyFrom(arg.to_plan(session))

        return self._with_relations(plan, session)

    def command(self, session: "SparkConnectClient") -> proto.Command:
        cmd = proto.Command()
        cmd.sql_command.input.CopyFrom(self.plan(session))
        return cmd


class Range(LogicalPlan):
    def __init__(
        self,
        start: int,
        end: int,
        step: int,
        num_partitions: Optional[int] = None,
    ) -> None:
        super().__init__(None)
        self._start = start
        self._end = end
        self._step = step
        self._num_partitions = num_partitions

    def plan(self, session: "SparkConnectClient") -> proto.Relation:
        plan = self._create_proto_relation()
        plan.range.start = self._start
        plan.range.end = self._end
        plan.range.step = self._step
        if self._num_partitions is not None:
            plan.range.num_partitions = self._num_partitions
        return plan


class ToSchema(LogicalPlan):
    def __init__(self, child: Optional["LogicalPlan"], schema: DataType) -> None:
        super().__init__(child)
        self._schema = schema

    def plan(self, session: "SparkConnectClient") -> proto.Relation:
        assert self._child is not None
        plan = self._create_proto_relation()
        plan.to_schema.input.CopyFrom(self._child.plan(session))
        plan.to_schema.schema.CopyFrom(pyspark_types_to_proto_types(self._schema))
        return plan


class WithColumnsRenamed(LogicalPlan):
    def __init__(self, child: Optional["LogicalPlan"], colsMap: Mapping[str, str]) -> None:
        super().__init__(child)
        self._colsMap = colsMap

    def plan(self, session: "SparkConnectClient") -> proto.Relation:
        assert self._child is not None
        plan = self._create_proto_relation()
        plan.with_columns_renamed.input.CopyFrom(self._child.plan(session))
        if len(self._colsMap) > 0:
            for k, v in self._colsMap.items():
                rename = proto.WithColumnsRenamed.Rename()
                rename.col_name = k
                rename.new_col_name = v
                plan.with_columns_renamed.renames.append(rename)
        return plan


class Unpivot(LogicalPlan):
    """Logical plan object for a unpivot operation."""

    def __init__(
        self,
        child: Optional["LogicalPlan"],
        ids: List[Column],
        values: Optional[List[Column]],
        variable_column_name: str,
        value_column_name: str,
    ) -> None:
        super().__init__(child, self._collect_references(ids + (values or [])))
        self.ids = ids
        self.values = values
        self.variable_column_name = variable_column_name
        self.value_column_name = value_column_name

    def plan(self, session: "SparkConnectClient") -> proto.Relation:
        assert self._child is not None
        plan = self._create_proto_relation()
        plan.unpivot.input.CopyFrom(self._child.plan(session))
        plan.unpivot.ids.extend([id.to_plan(session) for id in self.ids])
        if self.values is not None:
            plan.unpivot.values.values.extend([v.to_plan(session) for v in self.values])
        plan.unpivot.variable_column_name = self.variable_column_name
        plan.unpivot.value_column_name = self.value_column_name
        return self._with_relations(plan, session)


class Transpose(LogicalPlan):
    """Logical plan object for a transpose operation."""

    def __init__(
        self,
        child: Optional["LogicalPlan"],
        index_columns: Sequence[Column],
    ) -> None:
        super().__init__(child, self._collect_references(index_columns))
        self.index_columns = index_columns

    def plan(self, session: "SparkConnectClient") -> proto.Relation:
        assert self._child is not None
        plan = self._create_proto_relation()
        plan.transpose.input.CopyFrom(self._child.plan(session))
        if self.index_columns is not None and len(self.index_columns) > 0:
            for index_column in self.index_columns:
                plan.transpose.index_columns.append(index_column.to_plan(session))
        return self._with_relations(plan, session)


class UnresolvedTableValuedFunction(LogicalPlan):
    def __init__(self, name: str, args: Sequence[Column]):
        super().__init__(None, self._collect_references(args))
        self._name = name
        self._args = args

    def plan(self, session: "SparkConnectClient") -> proto.Relation:
        plan = self._create_proto_relation()
        plan.unresolved_table_valued_function.function_name = self._name
        for arg in self._args:
            plan.unresolved_table_valued_function.arguments.append(arg.to_plan(session))
        return self._with_relations(plan, session)


class CollectMetrics(LogicalPlan):
    """Logical plan object for a CollectMetrics operation."""

    def __init__(
        self,
        child: Optional["LogicalPlan"],
        observation: Union[str, "Observation"],
        exprs: List[Column],
    ) -> None:
        assert all(isinstance(e, Column) for e in exprs)
        super().__init__(child, self._collect_references(exprs))
        self._observation = observation
        self._exprs = exprs

    def plan(self, session: "SparkConnectClient") -> proto.Relation:
        assert self._child is not None
        plan = self._create_proto_relation()
        plan.collect_metrics.input.CopyFrom(self._child.plan(session))
        plan.collect_metrics.name = (
            self._observation
            if isinstance(self._observation, str)
            else str(self._observation._name)
        )
        plan.collect_metrics.metrics.extend([e.to_plan(session) for e in self._exprs])
        return self._with_relations(plan, session)

    @property
    def observations(self) -> Dict[str, "Observation"]:
        from pyspark.sql.connect.observation import Observation

        if isinstance(self._observation, Observation):
            observations = {str(self._observation._name): self._observation}
        else:
            observations = {}
        return dict(**super().observations, **observations)


class NAFill(LogicalPlan):
    def __init__(
        self, child: Optional["LogicalPlan"], cols: Optional[List[str]], values: List[Any]
    ) -> None:
        super().__init__(child)

        assert (
            isinstance(values, list)
            and len(values) > 0
            and all(isinstance(v, (bool, int, float, str)) for v in values)
        )

        if cols is not None and len(cols) > 0:
            assert isinstance(cols, list) and all(isinstance(c, str) for c in cols)
            if len(values) > 1:
                assert len(cols) == len(values)

        self.cols = cols
        self.values = values

    def _convert_value(self, v: Any) -> proto.Expression.Literal:
        value = proto.Expression.Literal()
        if isinstance(v, bool):
            value.boolean = v
        elif isinstance(v, int):
            value.long = v
        elif isinstance(v, float):
            value.double = v
        else:
            value.string = v
        return value

    def plan(self, session: "SparkConnectClient") -> proto.Relation:
        assert self._child is not None
        plan = self._create_proto_relation()
        plan.fill_na.input.CopyFrom(self._child.plan(session))
        if self.cols is not None and len(self.cols) > 0:
            plan.fill_na.cols.extend(self.cols)
        plan.fill_na.values.extend([self._convert_value(v) for v in self.values])
        return plan


class NADrop(LogicalPlan):
    def __init__(
        self,
        child: Optional["LogicalPlan"],
        cols: Optional[List[str]],
        min_non_nulls: Optional[int],
    ) -> None:
        super().__init__(child)

        self.cols = cols
        self.min_non_nulls = min_non_nulls

    def plan(self, session: "SparkConnectClient") -> proto.Relation:
        assert self._child is not None
        plan = self._create_proto_relation()
        plan.drop_na.input.CopyFrom(self._child.plan(session))
        if self.cols is not None and len(self.cols) > 0:
            plan.drop_na.cols.extend(self.cols)
        if self.min_non_nulls is not None:
            plan.drop_na.min_non_nulls = self.min_non_nulls
        return plan


class NAReplace(LogicalPlan):
    def __init__(
        self,
        child: Optional["LogicalPlan"],
        cols: Optional[List[str]],
        replacements: Sequence[Tuple[Column, Column]],
    ) -> None:
        assert replacements is not None and isinstance(replacements, List)
        for k, v in replacements:
            assert k is not None and isinstance(k, Column)
            assert v is not None and isinstance(v, Column)

        super().__init__(child, self._collect_references([e for t in replacements for e in t]))
        self.cols = cols
        self.replacements = replacements

    def plan(self, session: "SparkConnectClient") -> proto.Relation:
        assert self._child is not None
        plan = self._create_proto_relation()
        plan.replace.input.CopyFrom(self._child.plan(session))
        if self.cols is not None and len(self.cols) > 0:
            plan.replace.cols.extend(self.cols)
        if len(self.replacements) > 0:
            for old_value, new_value in self.replacements:
                replacement = proto.NAReplace.Replacement()
                replacement.old_value.CopyFrom(old_value.to_plan(session).literal)
                replacement.new_value.CopyFrom(new_value.to_plan(session).literal)
                plan.replace.replacements.append(replacement)
        return self._with_relations(plan, session)


class StatSummary(LogicalPlan):
    def __init__(self, child: Optional["LogicalPlan"], statistics: List[str]) -> None:
        super().__init__(child)
        self.statistics = statistics

    def plan(self, session: "SparkConnectClient") -> proto.Relation:
        assert self._child is not None
        plan = self._create_proto_relation()
        plan.summary.input.CopyFrom(self._child.plan(session))
        plan.summary.statistics.extend(self.statistics)
        return plan


class StatDescribe(LogicalPlan):
    def __init__(self, child: Optional["LogicalPlan"], cols: List[str]) -> None:
        super().__init__(child)
        self.cols = cols

    def plan(self, session: "SparkConnectClient") -> proto.Relation:
        assert self._child is not None
        plan = self._create_proto_relation()
        plan.describe.input.CopyFrom(self._child.plan(session))
        plan.describe.cols.extend(self.cols)
        return plan


class StatCov(LogicalPlan):
    def __init__(self, child: Optional["LogicalPlan"], col1: str, col2: str) -> None:
        super().__init__(child)
        self._col1 = col1
        self._col2 = col2

    def plan(self, session: "SparkConnectClient") -> proto.Relation:
        assert self._child is not None
        plan = self._create_proto_relation()
        plan.cov.input.CopyFrom(self._child.plan(session))
        plan.cov.col1 = self._col1
        plan.cov.col2 = self._col2
        return plan


class StatApproxQuantile(LogicalPlan):
    def __init__(
        self,
        child: Optional["LogicalPlan"],
        cols: List[str],
        probabilities: List[float],
        relativeError: float,
    ) -> None:
        super().__init__(child)
        self._cols = cols
        self._probabilities = probabilities
        self._relativeError = relativeError

    def plan(self, session: "SparkConnectClient") -> proto.Relation:
        assert self._child is not None
        plan = self._create_proto_relation()
        plan.approx_quantile.input.CopyFrom(self._child.plan(session))
        plan.approx_quantile.cols.extend(self._cols)
        plan.approx_quantile.probabilities.extend(self._probabilities)
        plan.approx_quantile.relative_error = self._relativeError
        return plan


class StatCrosstab(LogicalPlan):
    def __init__(self, child: Optional["LogicalPlan"], col1: str, col2: str) -> None:
        super().__init__(child)
        self.col1 = col1
        self.col2 = col2

    def plan(self, session: "SparkConnectClient") -> proto.Relation:
        assert self._child is not None
        plan = self._create_proto_relation()
        plan.crosstab.input.CopyFrom(self._child.plan(session))
        plan.crosstab.col1 = self.col1
        plan.crosstab.col2 = self.col2
        return plan


class StatFreqItems(LogicalPlan):
    def __init__(
        self,
        child: Optional["LogicalPlan"],
        cols: List[str],
        support: float,
    ) -> None:
        super().__init__(child)
        self._cols = cols
        self._support = support

    def plan(self, session: "SparkConnectClient") -> proto.Relation:
        assert self._child is not None
        plan = self._create_proto_relation()
        plan.freq_items.input.CopyFrom(self._child.plan(session))
        plan.freq_items.cols.extend(self._cols)
        plan.freq_items.support = self._support
        return plan


class StatSampleBy(LogicalPlan):
    def __init__(
        self,
        child: Optional["LogicalPlan"],
        col: Column,
        fractions: Sequence[Tuple[Column, float]],
        seed: int,
    ) -> None:
        assert col is not None and isinstance(col, (Column, str))

        assert fractions is not None and isinstance(fractions, List)
        for k, v in fractions:
            assert k is not None and isinstance(k, Column)
            assert v is not None and isinstance(v, float)

        assert seed is None or isinstance(seed, int)

        super().__init__(
            child,
            self._collect_references(
                [col] if isinstance(col, Column) else [] + [c for c, _ in fractions]
            ),
        )
        self._col = col
        self._fractions = fractions
        self._seed = seed

    def plan(self, session: "SparkConnectClient") -> proto.Relation:
        assert self._child is not None
        plan = self._create_proto_relation()
        plan.sample_by.input.CopyFrom(self._child.plan(session))
        plan.sample_by.col.CopyFrom(self._col._expr.to_plan(session))
        if len(self._fractions) > 0:
            for k, v in self._fractions:
                fraction = proto.StatSampleBy.Fraction()
                fraction.stratum.CopyFrom(k.to_plan(session).literal)
                fraction.fraction = float(v)
                plan.sample_by.fractions.append(fraction)
        plan.sample_by.seed = self._seed
        return self._with_relations(plan, session)


class StatCorr(LogicalPlan):
    def __init__(self, child: Optional["LogicalPlan"], col1: str, col2: str, method: str) -> None:
        super().__init__(child)
        self._col1 = col1
        self._col2 = col2
        self._method = method

    def plan(self, session: "SparkConnectClient") -> proto.Relation:
        assert self._child is not None
        plan = self._create_proto_relation()
        plan.corr.input.CopyFrom(self._child.plan(session))
        plan.corr.col1 = self._col1
        plan.corr.col2 = self._col2
        plan.corr.method = self._method
        return plan


class ToDF(LogicalPlan):
    def __init__(self, child: Optional["LogicalPlan"], cols: Sequence[str]) -> None:
        super().__init__(child)
        self._cols = cols

    def plan(self, session: "SparkConnectClient") -> proto.Relation:
        assert self._child is not None
        plan = self._create_proto_relation()
        plan.to_df.input.CopyFrom(self._child.plan(session))
        plan.to_df.column_names.extend(self._cols)
        return plan


class CreateView(LogicalPlan):
    def __init__(
        self, child: Optional["LogicalPlan"], name: str, is_global: bool, replace: bool
    ) -> None:
        super().__init__(child)
        self._name = name
        self._is_global = is_global
        self._replace = replace

    def command(self, session: "SparkConnectClient") -> proto.Command:
        assert self._child is not None
        plan = proto.Command()

        plan.create_dataframe_view.replace = self._replace
        plan.create_dataframe_view.is_global = self._is_global
        plan.create_dataframe_view.name = self._name
        plan.create_dataframe_view.input.CopyFrom(self._child.plan(session))
        return plan


class WriteOperation(LogicalPlan):
    def __init__(self, child: "LogicalPlan") -> None:
        super(WriteOperation, self).__init__(child)
        self.source: Optional[str] = None
        self.path: Optional[str] = None
        self.table_name: Optional[str] = None
        self.table_save_method: Optional[str] = None
        self.mode: Optional[str] = None
        self.sort_cols: List[str] = []
        self.partitioning_cols: List[str] = []
        self.clustering_cols: List[str] = []
        self.options: Dict[str, Optional[str]] = {}
        self.num_buckets: int = -1
        self.bucket_cols: List[str] = []

    def command(self, session: "SparkConnectClient") -> proto.Command:
        assert self._child is not None
        plan = proto.Command()

        plan.write_operation.input.CopyFrom(self._child.plan(session))
        if self.source is not None:
            plan.write_operation.source = self.source
        plan.write_operation.sort_column_names.extend(self.sort_cols)
        plan.write_operation.partitioning_columns.extend(self.partitioning_cols)
        plan.write_operation.clustering_columns.extend(self.clustering_cols)

        if self.num_buckets > 0:
            plan.write_operation.bucket_by.bucket_column_names.extend(self.bucket_cols)
            plan.write_operation.bucket_by.num_buckets = self.num_buckets

        for k in self.options:
            if self.options[k] is None:
                plan.write_operation.options.pop(k, None)
            else:
                plan.write_operation.options[k] = cast(str, self.options[k])

        if self.table_name is not None:
            plan.write_operation.table.table_name = self.table_name
            if self.table_save_method is not None:
                tsm = self.table_save_method.lower()
                if tsm == "save_as_table":
                    plan.write_operation.table.save_method = (
                        proto.WriteOperation.SaveTable.TableSaveMethod.TABLE_SAVE_METHOD_SAVE_AS_TABLE  # noqa: E501
                    )
                elif tsm == "insert_into":
                    plan.write_operation.table.save_method = (
                        proto.WriteOperation.SaveTable.TableSaveMethod.TABLE_SAVE_METHOD_INSERT_INTO
                    )
                else:
                    raise PySparkValueError(
                        errorClass="UNSUPPORTED_OPERATION",
                        messageParameters={"operation": tsm},
                    )
        elif self.path is not None:
            plan.write_operation.path = self.path

        if self.mode is not None:
            wm = self.mode.lower()
            if wm == "append":
                plan.write_operation.mode = proto.WriteOperation.SaveMode.SAVE_MODE_APPEND
            elif wm == "overwrite":
                plan.write_operation.mode = proto.WriteOperation.SaveMode.SAVE_MODE_OVERWRITE
            elif wm == "error":
                plan.write_operation.mode = proto.WriteOperation.SaveMode.SAVE_MODE_ERROR_IF_EXISTS
            elif wm == "ignore":
                plan.write_operation.mode = proto.WriteOperation.SaveMode.SAVE_MODE_IGNORE
            else:
                raise PySparkValueError(
                    errorClass="UNSUPPORTED_OPERATION",
                    messageParameters={"operation": self.mode},
                )
        return plan

    def print(self, indent: int = 0) -> str:
        i = " " * indent
        return (
            f"{i}"
            f"<WriteOperation source='{self.source}' "
            f"path='{self.path} "
            f"table_name='{self.table_name}' "
            f"table_save_method='{self.table_save_method}' "
            f"mode='{self.mode}' "
            f"sort_cols='{self.sort_cols}' "
            f"partitioning_cols='{self.partitioning_cols}' "
            f"clustering_cols='{self.clustering_cols}' "
            f"num_buckets='{self.num_buckets}' "
            f"bucket_cols='{self.bucket_cols}' "
            f"options='{self.options}'>"
        )

    def _repr_html_(self) -> str:
        return (
            f"<uL><li>WriteOperation <br />source='{self.source}'<br />"
            f"path: '{self.path}<br />"
            f"table_name: '{self.table_name}' <br />"
            f"table_save_method: '{self.table_save_method}' <br />"
            f"mode: '{self.mode}' <br />"
            f"sort_cols: '{self.sort_cols}' <br />"
            f"partitioning_cols: '{self.partitioning_cols}' <br />"
            f"clustering_cols: '{self.clustering_cols}' <br />"
            f"num_buckets: '{self.num_buckets}' <br />"
            f"bucket_cols: '{self.bucket_cols}' <br />"
            f"options: '{self.options}'<br />"
            f"</li></ul>"
        )


class WriteOperationV2(LogicalPlan):
    def __init__(self, child: "LogicalPlan", table_name: str) -> None:
        super(WriteOperationV2, self).__init__(child)
        self.table_name: Optional[str] = table_name
        self.provider: Optional[str] = None
        self.partitioning_columns: List[Column] = []
        self.clustering_columns: List[str] = []
        self.options: dict[str, Optional[str]] = {}
        self.table_properties: dict[str, Optional[str]] = {}
        self.mode: Optional[str] = None
        self.overwrite_condition: Optional[Column] = None

    def command(self, session: "SparkConnectClient") -> proto.Command:
        assert self._child is not None
        plan = proto.Command()
        plan.write_operation_v2.input.CopyFrom(self._child.plan(session))
        if self.table_name is not None:
            plan.write_operation_v2.table_name = self.table_name
        if self.provider is not None:
            plan.write_operation_v2.provider = self.provider

        plan.write_operation_v2.partitioning_columns.extend(
            [c.to_plan(session) for c in self.partitioning_columns]
        )
        plan.write_operation_v2.clustering_columns.extend(self.clustering_columns)

        for k in self.options:
            if self.options[k] is None:
                plan.write_operation_v2.options.pop(k, None)
            else:
                plan.write_operation_v2.options[k] = cast(str, self.options[k])

        for k in self.table_properties:
            if self.table_properties[k] is None:
                plan.write_operation_v2.table_properties.pop(k, None)
            else:
                plan.write_operation_v2.table_properties[k] = cast(str, self.table_properties[k])

        if self.mode is not None:
            wm = self.mode.lower()
            if wm == "create":
                plan.write_operation_v2.mode = proto.WriteOperationV2.Mode.MODE_CREATE
            elif wm == "overwrite":
                plan.write_operation_v2.mode = proto.WriteOperationV2.Mode.MODE_OVERWRITE
                if self.overwrite_condition is not None:
                    plan.write_operation_v2.overwrite_condition.CopyFrom(
                        self.overwrite_condition.to_plan(session)
                    )
            elif wm == "overwrite_partitions":
                plan.write_operation_v2.mode = proto.WriteOperationV2.Mode.MODE_OVERWRITE_PARTITIONS
            elif wm == "append":
                plan.write_operation_v2.mode = proto.WriteOperationV2.Mode.MODE_APPEND
            elif wm == "replace":
                plan.write_operation_v2.mode = proto.WriteOperationV2.Mode.MODE_REPLACE
            elif wm == "create_or_replace":
                plan.write_operation_v2.mode = proto.WriteOperationV2.Mode.MODE_CREATE_OR_REPLACE
            else:
                raise PySparkValueError(
                    errorClass="UNSUPPORTED_OPERATION",
                    messageParameters={"operation": self.mode},
                )
        return plan


class WriteStreamOperation(LogicalPlan):
    def __init__(self, child: "LogicalPlan") -> None:
        super(WriteStreamOperation, self).__init__(child)
        self.write_op = proto.WriteStreamOperationStart()

    def command(self, session: "SparkConnectClient") -> proto.Command:
        assert self._child is not None
        self.write_op.input.CopyFrom(self._child.plan(session))
        cmd = proto.Command()
        cmd.write_stream_operation_start.CopyFrom(self.write_op)
        return cmd


class RemoveRemoteCachedRelation(LogicalPlan):
    def __init__(self, relation: CachedRemoteRelation) -> None:
        super().__init__(None)
        self._relation = relation

    def command(self, session: "SparkConnectClient") -> proto.Command:
        plan = self._create_proto_relation()
        plan.cached_remote_relation.relation_id = self._relation._relation_id
        cmd = proto.Command()
        cmd.remove_cached_remote_relation_command.relation.CopyFrom(plan.cached_remote_relation)
        return cmd


class Checkpoint(LogicalPlan):
    def __init__(
        self,
        child: Optional["LogicalPlan"],
        local: bool,
        eager: bool,
        storage_level: Optional[StorageLevel] = None,
    ) -> None:
        super().__init__(child)
        self._local = local
        self._eager = eager
        self._storage_level = storage_level

    def command(self, session: "SparkConnectClient") -> proto.Command:
        cmd = proto.Command()
        assert self._child is not None
        checkpoint_command = proto.CheckpointCommand(
            relation=self._child.plan(session),
            local=self._local,
            eager=self._eager,
        )
        if self._storage_level is not None:
            checkpoint_command.storage_level.CopyFrom(storage_level_to_proto(self._storage_level))
        cmd.checkpoint_command.CopyFrom(checkpoint_command)
        return cmd


# Catalog API (internal-only)
class CurrentDatabase(LogicalPlan):
    def __init__(self) -> None:
        super().__init__(None)

    def plan(self, session: "SparkConnectClient") -> proto.Relation:
        plan = self._create_proto_relation()
        plan.catalog.current_database.SetInParent()
        return plan


class SetCurrentDatabase(LogicalPlan):
    def __init__(self, db_name: str) -> None:
        super().__init__(None)
        self._db_name = db_name

    def plan(self, session: "SparkConnectClient") -> proto.Relation:
        plan = self._create_proto_relation()
        plan.catalog.set_current_database.db_name = self._db_name
        return plan


class ListDatabases(LogicalPlan):
    def __init__(self, pattern: Optional[str] = None) -> None:
        super().__init__(None)
        self._pattern = pattern

    def plan(self, session: "SparkConnectClient") -> proto.Relation:
        plan = self._create_proto_relation()
        plan.catalog.list_databases.SetInParent()
        if self._pattern is not None:
            plan.catalog.list_databases.pattern = self._pattern
        return plan


class ListTables(LogicalPlan):
    def __init__(self, db_name: Optional[str] = None, pattern: Optional[str] = None) -> None:
        super().__init__(None)
        self._db_name = db_name
        self._pattern = pattern

    def plan(self, session: "SparkConnectClient") -> proto.Relation:
        plan = self._create_proto_relation()
        plan.catalog.list_tables.SetInParent()
        if self._db_name is not None:
            plan.catalog.list_tables.db_name = self._db_name
        if self._pattern is not None:
            plan.catalog.list_tables.pattern = self._pattern
        return plan


class ListFunctions(LogicalPlan):
    def __init__(self, db_name: Optional[str] = None, pattern: Optional[str] = None) -> None:
        super().__init__(None)
        self._db_name = db_name
        self._pattern = pattern

    def plan(self, session: "SparkConnectClient") -> proto.Relation:
        plan = self._create_proto_relation()
        plan.catalog.list_functions.SetInParent()
        if self._db_name is not None:
            plan.catalog.list_functions.db_name = self._db_name
        if self._pattern is not None:
            plan.catalog.list_functions.pattern = self._pattern
        return plan


class ListColumns(LogicalPlan):
    def __init__(self, table_name: str, db_name: Optional[str] = None) -> None:
        super().__init__(None)
        self._table_name = table_name
        self._db_name = db_name

    def plan(self, session: "SparkConnectClient") -> proto.Relation:
        plan = self._create_proto_relation()
        plan.catalog.list_columns.table_name = self._table_name
        if self._db_name is not None:
            plan.catalog.list_columns.db_name = self._db_name
        return plan


class GetDatabase(LogicalPlan):
    def __init__(self, db_name: str) -> None:
        super().__init__(None)
        self._db_name = db_name

    def plan(self, session: "SparkConnectClient") -> proto.Relation:
        plan = self._create_proto_relation()
        plan.catalog.get_database.db_name = self._db_name
        return plan


class GetTable(LogicalPlan):
    def __init__(self, table_name: str, db_name: Optional[str] = None) -> None:
        super().__init__(None)
        self._table_name = table_name
        self._db_name = db_name

    def plan(self, session: "SparkConnectClient") -> proto.Relation:
        plan = self._create_proto_relation()
        plan.catalog.get_table.table_name = self._table_name
        if self._db_name is not None:
            plan.catalog.get_table.db_name = self._db_name
        return plan


class GetFunction(LogicalPlan):
    def __init__(self, function_name: str, db_name: Optional[str] = None) -> None:
        super().__init__(None)
        self._function_name = function_name
        self._db_name = db_name

    def plan(self, session: "SparkConnectClient") -> proto.Relation:
        plan = self._create_proto_relation()
        plan.catalog.get_function.function_name = self._function_name
        if self._db_name is not None:
            plan.catalog.get_function.db_name = self._db_name
        return plan


class DatabaseExists(LogicalPlan):
    def __init__(self, db_name: str) -> None:
        super().__init__(None)
        self._db_name = db_name

    def plan(self, session: "SparkConnectClient") -> proto.Relation:
        plan = self._create_proto_relation()
        plan.catalog.database_exists.db_name = self._db_name
        return plan


class TableExists(LogicalPlan):
    def __init__(self, table_name: str, db_name: Optional[str] = None) -> None:
        super().__init__(None)
        self._table_name = table_name
        self._db_name = db_name

    def plan(self, session: "SparkConnectClient") -> proto.Relation:
        plan = self._create_proto_relation()
        plan.catalog.table_exists.table_name = self._table_name
        if self._db_name is not None:
            plan.catalog.table_exists.db_name = self._db_name
        return plan


class FunctionExists(LogicalPlan):
    def __init__(self, function_name: str, db_name: Optional[str] = None) -> None:
        super().__init__(None)
        self._function_name = function_name
        self._db_name = db_name

    def plan(self, session: "SparkConnectClient") -> proto.Relation:
        plan = self._create_proto_relation()
        plan.catalog.function_exists.function_name = self._function_name
        if self._db_name is not None:
            plan.catalog.function_exists.db_name = self._db_name
        return plan


class CreateTable(LogicalPlan):
    def __init__(
        self,
        table_name: str,
        path: str,
        source: Optional[str] = None,
        description: Optional[str] = None,
        schema: Optional[DataType] = None,
        options: Mapping[str, str] = {},
    ) -> None:
        super().__init__(None)
        self._table_name = table_name
        self._path = path
        self._source = source
        self._description = description
        self._schema = schema
        self._options = options

    def plan(self, session: "SparkConnectClient") -> proto.Relation:
        plan = self._create_proto_relation()
        plan.catalog.create_table.table_name = self._table_name
        if self._path is not None:
            plan.catalog.create_table.path = self._path
        if self._source is not None:
            plan.catalog.create_table.source = self._source
        if self._description is not None:
            plan.catalog.create_table.description = self._description
        if self._schema is not None:
            plan.catalog.create_table.schema.CopyFrom(pyspark_types_to_proto_types(self._schema))
        for k in self._options.keys():
            v = self._options.get(k)
            if v is not None:
                plan.catalog.create_table.options[k] = v
        return plan


class DropTempView(LogicalPlan):
    def __init__(self, view_name: str) -> None:
        super().__init__(None)
        self._view_name = view_name

    def plan(self, session: "SparkConnectClient") -> proto.Relation:
        plan = self._create_proto_relation()
        plan.catalog.drop_temp_view.view_name = self._view_name
        return plan


class DropGlobalTempView(LogicalPlan):
    def __init__(self, view_name: str) -> None:
        super().__init__(None)
        self._view_name = view_name

    def plan(self, session: "SparkConnectClient") -> proto.Relation:
        plan = self._create_proto_relation()
        plan.catalog.drop_global_temp_view.view_name = self._view_name
        return plan


class RecoverPartitions(LogicalPlan):
    def __init__(self, table_name: str) -> None:
        super().__init__(None)
        self._table_name = table_name

    def plan(self, session: "SparkConnectClient") -> proto.Relation:
        plan = self._create_proto_relation()
        plan.catalog.recover_partitions.table_name = self._table_name
        return plan


class IsCached(LogicalPlan):
    def __init__(self, table_name: str) -> None:
        super().__init__(None)
        self._table_name = table_name

    def plan(self, session: "SparkConnectClient") -> proto.Relation:
        plan = self._create_proto_relation()
        plan.catalog.is_cached.table_name = self._table_name
        return plan


class CacheTable(LogicalPlan):
    def __init__(self, table_name: str, storage_level: Optional[StorageLevel] = None) -> None:
        super().__init__(None)
        self._table_name = table_name
        self._storage_level = storage_level

    def plan(self, session: "SparkConnectClient") -> proto.Relation:
        plan = self._create_proto_relation()
        _cache_table = proto.CacheTable(table_name=self._table_name)
        if self._storage_level:
            _cache_table.storage_level.CopyFrom(storage_level_to_proto(self._storage_level))
        plan.catalog.cache_table.CopyFrom(_cache_table)
        return plan


class UncacheTable(LogicalPlan):
    def __init__(self, table_name: str) -> None:
        super().__init__(None)
        self._table_name = table_name

    def plan(self, session: "SparkConnectClient") -> proto.Relation:
        plan = self._create_proto_relation()
        plan.catalog.uncache_table.table_name = self._table_name
        return plan


class ClearCache(LogicalPlan):
    def __init__(self) -> None:
        super().__init__(None)

    def plan(self, session: "SparkConnectClient") -> proto.Relation:
        plan = self._create_proto_relation()
        plan.catalog.clear_cache.SetInParent()
        return plan


class RefreshTable(LogicalPlan):
    def __init__(self, table_name: str) -> None:
        super().__init__(None)
        self._table_name = table_name

    def plan(self, session: "SparkConnectClient") -> proto.Relation:
        plan = self._create_proto_relation()
        plan.catalog.refresh_table.table_name = self._table_name
        return plan


class RefreshByPath(LogicalPlan):
    def __init__(self, path: str) -> None:
        super().__init__(None)
        self._path = path

    def plan(self, session: "SparkConnectClient") -> proto.Relation:
        plan = self._create_proto_relation()
        plan.catalog.refresh_by_path.path = self._path
        return plan


class CurrentCatalog(LogicalPlan):
    def __init__(self) -> None:
        super().__init__(None)

    def plan(self, session: "SparkConnectClient") -> proto.Relation:
        plan = self._create_proto_relation()
        plan.catalog.current_catalog.SetInParent()
        return plan


class SetCurrentCatalog(LogicalPlan):
    def __init__(self, catalog_name: str) -> None:
        super().__init__(None)
        self._catalog_name = catalog_name

    def plan(self, session: "SparkConnectClient") -> proto.Relation:
        plan = self._create_proto_relation()
        plan.catalog.set_current_catalog.catalog_name = self._catalog_name
        return plan


class ListCatalogs(LogicalPlan):
    def __init__(self, pattern: Optional[str] = None) -> None:
        super().__init__(None)
        self._pattern = pattern

    def plan(self, session: "SparkConnectClient") -> proto.Relation:
        plan = self._create_proto_relation()
        plan.catalog.list_catalogs.SetInParent()
        if self._pattern is not None:
            plan.catalog.list_catalogs.pattern = self._pattern
        return plan


class MapPartitions(LogicalPlan):
    """Logical plan object for a mapPartitions-equivalent API: mapInPandas, mapInArrow."""

    def __init__(
        self,
        child: Optional["LogicalPlan"],
        function: "UserDefinedFunction",
        cols: List[str],
        is_barrier: bool,
        profile: Optional[ResourceProfile],
    ) -> None:
        super().__init__(child)

        self._function = function._build_common_inline_user_defined_function(*cols)
        self._is_barrier = is_barrier
        self._profile = profile

    def plan(self, session: "SparkConnectClient") -> proto.Relation:
        assert self._child is not None
        plan = self._create_proto_relation()
        plan.map_partitions.input.CopyFrom(self._child.plan(session))
        plan.map_partitions.func.CopyFrom(self._function.to_plan_udf(session))
        plan.map_partitions.is_barrier = self._is_barrier
        if self._profile is not None:
            plan.map_partitions.profile_id = self._profile.id
        return plan


class GroupMap(LogicalPlan):
    """Logical plan object for a Group Map API: apply, applyInPandas."""

    def __init__(
        self,
        child: Optional["LogicalPlan"],
        grouping_cols: Sequence[Column],
        function: "UserDefinedFunction",
        cols: List[str],
    ):
        assert isinstance(grouping_cols, list) and all(isinstance(c, Column) for c in grouping_cols)

        super().__init__(child, self._collect_references(grouping_cols))
        self._grouping_cols = grouping_cols
        self._function = function._build_common_inline_user_defined_function(*cols)

    def plan(self, session: "SparkConnectClient") -> proto.Relation:
        assert self._child is not None
        plan = self._create_proto_relation()
        plan.group_map.input.CopyFrom(self._child.plan(session))
        plan.group_map.grouping_expressions.extend(
            [c.to_plan(session) for c in self._grouping_cols]
        )
        plan.group_map.func.CopyFrom(self._function.to_plan_udf(session))
        return self._with_relations(plan, session)


class CoGroupMap(LogicalPlan):
    """Logical plan object for a CoGroup Map API: applyInPandas."""

    def __init__(
        self,
        input: Optional["LogicalPlan"],
        input_grouping_cols: Sequence[Column],
        other: Optional["LogicalPlan"],
        other_grouping_cols: Sequence[Column],
        function: "UserDefinedFunction",
    ):
        assert isinstance(input_grouping_cols, list) and all(
            isinstance(c, Column) for c in input_grouping_cols
        )
        assert isinstance(other_grouping_cols, list) and all(
            isinstance(c, Column) for c in other_grouping_cols
        )

        super().__init__(input, self._collect_references(input_grouping_cols + other_grouping_cols))
        self._input_grouping_cols = input_grouping_cols
        self._other_grouping_cols = other_grouping_cols
        self._other = cast(LogicalPlan, other)
        # The function takes entire DataFrame as inputs, no need to do
        # column binding (no input columns).
        self._function = function._build_common_inline_user_defined_function()

    def plan(self, session: "SparkConnectClient") -> proto.Relation:
        assert self._child is not None
        plan = self._create_proto_relation()
        plan.co_group_map.input.CopyFrom(self._child.plan(session))
        plan.co_group_map.input_grouping_expressions.extend(
            [c.to_plan(session) for c in self._input_grouping_cols]
        )
        plan.co_group_map.other.CopyFrom(self._other.plan(session))
        plan.co_group_map.other_grouping_expressions.extend(
            [c.to_plan(session) for c in self._other_grouping_cols]
        )
        plan.co_group_map.func.CopyFrom(self._function.to_plan_udf(session))
        return self._with_relations(plan, session)


class ApplyInPandasWithState(LogicalPlan):
    """Logical plan object for a applyInPandasWithState."""

    def __init__(
        self,
        child: Optional["LogicalPlan"],
        grouping_cols: Sequence[Column],
        function: "UserDefinedFunction",
        output_schema: str,
        state_schema: str,
        output_mode: str,
        timeout_conf: str,
        cols: List[str],
    ):
        assert isinstance(grouping_cols, list) and all(isinstance(c, Column) for c in grouping_cols)

        super().__init__(child, self._collect_references(grouping_cols))
        self._grouping_cols = grouping_cols
        self._function = function._build_common_inline_user_defined_function(*cols)
        self._output_schema = output_schema
        self._state_schema = state_schema
        self._output_mode = output_mode
        self._timeout_conf = timeout_conf

    def plan(self, session: "SparkConnectClient") -> proto.Relation:
        assert self._child is not None
        plan = self._create_proto_relation()
        plan.apply_in_pandas_with_state.input.CopyFrom(self._child.plan(session))
        plan.apply_in_pandas_with_state.grouping_expressions.extend(
            [c.to_plan(session) for c in self._grouping_cols]
        )
        plan.apply_in_pandas_with_state.func.CopyFrom(self._function.to_plan_udf(session))
        plan.apply_in_pandas_with_state.output_schema = self._output_schema
        plan.apply_in_pandas_with_state.state_schema = self._state_schema
        plan.apply_in_pandas_with_state.output_mode = self._output_mode
        plan.apply_in_pandas_with_state.timeout_conf = self._timeout_conf
        return self._with_relations(plan, session)


class BaseTransformWithStateInPySpark(LogicalPlan):
    """Base implementation of logical plan object for a TransformWithStateIn(PySpark/Pandas)."""

    def __init__(
        self,
        child: Optional["LogicalPlan"],
        grouping_cols: Sequence[Column],
        function: "UserDefinedFunction",
        output_schema: Union[DataType, str],
        output_mode: str,
        time_mode: str,
        event_time_col_name: str,
        cols: List[str],
        initial_state_plan: Optional["LogicalPlan"],
        initial_state_grouping_cols: Optional[Sequence[Column]],
    ):
        assert isinstance(grouping_cols, list) and all(isinstance(c, Column) for c in grouping_cols)
        if initial_state_plan is not None:
            assert isinstance(initial_state_grouping_cols, list) and all(
                isinstance(c, Column) for c in initial_state_grouping_cols
            )
            super().__init__(
                child, self._collect_references(grouping_cols + initial_state_grouping_cols)
            )
        else:
            super().__init__(child, self._collect_references(grouping_cols))
        self._grouping_cols = grouping_cols
        self._output_schema: DataType = (
            UnparsedDataType(output_schema) if isinstance(output_schema, str) else output_schema
        )
        self._output_mode = output_mode
        self._time_mode = time_mode
        self._event_time_col_name = event_time_col_name
        self._function = function._build_common_inline_user_defined_function(*cols)
        self._initial_state_plan = initial_state_plan
        self._initial_state_grouping_cols = initial_state_grouping_cols

    def plan(self, session: "SparkConnectClient") -> proto.Relation:
        assert self._child is not None
        plan = self._create_proto_relation()
        plan.group_map.input.CopyFrom(self._child.plan(session))
        plan.group_map.grouping_expressions.extend(
            [c.to_plan(session) for c in self._grouping_cols]
        )
        plan.group_map.output_mode = self._output_mode

        # fill in initial state related fields
        if self._initial_state_plan is not None:
            plan.group_map.initial_input.CopyFrom(self._initial_state_plan.plan(session))
            assert self._initial_state_grouping_cols is not None
            plan.group_map.initial_grouping_expressions.extend(
                [c.to_plan(session) for c in self._initial_state_grouping_cols]
            )

        # fill in transformWithStateInPySpark/Pandas related fields
        tws_info = proto.TransformWithStateInfo()
        tws_info.time_mode = self._time_mode
        tws_info.event_time_column_name = self._event_time_col_name
        tws_info.output_schema.CopyFrom(pyspark_types_to_proto_types(self._output_schema))

        plan.group_map.transform_with_state_info.CopyFrom(tws_info)

        # wrap transformWithStateInPySparkUdf in a function
        plan.group_map.func.CopyFrom(self._function.to_plan_udf(session))

        return self._with_relations(plan, session)


class TransformWithStateInPySpark(BaseTransformWithStateInPySpark):
    """Logical plan object for a TransformWithStateInPySpark."""

    pass


# Retaining this to avoid breaking backward compatibility.
class TransformWithStateInPandas(BaseTransformWithStateInPySpark):
    """Logical plan object for a TransformWithStateInPandas."""

    pass


class PythonUDTF:
    """Represents a Python user-defined table function."""

    def __init__(
        self,
        func: Type,
        return_type: Optional[Union[DataType, str]],
        eval_type: int,
        python_ver: str,
    ) -> None:
        self._func = func
        self._name = func.__name__
        self._return_type: Optional[DataType] = (
            None
            if return_type is None
            else UnparsedDataType(return_type)
            if isinstance(return_type, str)
            else return_type
        )
        self._eval_type = eval_type
        self._python_ver = python_ver

    def to_plan(self, session: "SparkConnectClient") -> proto.PythonUDTF:
        udtf = proto.PythonUDTF()
        if self._return_type is not None:
            udtf.return_type.CopyFrom(pyspark_types_to_proto_types(self._return_type))
        udtf.eval_type = self._eval_type
        try:
            udtf.command = CloudPickleSerializer().dumps(self._func)
        except pickle.PicklingError:
            raise PySparkPicklingError(
                errorClass="UDTF_SERIALIZATION_ERROR",
                messageParameters={
                    "name": self._name,
                    "message": "Please check the stack trace and "
                    "make sure the function is serializable.",
                },
            )
        udtf.python_ver = self._python_ver
        return udtf

    def __repr__(self) -> str:
        return (
            f"PythonUDTF({self._name}, {self._return_type}, "
            f"{self._eval_type}, {self._python_ver})"
        )


class CommonInlineUserDefinedTableFunction(LogicalPlan):
    """
    Logical plan object for a user-defined table function with
    an inlined defined function body.
    """

    def __init__(
        self,
        function_name: str,
        function: PythonUDTF,
        deterministic: bool,
        arguments: Sequence[Expression],
    ) -> None:
        super().__init__(None, self._collect_references(arguments))
        self._function_name = function_name
        self._deterministic = deterministic
        self._arguments = arguments
        self._function = function

    def plan(self, session: "SparkConnectClient") -> proto.Relation:
        plan = self._create_proto_relation()
        plan.common_inline_user_defined_table_function.function_name = self._function_name
        plan.common_inline_user_defined_table_function.deterministic = self._deterministic
        if len(self._arguments) > 0:
            plan.common_inline_user_defined_table_function.arguments.extend(
                [arg.to_plan(session) for arg in self._arguments]
            )
        plan.common_inline_user_defined_table_function.python_udtf.CopyFrom(
            self._function.to_plan(session)
        )
        return self._with_relations(plan, session)

    def udtf_plan(
        self, session: "SparkConnectClient"
    ) -> "proto.CommonInlineUserDefinedTableFunction":
        """
        Compared to `plan`, it returns a `proto.CommonInlineUserDefinedTableFunction`
        instead of a `proto.Relation`.
        """
        plan = proto.CommonInlineUserDefinedTableFunction()
        plan.function_name = self._function_name
        plan.deterministic = self._deterministic
        if len(self._arguments) > 0:
            plan.arguments.extend([arg.to_plan(session) for arg in self._arguments])
        plan.python_udtf.CopyFrom(
            cast(proto.PythonUDF, self._function.to_plan(session))  # type: ignore[arg-type]
        )
        return plan

    def __repr__(self) -> str:
        return f"{self._function_name}({', '.join([str(arg) for arg in self._arguments])})"


class PythonDataSource:
    """Represents a user-defined Python data source."""

    def __init__(self, data_source: Type, python_ver: str):
        self._data_source = data_source
        self._python_ver = python_ver

    def to_plan(self, session: "SparkConnectClient") -> proto.PythonDataSource:
        ds = proto.PythonDataSource()
        ds.command = CloudPickleSerializer().dumps(self._data_source)
        ds.python_ver = self._python_ver
        return ds


class CommonInlineUserDefinedDataSource(LogicalPlan):
    """Logical plan object for a user-defined data source"""

    def __init__(self, name: str, data_source: PythonDataSource) -> None:
        super().__init__(None)
        self._name = name
        self._data_source = data_source

    def plan(self, session: "SparkConnectClient") -> proto.Relation:
        plan = self._create_proto_relation()
        plan.common_inline_user_defined_data_source.name = self._name
        plan.common_inline_user_defined_data_source.python_data_source.CopyFrom(
            self._data_source.to_plan(session)
        )
        return plan

    def to_data_source_proto(
        self, session: "SparkConnectClient"
    ) -> "proto.CommonInlineUserDefinedDataSource":
        plan = proto.CommonInlineUserDefinedDataSource()
        plan.name = self._name
        plan.python_data_source.CopyFrom(self._data_source.to_plan(session))
        return plan


class CachedRelation(LogicalPlan):
    def __init__(self, plan: proto.Relation) -> None:
        super(CachedRelation, self).__init__(None)
        self._plan = plan
        # Update the plan ID based on the incremented counter.
        self._plan.common.plan_id = self._plan_id

    def plan(self, session: "SparkConnectClient") -> proto.Relation:
        return self._plan
