#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
__all__ = [
    "ChannelBuilder",
    "DefaultChannelBuilder",
    "SparkConnectClient",
]

import atexit

from pyspark.sql.connect.utils import check_dependencies

check_dependencies(__name__)

import logging
import threading
import os
import copy
import platform
import urllib.parse
import uuid
import sys
from typing import (
    Iterable,
    Iterator,
    Optional,
    Any,
    Union,
    List,
    Tuple,
    Dict,
    Set,
    NoReturn,
    Mapping,
    cast,
    TYPE_CHECKING,
    Type,
    Sequence,
)

import pandas as pd
import pyarrow as pa

import google.protobuf.message
from grpc_status import rpc_status
import grpc
from google.protobuf import text_format, any_pb2
from google.rpc import error_details_pb2

from pyspark.util import is_remote_only
from pyspark.accumulators import SpecialAccumulatorIds
from pyspark.loose_version import LooseVersion
from pyspark.version import __version__
from pyspark.resource.information import ResourceInformation
from pyspark.sql.metrics import MetricValue, PlanMetrics, ExecutionInfo, ObservedMetrics
from pyspark.sql.connect.client.artifact import ArtifactManager
from pyspark.sql.connect.logging import logger
from pyspark.sql.connect.profiler import ConnectProfilerCollector
from pyspark.sql.connect.client.reattach import ExecutePlanResponseReattachableIterator
from pyspark.sql.connect.client.retries import RetryPolicy, Retrying, DefaultPolicy
from pyspark.sql.connect.conversion import (
    storage_level_to_proto,
    proto_to_storage_level,
    proto_to_remote_cached_dataframe,
)
import pyspark.sql.connect.proto as pb2
import pyspark.sql.connect.proto.base_pb2_grpc as grpc_lib
import pyspark.sql.connect.types as types
from pyspark.errors.exceptions.connect import (
    convert_exception,
    SparkConnectException,
    SparkConnectGrpcException,
)
from pyspark.sql.connect.expressions import (
    LiteralExpression,
    PythonUDF,
    CommonInlineUserDefinedFunction,
    JavaUDF,
)
from pyspark.sql.connect.plan import (
    CommonInlineUserDefinedTableFunction,
    CommonInlineUserDefinedDataSource,
    PythonUDTF,
    PythonDataSource,
)
from pyspark.sql.connect.observation import Observation
from pyspark.sql.connect.utils import get_python_ver
from pyspark.sql.pandas.types import _create_converter_to_pandas, from_arrow_schema
from pyspark.sql.types import DataType, StructType, TimestampType, _has_type
from pyspark.util import PythonEvalType
from pyspark.storagelevel import StorageLevel
from pyspark.errors import PySparkValueError, PySparkAssertionError, PySparkNotImplementedError
from pyspark.sql.connect.shell.progress import Progress, ProgressHandler, from_proto

if TYPE_CHECKING:
    from google.rpc.error_details_pb2 import ErrorInfo
    from pyspark.sql.connect._typing import DataTypeOrString
    from pyspark.sql.datasource import DataSource


class ChannelBuilder:
    """
    This is a helper class that is used to create a GRPC channel based on the given
    connection string per the documentation of Spark Connect.

    The standard implementation is in :class:`DefaultChannelBuilder`.
    """

    PARAM_USE_SSL = "use_ssl"
    PARAM_TOKEN = "token"
    PARAM_USER_ID = "user_id"
    PARAM_USER_AGENT = "user_agent"
    PARAM_SESSION_ID = "session_id"

    GRPC_MAX_MESSAGE_LENGTH_DEFAULT = 128 * 1024 * 1024

    GRPC_DEFAULT_OPTIONS = [
        ("grpc.max_send_message_length", GRPC_MAX_MESSAGE_LENGTH_DEFAULT),
        ("grpc.max_receive_message_length", GRPC_MAX_MESSAGE_LENGTH_DEFAULT),
    ]

    def __init__(
        self,
        channelOptions: Optional[List[Tuple[str, Any]]] = None,
        params: Optional[Dict[str, str]] = None,
    ):
        self._interceptors: List[grpc.UnaryStreamClientInterceptor] = []
        self._params: Dict[str, str] = params or dict()
        self._channel_options: List[Tuple[str, Any]] = ChannelBuilder.GRPC_DEFAULT_OPTIONS.copy()

        if channelOptions is not None:
            for key, value in channelOptions:
                self.setChannelOption(key, value)

    def get(self, key: str) -> Any:
        """
        Parameters
        ----------
        key : str
            Parameter key name.
        Returns
        -------
        The parameter value if present, raises exception otherwise.
        """
        return self._params[key]

    def getDefault(self, key: str, default: Any) -> Any:
        return self._params.get(key, default)

    def set(self, key: str, value: Any) -> None:
        self._params[key] = value

    def setChannelOption(self, key: str, value: Any) -> None:
        # overwrite option if it exists already else append it
        for i, option in enumerate(self._channel_options):
            if option[0] == key:
                self._channel_options[i] = (key, value)
                return
        self._channel_options.append((key, value))

    def add_interceptor(self, interceptor: grpc.UnaryStreamClientInterceptor) -> None:
        self._interceptors.append(interceptor)

    def toChannel(self) -> grpc.Channel:
        """
        The actual channel builder implementations should implement this function
        to return grpc Channel.
        This function should generally use self._insecure_channel or
        self._secure_channel so that configuration options are applied
        appropriately.
        """
        raise PySparkNotImplementedError

    @property
    def host(self) -> str:
        """
        The hostname where this client intends to connect.
        This is used for end-user display purpose in REPL
        """
        raise PySparkNotImplementedError

    def _insecure_channel(self, target: Any, **kwargs: Any) -> grpc.Channel:
        channel = grpc.insecure_channel(target, options=self._channel_options, **kwargs)

        if len(self._interceptors) > 0:
            logger.debug(f"Applying interceptors ({self._interceptors})")
            channel = grpc.intercept_channel(channel, *self._interceptors)
        return channel

    def _secure_channel(self, target: Any, credentials: Any, **kwargs: Any) -> grpc.Channel:
        channel = grpc.secure_channel(target, credentials, options=self._channel_options, **kwargs)

        if len(self._interceptors) > 0:
            logger.debug(f"Applying interceptors ({self._interceptors})")
            channel = grpc.intercept_channel(channel, *self._interceptors)
        return channel

    @property
    def userId(self) -> Optional[str]:
        """
        Returns
        -------
        The user_id (extracted from connection string or configured by other means).
        """
        return self._params.get(ChannelBuilder.PARAM_USER_ID, None)

    @property
    def token(self) -> Optional[str]:
        return self._params.get(
            ChannelBuilder.PARAM_TOKEN, os.environ.get("SPARK_CONNECT_AUTHENTICATE_TOKEN")
        )

    def metadata(self) -> Iterable[Tuple[str, str]]:
        """
        Builds the GRPC specific metadata list to be injected into the request. All
        parameters will be converted to metadata except ones that are explicitly used
        by the channel.
        Returns
        -------
        A list of tuples (key, value)
        """
        return [
            (k, self._params[k])
            for k in self._params
            if k
            not in [
                ChannelBuilder.PARAM_TOKEN,
                ChannelBuilder.PARAM_USE_SSL,
                ChannelBuilder.PARAM_USER_ID,
                ChannelBuilder.PARAM_USER_AGENT,
                ChannelBuilder.PARAM_SESSION_ID,
            ]
        ]

    @property
    def session_id(self) -> Optional[str]:
        """
        Returns
        -------
        The session_id extracted from the parameters of the connection string or `None` if not
        specified.
        """
        session_id = self._params.get(ChannelBuilder.PARAM_SESSION_ID, None)
        if session_id is not None:
            try:
                uuid.UUID(session_id, version=4)
            except ValueError as ve:
                raise PySparkValueError(
                    errorClass="INVALID_SESSION_UUID_ID",
                    messageParameters={"arg_name": "session_id", "origin": str(ve)},
                )
        return session_id

    @property
    def userAgent(self) -> str:
        """
        Returns
        -------
        user_agent : str
            The user_agent parameter specified in the connection string,
            or "_SPARK_CONNECT_PYTHON" when not specified.
            The returned value will be percent encoded.
        """
        user_agent = self._params.get(
            ChannelBuilder.PARAM_USER_AGENT,
            os.getenv("SPARK_CONNECT_USER_AGENT", "_SPARK_CONNECT_PYTHON"),
        )

        ua_len = len(urllib.parse.quote(user_agent))
        if ua_len > 2048:
            raise SparkConnectException(
                f"'user_agent' parameter should not exceed 2048 characters, found {len} characters."
            )
        return " ".join(
            [
                user_agent,
                f"spark/{__version__}",
                f"os/{platform.uname().system.lower()}",
                f"python/{platform.python_version()}",
            ]
        )


class DefaultChannelBuilder(ChannelBuilder):
    """
    This is a helper class that is used to create a GRPC channel based on the given
    connection string per the documentation of Spark Connect.

    .. versionadded:: 3.4.0

    Examples
    --------
    >>> cb =  DefaultChannelBuilder("sc://localhost")
    ... cb.endpoint
    "localhost:15002"

    >>> cb = DefaultChannelBuilder("sc://localhost/;use_ssl=true;token=aaa")
    ... cb.secure
    True
    """

    @staticmethod
    def default_port() -> int:
        if "SPARK_TESTING" in os.environ and not is_remote_only():
            from pyspark.sql.session import SparkSession as PySparkSession

            # In the case when Spark Connect uses the local mode, it starts the regular Spark
            # session that starts Spark Connect server that sets `SparkSession._instantiatedSession`
            # via SparkSession.__init__.
            #
            # We are getting the actual server port from the Spark session via Py4J to address
            # the case when the server port is set to 0 (in which allocates an ephemeral port).
            #
            # This is only used in the test/development mode.
            session = PySparkSession._instantiatedSession

            if session is not None:
                jvm = PySparkSession._instantiatedSession._jvm  # type: ignore[union-attr]
                return getattr(
                    getattr(
                        jvm,
                        "org.apache.spark.sql.connect.service.SparkConnectService$",
                    ),
                    "MODULE$",
                ).localPort()
        return 15002

    def __init__(self, url: str, channelOptions: Optional[List[Tuple[str, Any]]] = None) -> None:
        """
        Constructs a new channel builder. This is used to create the proper GRPC channel from
        the connection string.

        Parameters
        ----------
        url : str
            Spark Connect connection string
        channelOptions: list of tuple, optional
            Additional options that can be passed to the GRPC channel construction.
        """

        super().__init__(channelOptions=channelOptions)

        # Explicitly check the scheme of the URL.
        if url[:5] != "sc://":
            raise PySparkValueError(
                errorClass="INVALID_CONNECT_URL",
                messageParameters={
                    "detail": "The URL must start with 'sc://'. Please update the URL to "
                    "follow the correct format, e.g., 'sc://hostname:port'.",
                },
            )
        # Rewrite the URL to use http as the scheme so that we can leverage
        # Python's built-in parser.
        tmp_url = "http" + url[2:]
        self.url = urllib.parse.urlparse(tmp_url)
        if len(self.url.path) > 0 and self.url.path != "/":
            raise PySparkValueError(
                errorClass="INVALID_CONNECT_URL",
                messageParameters={
                    "detail": f"The path component '{self.url.path}' must be empty. Please update "
                    f"the URL to follow the correct format, e.g., 'sc://hostname:port'.",
                },
            )
        self._extract_attributes()

    def _extract_attributes(self) -> None:
        if len(self.url.params) > 0:
            parts = self.url.params.split(";")
            for p in parts:
                kv = p.split("=")
                if len(kv) != 2:
                    raise PySparkValueError(
                        errorClass="INVALID_CONNECT_URL",
                        messageParameters={
                            "detail": f"Parameter '{p}' should be provided as a "
                            f"key-value pair separated by an equal sign (=). Please update "
                            f"the parameter to follow the correct format, e.g., 'key=value'.",
                        },
                    )
                self.set(kv[0], urllib.parse.unquote(kv[1]))

        netloc = self.url.netloc.split(":")
        if len(netloc) == 1:
            self._host = netloc[0]
            self._port = DefaultChannelBuilder.default_port()
        elif len(netloc) == 2:
            self._host = netloc[0]
            self._port = int(netloc[1])
        else:
            raise PySparkValueError(
                errorClass="INVALID_CONNECT_URL",
                messageParameters={
                    "detail": f"Target destination '{self.url.netloc}' should match the "
                    f"'<host>:<port>' pattern. Please update the destination to follow "
                    f"the correct format, e.g., 'hostname:port'.",
                },
            )

    @property
    def secure(self) -> bool:
        return self.use_ssl or self.token is not None

    @property
    def use_ssl(self) -> bool:
        return self.getDefault(ChannelBuilder.PARAM_USE_SSL, "").lower() == "true"

    @property
    def host(self) -> str:
        """
        The hostname where this client intends to connect.
        """
        return self._host

    @property
    def endpoint(self) -> str:
        return f"{self._host}:{self._port}"

    def toChannel(self) -> grpc.Channel:
        """
        Applies the parameters of the connection string and creates a new
        GRPC channel according to the configuration. Passes optional channel options to
        construct the channel.

        Returns
        -------
        GRPC Channel instance.
        """

        if not self.secure:
            return self._insecure_channel(self.endpoint)
        elif not self.use_ssl and self._host == "localhost":
            creds = grpc.local_channel_credentials()

            if self.token is not None:
                creds = grpc.composite_channel_credentials(
                    creds, grpc.access_token_call_credentials(self.token)
                )
            return self._secure_channel(self.endpoint, creds)
        else:
            creds = grpc.ssl_channel_credentials()

            if self.token is not None:
                creds = grpc.composite_channel_credentials(
                    creds, grpc.access_token_call_credentials(self.token)
                )

            return self._secure_channel(self.endpoint, creds)


class PlanObservedMetrics(ObservedMetrics):
    def __init__(self, name: str, metrics: List[pb2.Expression.Literal], keys: List[str]):
        self._name = name
        self._metrics = metrics
        self._keys = keys if keys else [f"observed_metric_{i}" for i in range(len(self.metrics))]

    def __repr__(self) -> str:
        return f"Plan observed({self._name}={self._metrics})"

    @property
    def name(self) -> str:
        return self._name

    @property
    def metrics(self) -> List[pb2.Expression.Literal]:
        return self._metrics

    @property
    def pairs(self) -> dict[str, Any]:
        result = {}
        for x in range(len(self._metrics)):
            result[self.keys[x]] = LiteralExpression._to_value(self.metrics[x])
        return result

    @property
    def keys(self) -> List[str]:
        return self._keys


class AnalyzeResult:
    def __init__(
        self,
        schema: Optional[DataType],
        explain_string: Optional[str],
        tree_string: Optional[str],
        is_local: Optional[bool],
        is_streaming: Optional[bool],
        input_files: Optional[List[str]],
        spark_version: Optional[str],
        parsed: Optional[DataType],
        is_same_semantics: Optional[bool],
        semantic_hash: Optional[int],
        storage_level: Optional[StorageLevel],
        ddl_string: Optional[str],
    ):
        self.schema = schema
        self.explain_string = explain_string
        self.tree_string = tree_string
        self.is_local = is_local
        self.is_streaming = is_streaming
        self.input_files = input_files
        self.spark_version = spark_version
        self.parsed = parsed
        self.is_same_semantics = is_same_semantics
        self.semantic_hash = semantic_hash
        self.storage_level = storage_level
        self.ddl_string = ddl_string

    @classmethod
    def fromProto(cls, pb: Any) -> "AnalyzeResult":
        schema: Optional[DataType] = None
        explain_string: Optional[str] = None
        tree_string: Optional[str] = None
        is_local: Optional[bool] = None
        is_streaming: Optional[bool] = None
        input_files: Optional[List[str]] = None
        spark_version: Optional[str] = None
        parsed: Optional[DataType] = None
        is_same_semantics: Optional[bool] = None
        semantic_hash: Optional[int] = None
        storage_level: Optional[StorageLevel] = None
        ddl_string: Optional[str] = None

        if pb.HasField("schema"):
            schema = types.proto_schema_to_pyspark_data_type(pb.schema.schema)
        elif pb.HasField("explain"):
            explain_string = pb.explain.explain_string
        elif pb.HasField("tree_string"):
            tree_string = pb.tree_string.tree_string
        elif pb.HasField("is_local"):
            is_local = pb.is_local.is_local
        elif pb.HasField("is_streaming"):
            is_streaming = pb.is_streaming.is_streaming
        elif pb.HasField("input_files"):
            input_files = pb.input_files.files
        elif pb.HasField("spark_version"):
            spark_version = pb.spark_version.version
        elif pb.HasField("ddl_parse"):
            parsed = types.proto_schema_to_pyspark_data_type(pb.ddl_parse.parsed)
        elif pb.HasField("same_semantics"):
            is_same_semantics = pb.same_semantics.result
        elif pb.HasField("semantic_hash"):
            semantic_hash = pb.semantic_hash.result
        elif pb.HasField("persist"):
            pass
        elif pb.HasField("unpersist"):
            pass
        elif pb.HasField("get_storage_level"):
            storage_level = proto_to_storage_level(pb.get_storage_level.storage_level)
        elif pb.HasField("json_to_ddl"):
            ddl_string = pb.json_to_ddl.ddl_string
        else:
            raise SparkConnectException("No analyze result found!")

        return AnalyzeResult(
            schema,
            explain_string,
            tree_string,
            is_local,
            is_streaming,
            input_files,
            spark_version,
            parsed,
            is_same_semantics,
            semantic_hash,
            storage_level,
            ddl_string,
        )


class ConfigResult:
    def __init__(self, pairs: List[Tuple[str, Optional[str]]], warnings: List[str]):
        self.pairs = pairs
        self.warnings = warnings

    @classmethod
    def fromProto(cls, pb: pb2.ConfigResponse) -> "ConfigResult":
        return ConfigResult(
            pairs=[(pair.key, pair.value if pair.HasField("value") else None) for pair in pb.pairs],
            warnings=list(pb.warnings),
        )


class SparkConnectClient(object):
    """
    Conceptually the remote spark session that communicates with the server
    """

    def __init__(
        self,
        connection: Union[str, ChannelBuilder],
        user_id: Optional[str] = None,
        channel_options: Optional[List[Tuple[str, Any]]] = None,
        retry_policy: Optional[Dict[str, Any]] = None,
        use_reattachable_execute: bool = True,
    ):
        """
        Creates a new SparkSession for the Spark Connect interface.

        Parameters
        ----------
        connection : str or :class:`ChannelBuilder`
            Connection string that is used to extract the connection parameters and configure
            the GRPC connection. Or instance of ChannelBuilder that creates GRPC connection.
            Defaults to `sc://localhost`.
        user_id : str, optional
            Optional unique user ID that is used to differentiate multiple users and
            isolate their Spark Sessions. If the `user_id` is not set, will default to
            the $USER environment. Defining the user ID as part of the connection string
            takes precedence.
        channel_options: list of tuple, optional
            Additional options that can be passed to the GRPC channel construction.
        retry_policy: dict of str and any, optional
            Additional configuration for retrying. There are four configurations as below
                * ``max_retries``
                    Maximum number of tries default 15
                * ``backoff_multiplier``
                    Backoff multiplier for the policy. Default: 4(ms)
                * ``initial_backoff``
                    Backoff to wait before the first retry. Default: 50(ms)
                * ``max_backoff``
                    Maximum backoff controls the maximum amount of time to wait before retrying
                    a failed request. Default: 60000(ms).
        use_reattachable_execute: bool
            Enable reattachable execution.
        """
        self.thread_local = threading.local()

        # Parse the connection string.
        self._builder = (
            connection
            if isinstance(connection, ChannelBuilder)
            else DefaultChannelBuilder(connection, channel_options)
        )
        self._user_id = None
        self._retry_policies: List[RetryPolicy] = []

        retry_policy_args = retry_policy or dict()
        default_policy = DefaultPolicy(**retry_policy_args)
        self.set_retry_policies([default_policy])

        if self._builder.session_id is None:
            # Generate a unique session ID for this client. This UUID must be unique to allow
            # concurrent Spark sessions of the same user. If the channel is closed, creating
            # a new client will create a new session ID.
            self._session_id = str(uuid.uuid4())
        else:
            # Use the pre-defined session ID.
            self._session_id = str(self._builder.session_id)

        if self._builder.userId is not None:
            self._user_id = self._builder.userId
        elif user_id is not None:
            self._user_id = user_id
        else:
            self._user_id = os.getenv("SPARK_USER", os.getenv("USER", None))

        self._channel = self._builder.toChannel()
        self._closed = False
        self._internal_stub = grpc_lib.SparkConnectServiceStub(self._channel)
        self._artifact_manager = ArtifactManager(
            self._user_id, self._session_id, self._channel, self._builder.metadata()
        )
        self._use_reattachable_execute = use_reattachable_execute
        # Configure logging for the SparkConnect client.

        # Capture the server-side session ID and set it to None initially. It will
        # be updated on the first response received.
        self._server_session_id: Optional[str] = None

        self._profiler_collector = ConnectProfilerCollector()

        self._progress_handlers: List[ProgressHandler] = []

        # cleanup ml cache if possible
        atexit.register(self._cleanup_ml_cache)

    @property
    def _stub(self) -> grpc_lib.SparkConnectServiceStub:
        if self.is_closed:
            raise SparkConnectException(
                errorClass="NO_ACTIVE_SESSION", messageParameters=dict()
            ) from None
        return self._internal_stub

    # For testing only.
    @_stub.setter
    def _stub(self, value: grpc_lib.SparkConnectServiceStub) -> None:
        self._internal_stub = value

    def register_progress_handler(self, handler: ProgressHandler) -> None:
        """
        Register a progress handler to be called when a progress message is received.

        Parameters
        ----------
        handler : ProgressHandler
          The callable that will be called with the progress information.

        """
        if handler in self._progress_handlers:
            return
        self._progress_handlers.append(handler)

    def clear_progress_handlers(self) -> None:
        self._progress_handlers.clear()

    def remove_progress_handler(self, handler: ProgressHandler) -> None:
        """
        Remove a progress handler from the list of registered handlers.

        Parameters
        ----------
        handler : ProgressHandler
          The callable to remove from the list of progress handlers.

        """
        self._progress_handlers.remove(handler)

    def _retrying(self) -> "Retrying":
        return Retrying(self._retry_policies)

    def disable_reattachable_execute(self) -> "SparkConnectClient":
        self._use_reattachable_execute = False
        return self

    def enable_reattachable_execute(self) -> "SparkConnectClient":
        self._use_reattachable_execute = True
        return self

    def set_retry_policies(self, policies: Iterable[RetryPolicy]) -> None:
        """
        Sets list of policies to be used for retries.
        I.e. set_retry_policies([DefaultPolicy(), CustomPolicy()]).

        """
        self._retry_policies = list(policies)

    def get_retry_policies(self) -> List[RetryPolicy]:
        """
        Return list of currently used policies
        """
        return list(self._retry_policies)

    def register_udf(
        self,
        function: Any,
        return_type: "DataTypeOrString",
        name: Optional[str] = None,
        eval_type: int = PythonEvalType.SQL_BATCHED_UDF,
        deterministic: bool = True,
    ) -> str:
        """
        Create a temporary UDF in the session catalog on the other side. We generate a
        temporary name for it.
        """

        if name is None:
            name = f"fun_{uuid.uuid4().hex}"

        # construct a PythonUDF
        py_udf = PythonUDF(
            output_type=return_type,
            eval_type=eval_type,
            func=function,
            python_ver="%d.%d" % sys.version_info[:2],
        )

        # construct a CommonInlineUserDefinedFunction
        fun = CommonInlineUserDefinedFunction(
            function_name=name,
            arguments=[],
            function=py_udf,
            deterministic=deterministic,
        ).to_plan_udf(self)

        # construct the request
        req = self._execute_plan_request_with_metadata()
        req.plan.command.register_function.CopyFrom(fun)

        self._execute(req)
        return name

    def register_udtf(
        self,
        function: Any,
        return_type: Optional["DataTypeOrString"],
        name: str,
        eval_type: int = PythonEvalType.SQL_TABLE_UDF,
        deterministic: bool = True,
    ) -> str:
        """
        Register a user-defined table function (UDTF) in the session catalog
        as a temporary function. The return type, if specified, must be a
        struct type and it's validated when building the proto message
        for the PythonUDTF.
        """
        udtf = PythonUDTF(
            func=function,
            return_type=return_type,
            eval_type=eval_type,
            python_ver=get_python_ver(),
        )

        func = CommonInlineUserDefinedTableFunction(
            function_name=name,
            function=udtf,
            deterministic=deterministic,
            arguments=[],
        ).udtf_plan(self)

        req = self._execute_plan_request_with_metadata()
        req.plan.command.register_table_function.CopyFrom(func)

        self._execute(req)
        return name

    def register_data_source(self, dataSource: Type["DataSource"]) -> None:
        """
        Register a data source in the session catalog.
        """
        data_source = PythonDataSource(
            data_source=dataSource,
            python_ver=get_python_ver(),
        )
        proto = CommonInlineUserDefinedDataSource(
            name=dataSource.name(),
            data_source=data_source,
        ).to_data_source_proto(self)

        req = self._execute_plan_request_with_metadata()
        req.plan.command.register_data_source.CopyFrom(proto)
        self._execute(req)

    def register_java(
        self,
        name: str,
        javaClassName: str,
        return_type: Optional["DataTypeOrString"] = None,
        aggregate: bool = False,
    ) -> None:
        # construct a JavaUDF
        if return_type is None:
            java_udf = JavaUDF(class_name=javaClassName, aggregate=aggregate)
        else:
            java_udf = JavaUDF(class_name=javaClassName, output_type=return_type)
        fun = CommonInlineUserDefinedFunction(
            function_name=name,
            function=java_udf,
        ).to_plan_judf(self)
        # construct the request
        req = self._execute_plan_request_with_metadata()
        req.plan.command.register_function.CopyFrom(fun)

        self._execute(req)

    def _build_metrics(self, metrics: "pb2.ExecutePlanResponse.Metrics") -> Iterator[PlanMetrics]:
        return (
            PlanMetrics(
                x.name,
                x.plan_id,
                x.parent,
                [MetricValue(k, v.value, v.metric_type) for k, v in x.execution_metrics.items()],
            )
            for x in metrics.metrics
        )

    def _resources(self) -> Dict[str, ResourceInformation]:
        logger.debug("Fetching the resources")
        cmd = pb2.Command()
        cmd.get_resources_command.SetInParent()
        (_, properties, _) = self.execute_command(cmd)
        resources = properties["get_resources_command_result"]
        return resources

    def _build_observed_metrics(
        self, metrics: Sequence["pb2.ExecutePlanResponse.ObservedMetrics"]
    ) -> Iterator[PlanObservedMetrics]:
        return (PlanObservedMetrics(x.name, [v for v in x.values], list(x.keys)) for x in metrics)

    def to_table_as_iterator(
        self, plan: pb2.Plan, observations: Dict[str, Observation]
    ) -> Iterator[Union[StructType, "pa.Table"]]:
        """
        Return given plan as a PyArrow Table iterator.
        """
        if logger.isEnabledFor(logging.DEBUG):
            # inside an if statement to not incur a performance cost converting proto to string
            # when not at debug log level.
            logger.debug(f"Executing plan {self._proto_to_string(plan, True)}")
        req = self._execute_plan_request_with_metadata()
        req.plan.CopyFrom(plan)
        with Progress(handlers=self._progress_handlers, operation_id=req.operation_id) as progress:
            for response in self._execute_and_fetch_as_iterator(req, observations, progress):
                if isinstance(response, StructType):
                    yield response
                elif isinstance(response, pa.RecordBatch):
                    yield pa.Table.from_batches([response])

    def to_table(
        self, plan: pb2.Plan, observations: Dict[str, Observation]
    ) -> Tuple["pa.Table", Optional[StructType], ExecutionInfo]:
        """
        Return given plan as a PyArrow Table.
        """
        if logger.isEnabledFor(logging.DEBUG):
            # inside an if statement to not incur a performance cost converting proto to string
            # when not at debug log level.
            logger.debug(f"Executing plan {self._proto_to_string(plan, True)}")
        req = self._execute_plan_request_with_metadata()
        req.plan.CopyFrom(plan)
        table, schema, metrics, observed_metrics, _ = self._execute_and_fetch(req, observations)

        # Create a query execution object.
        ei = ExecutionInfo(metrics, observed_metrics)
        assert table is not None
        return table, schema, ei

    def to_pandas(
        self, plan: pb2.Plan, observations: Dict[str, Observation]
    ) -> Tuple["pd.DataFrame", "ExecutionInfo"]:
        """
        Return given plan as a pandas DataFrame.
        """
        if logger.isEnabledFor(logging.DEBUG):
            # inside an if statement to not incur a performance cost converting proto to string
            # when not at debug log level.
            logger.debug(f"Executing plan {self._proto_to_string(plan, True)}")
        req = self._execute_plan_request_with_metadata()
        req.plan.CopyFrom(plan)
        (self_destruct_conf,) = self.get_config_with_defaults(
            ("spark.sql.execution.arrow.pyspark.selfDestruct.enabled", "false"),
        )
        self_destruct = cast(str, self_destruct_conf).lower() == "true"
        table, schema, metrics, observed_metrics, _ = self._execute_and_fetch(
            req, observations, self_destruct=self_destruct
        )
        assert table is not None
        ei = ExecutionInfo(metrics, observed_metrics)

        schema = schema or from_arrow_schema(table.schema, prefer_timestamp_ntz=True)
        assert schema is not None and isinstance(schema, StructType)

        # SPARK-51112: If the table is empty, we avoid using pyarrow to_pandas to create the
        # DataFrame, as it may fail with a segmentation fault. Instead, we create an empty pandas
        # DataFrame manually with the correct schema.
        if table.num_rows == 0:
            pdf = pd.DataFrame(columns=schema.names, index=range(0))
        else:
            # Rename columns to avoid duplicated column names.
            renamed_table = table.rename_columns([f"col_{i}" for i in range(table.num_columns)])

            pandas_options = {}
            if self_destruct:
                # Configure PyArrow to use as little memory as possible:
                # self_destruct - free columns as they are converted
                # split_blocks - create a separate Pandas block for each column
                # use_threads - convert one column at a time
                pandas_options.update(
                    {
                        "self_destruct": True,
                        "split_blocks": True,
                        "use_threads": False,
                    }
                )
            if LooseVersion(pa.__version__) >= LooseVersion("13.0.0"):
                # A legacy option to coerce date32, date64, duration, and timestamp
                # time units to nanoseconds when converting to pandas.
                # This option can only be added since 13.0.0.
                pandas_options.update(
                    {
                        "coerce_temporal_nanoseconds": True,
                    }
                )
            pdf = renamed_table.to_pandas(**pandas_options)
            pdf.columns = schema.names

        if len(pdf.columns) > 0:
            timezone: Optional[str] = None
            if any(_has_type(f.dataType, TimestampType) for f in schema.fields):
                (timezone,) = self.get_configs("spark.sql.session.timeZone")

            struct_in_pandas: Optional[str] = None
            error_on_duplicated_field_names: bool = False
            if any(_has_type(f.dataType, StructType) for f in schema.fields):
                (struct_in_pandas,) = self.get_config_with_defaults(
                    ("spark.sql.execution.pandas.structHandlingMode", "legacy"),
                )

                if struct_in_pandas == "legacy":
                    error_on_duplicated_field_names = True
                    struct_in_pandas = "dict"

            pdf = pd.concat(
                [
                    _create_converter_to_pandas(
                        field.dataType,
                        field.nullable,
                        timezone=timezone,
                        struct_in_pandas=struct_in_pandas,
                        error_on_duplicated_field_names=error_on_duplicated_field_names,
                    )(pser)
                    for (_, pser), field, pa_field in zip(pdf.items(), schema.fields, table.schema)
                ],
                axis="columns",
            )

        if len(metrics) > 0:
            pdf.attrs["metrics"] = metrics
        if len(observed_metrics) > 0:
            pdf.attrs["observed_metrics"] = observed_metrics
        return pdf, ei

    def _proto_to_string(self, p: google.protobuf.message.Message, truncate: bool = False) -> str:
        """
        Helper method to generate a one line string representation of the plan.

        Parameters
        ----------
        p : google.protobuf.message.Message
            Generic Message type
        truncate: bool
            Indicates whether to truncate the message

        Returns
        -------
        Single line string of the serialized proto message.
        """
        try:
            max_level = 8 if truncate else sys.maxsize
            p2 = self._truncate(p, max_level) if truncate else p
            return text_format.MessageToString(p2, as_one_line=True)
        except RecursionError:
            return "<Truncated message due to recursion error>"
        except Exception:
            return "<Truncated message due to truncation error>"

    def _truncate(
        self, p: google.protobuf.message.Message, allowed_recursion_depth: int
    ) -> google.protobuf.message.Message:
        """
        Helper method to truncate the protobuf message.
        Refer to 'org.apache.spark.sql.connect.common.Abbreviator' in the server side.
        """

        def truncate_str(s: str) -> str:
            if len(s) > 1024:
                return s[:1024] + "[truncated]"
            return s

        def truncate_bytes(b: bytes) -> bytes:
            if len(b) > 8:
                return b[:8] + b"[truncated]"
            return b

        p2 = copy.deepcopy(p)

        for descriptor, value in p.ListFields():
            if value is not None:
                field_name = descriptor.name

                if descriptor.type == descriptor.TYPE_MESSAGE:
                    if allowed_recursion_depth == 0:
                        p2.ClearField(field_name)
                    elif descriptor.label == descriptor.LABEL_REPEATED:
                        p2.ClearField(field_name)
                        getattr(p2, field_name).extend(
                            [self._truncate(v, allowed_recursion_depth - 1) for v in value]
                        )
                    else:
                        getattr(p2, field_name).CopyFrom(
                            self._truncate(value, allowed_recursion_depth - 1)
                        )

                elif descriptor.type == descriptor.TYPE_STRING:
                    if descriptor.label == descriptor.LABEL_REPEATED:
                        p2.ClearField(field_name)
                        getattr(p2, field_name).extend([truncate_str(v) for v in value])
                    else:
                        setattr(p2, field_name, truncate_str(value))

                elif descriptor.type == descriptor.TYPE_BYTES:
                    if descriptor.label == descriptor.LABEL_REPEATED:
                        p2.ClearField(field_name)
                        getattr(p2, field_name).extend([truncate_bytes(v) for v in value])
                    else:
                        setattr(p2, field_name, truncate_bytes(value))

        return p2

    def schema(self, plan: pb2.Plan) -> StructType:
        """
        Return schema for given plan.
        """
        if logger.isEnabledFor(logging.DEBUG):
            # inside an if statement to not incur a performance cost converting proto to string
            # when not at debug log level.
            logger.debug(f"Schema for plan: {self._proto_to_string(plan, True)}")
        schema = self._analyze(method="schema", plan=plan).schema
        assert schema is not None
        # Server side should populate the struct field which is the schema.
        assert isinstance(schema, StructType)
        return schema

    def explain_string(self, plan: pb2.Plan, explain_mode: str = "extended") -> str:
        """
        Return explain string for given plan.
        """
        if logger.isEnabledFor(logging.DEBUG):
            # inside an if statement to not incur a performance cost converting proto to string
            # when not at debug log level.
            logger.debug(
                f"Explain (mode={explain_mode}) for plan {self._proto_to_string(plan, True)}"
            )
        result = self._analyze(
            method="explain", plan=plan, explain_mode=explain_mode
        ).explain_string
        assert result is not None
        return result

    def execute_command(
        self, command: pb2.Command, observations: Optional[Dict[str, Observation]] = None
    ) -> Tuple[Optional[pd.DataFrame], Dict[str, Any], ExecutionInfo]:
        """
        Execute given command.
        """
        if logger.isEnabledFor(logging.DEBUG):
            # inside an if statement to not incur a performance cost converting proto to string
            # when not at debug log level.
            logger.debug(f"Execute command for command {self._proto_to_string(command, True)}")
        req = self._execute_plan_request_with_metadata()
        if self._user_id:
            req.user_context.user_id = self._user_id
        req.plan.command.CopyFrom(command)
        data, _, metrics, observed_metrics, properties = self._execute_and_fetch(
            req, observations or {}
        )
        # Create a query execution object.
        ei = ExecutionInfo(metrics, observed_metrics)
        if data is not None:
            return (data.to_pandas(), properties, ei)
        else:
            return (None, properties, ei)

    def execute_command_as_iterator(
        self, command: pb2.Command, observations: Optional[Dict[str, Observation]] = None
    ) -> Iterator[Dict[str, Any]]:
        """
        Execute given command. Similar to execute_command, but the value is returned using yield.
        """
        if logger.isEnabledFor(logging.DEBUG):
            # inside an if statement to not incur a performance cost converting proto to string
            # when not at debug log level.
            logger.debug(
                f"Execute command as iterator for command {self._proto_to_string(command, True)}"
            )
        req = self._execute_plan_request_with_metadata()
        if self._user_id:
            req.user_context.user_id = self._user_id
        req.plan.command.CopyFrom(command)
        for response in self._execute_and_fetch_as_iterator(req, observations or {}):
            if isinstance(response, dict):
                yield response
            else:
                raise PySparkValueError(
                    errorClass="UNKNOWN_RESPONSE",
                    messageParameters={
                        "response": str(response),
                    },
                )

    def same_semantics(self, plan: pb2.Plan, other: pb2.Plan) -> bool:
        """
        return if two plans have the same semantics.
        """
        result = self._analyze(method="same_semantics", plan=plan, other=other).is_same_semantics
        assert result is not None
        return result

    def semantic_hash(self, plan: pb2.Plan) -> int:
        """
        returns a `hashCode` of the logical query plan.
        """
        result = self._analyze(method="semantic_hash", plan=plan).semantic_hash
        assert result is not None
        return result

    def close(self) -> None:
        """
        Close the channel.
        """
        ExecutePlanResponseReattachableIterator.shutdown()
        self._channel.close()
        self._closed = True

    @property
    def is_closed(self) -> bool:
        """
        Returns if the channel was closed previously using close() method
        """
        return self._closed

    @property
    def host(self) -> str:
        """
        The hostname where this client intends to connect.
        """
        return self._builder.host

    @property
    def token(self) -> Optional[str]:
        """
        The authentication bearer token during connection.
        If authentication is not using a bearer token, None will be returned.
        """
        return self._builder.token

    def _execute_plan_request_with_metadata(
        self, operation_id: Optional[str] = None
    ) -> pb2.ExecutePlanRequest:
        req = pb2.ExecutePlanRequest(
            session_id=self._session_id,
            client_type=self._builder.userAgent,
            tags=list(self.get_tags()),
        )
        if self._server_session_id is not None:
            req.client_observed_server_side_session_id = self._server_session_id
        if self._user_id:
            req.user_context.user_id = self._user_id
        if operation_id is not None:
            try:
                uuid.UUID(operation_id, version=4)
            except ValueError as ve:
                raise PySparkValueError(
                    errorClass="INVALID_OPERATION_UUID_ID",
                    messageParameters={"arg_name": "operation_id", "origin": str(ve)},
                )
            req.operation_id = operation_id
        return req

    def _analyze_plan_request_with_metadata(self) -> pb2.AnalyzePlanRequest:
        req = pb2.AnalyzePlanRequest()
        req.session_id = self._session_id
        if self._server_session_id is not None:
            req.client_observed_server_side_session_id = self._server_session_id
        req.client_type = self._builder.userAgent
        if self._user_id:
            req.user_context.user_id = self._user_id
        return req

    def _analyze(self, method: str, **kwargs: Any) -> AnalyzeResult:
        """
        Call the analyze RPC of Spark Connect.

        Returns
        -------
        The result of the analyze call.
        """
        req = self._analyze_plan_request_with_metadata()
        if method == "schema":
            req.schema.plan.CopyFrom(cast(pb2.Plan, kwargs.get("plan")))
        elif method == "explain":
            req.explain.plan.CopyFrom(cast(pb2.Plan, kwargs.get("plan")))
            explain_mode = kwargs.get("explain_mode")
            if explain_mode not in ["simple", "extended", "codegen", "cost", "formatted"]:
                raise PySparkValueError(
                    errorClass="UNKNOWN_EXPLAIN_MODE",
                    messageParameters={
                        "explain_mode": str(explain_mode),
                    },
                )
            if explain_mode == "simple":
                req.explain.explain_mode = (
                    pb2.AnalyzePlanRequest.Explain.ExplainMode.EXPLAIN_MODE_SIMPLE
                )
            elif explain_mode == "extended":
                req.explain.explain_mode = (
                    pb2.AnalyzePlanRequest.Explain.ExplainMode.EXPLAIN_MODE_EXTENDED
                )
            elif explain_mode == "cost":
                req.explain.explain_mode = (
                    pb2.AnalyzePlanRequest.Explain.ExplainMode.EXPLAIN_MODE_COST
                )
            elif explain_mode == "codegen":
                req.explain.explain_mode = (
                    pb2.AnalyzePlanRequest.Explain.ExplainMode.EXPLAIN_MODE_CODEGEN
                )
            else:  # formatted
                req.explain.explain_mode = (
                    pb2.AnalyzePlanRequest.Explain.ExplainMode.EXPLAIN_MODE_FORMATTED
                )
        elif method == "tree_string":
            req.tree_string.plan.CopyFrom(cast(pb2.Plan, kwargs.get("plan")))
            level = kwargs.get("level")
            if level and isinstance(level, int):
                req.tree_string.level = level
        elif method == "is_local":
            req.is_local.plan.CopyFrom(cast(pb2.Plan, kwargs.get("plan")))
        elif method == "is_streaming":
            req.is_streaming.plan.CopyFrom(cast(pb2.Plan, kwargs.get("plan")))
        elif method == "input_files":
            req.input_files.plan.CopyFrom(cast(pb2.Plan, kwargs.get("plan")))
        elif method == "spark_version":
            req.spark_version.SetInParent()
        elif method == "ddl_parse":
            req.ddl_parse.ddl_string = cast(str, kwargs.get("ddl_string"))
        elif method == "same_semantics":
            req.same_semantics.target_plan.CopyFrom(cast(pb2.Plan, kwargs.get("plan")))
            req.same_semantics.other_plan.CopyFrom(cast(pb2.Plan, kwargs.get("other")))
        elif method == "semantic_hash":
            req.semantic_hash.plan.CopyFrom(cast(pb2.Plan, kwargs.get("plan")))
        elif method == "persist":
            req.persist.relation.CopyFrom(cast(pb2.Relation, kwargs.get("relation")))
            if kwargs.get("storage_level", None) is not None:
                storage_level = cast(StorageLevel, kwargs.get("storage_level"))
                req.persist.storage_level.CopyFrom(storage_level_to_proto(storage_level))
        elif method == "unpersist":
            req.unpersist.relation.CopyFrom(cast(pb2.Relation, kwargs.get("relation")))
            if kwargs.get("blocking", None) is not None:
                req.unpersist.blocking = cast(bool, kwargs.get("blocking"))
        elif method == "get_storage_level":
            req.get_storage_level.relation.CopyFrom(cast(pb2.Relation, kwargs.get("relation")))
        elif method == "json_to_ddl":
            req.json_to_ddl.json_string = cast(str, kwargs.get("json_string"))
        else:
            raise PySparkValueError(
                errorClass="UNSUPPORTED_OPERATION",
                messageParameters={
                    "operation": method,
                },
            )

        try:
            for attempt in self._retrying():
                with attempt:
                    resp = self._stub.AnalyzePlan(req, metadata=self._builder.metadata())
                    self._verify_response_integrity(resp)
                    return AnalyzeResult.fromProto(resp)
            raise SparkConnectException("Invalid state during retry exception handling.")
        except Exception as error:
            self._handle_error(error)

    def _execute(self, req: pb2.ExecutePlanRequest) -> None:
        """
        Execute the passed request `req` and drop all results.

        Parameters
        ----------
        req : pb2.ExecutePlanRequest
            Proto representation of the plan.

        """
        logger.debug("Execute")

        def handle_response(b: pb2.ExecutePlanResponse) -> None:
            self._verify_response_integrity(b)

        try:
            if self._use_reattachable_execute:
                # Don't use retryHandler - own retry handling is inside.
                generator = ExecutePlanResponseReattachableIterator(
                    req, self._stub, self._retrying, self._builder.metadata()
                )
                for b in generator:
                    handle_response(b)
            else:
                for attempt in self._retrying():
                    with attempt:
                        for b in self._stub.ExecutePlan(req, metadata=self._builder.metadata()):
                            handle_response(b)
        except Exception as error:
            self._handle_error(error)

    def _execute_and_fetch_as_iterator(
        self,
        req: pb2.ExecutePlanRequest,
        observations: Dict[str, Observation],
        progress: Optional["Progress"] = None,
    ) -> Iterator[
        Union[
            "pa.RecordBatch",
            StructType,
            PlanMetrics,
            PlanObservedMetrics,
            Dict[str, Any],
        ]
    ]:
        if logger.isEnabledFor(logging.DEBUG):
            # inside an if statement to not incur a performance cost converting proto to string
            # when not at debug log level.
            logger.debug(f"ExecuteAndFetchAsIterator. Request: {self._proto_to_string(req)}")

        num_records = 0

        def handle_response(
            b: pb2.ExecutePlanResponse,
        ) -> Iterator[
            Union[
                "pa.RecordBatch",
                StructType,
                PlanMetrics,
                PlanObservedMetrics,
                Dict[str, Any],
                any_pb2.Any,
            ]
        ]:
            nonlocal num_records
            # The session ID is the local session ID and should match what we expect.
            self._verify_response_integrity(b)
            if logger.isEnabledFor(logging.DEBUG):
                # inside an if statement to not incur a performance cost converting proto to string
                # when not at debug log level.
                logger.debug(
                    f"ExecuteAndFetchAsIterator. Response received: {self._proto_to_string(b)}"
                )

            if b.HasField("metrics"):
                logger.debug("Received metric batch.")
                yield from self._build_metrics(b.metrics)
            if b.observed_metrics:
                logger.debug("Received observed metric batch.")
                for observed_metrics in self._build_observed_metrics(b.observed_metrics):
                    if observed_metrics.name == "__python_accumulator__":
                        from pyspark.worker_util import pickleSer

                        for metric in observed_metrics.metrics:
                            (aid, update) = pickleSer.loads(LiteralExpression._to_value(metric))
                            if aid == SpecialAccumulatorIds.SQL_UDF_PROFIER:
                                self._profiler_collector._update(update)
                    elif observed_metrics.name in observations:
                        observation_result = observations[observed_metrics.name]._result
                        assert observation_result is not None
                        observation_result.update(
                            {
                                key: LiteralExpression._to_value(metric)
                                for key, metric in zip(
                                    observed_metrics.keys, observed_metrics.metrics
                                )
                            }
                        )
                    yield observed_metrics
            if b.HasField("schema"):
                logger.debug("Received the schema.")
                dt = types.proto_schema_to_pyspark_data_type(b.schema)
                assert isinstance(dt, StructType)
                yield dt
            if b.HasField("sql_command_result"):
                logger.debug("Received the SQL command result.")
                yield {"sql_command_result": b.sql_command_result.relation}
            if b.HasField("write_stream_operation_start_result"):
                field = "write_stream_operation_start_result"
                yield {field: b.write_stream_operation_start_result}
            if b.HasField("streaming_query_command_result"):
                yield {"streaming_query_command_result": b.streaming_query_command_result}
            if b.HasField("streaming_query_manager_command_result"):
                cmd_result = b.streaming_query_manager_command_result
                yield {"streaming_query_manager_command_result": cmd_result}
            if b.HasField("streaming_query_listener_events_result"):
                event_result = b.streaming_query_listener_events_result
                yield {"streaming_query_listener_events_result": event_result}
            if b.HasField("get_resources_command_result"):
                resources = {}
                for key, resource in b.get_resources_command_result.resources.items():
                    name = resource.name
                    addresses = [address for address in resource.addresses]
                    resources[key] = ResourceInformation(name, addresses)
                yield {"get_resources_command_result": resources}
            if b.HasField("extension"):
                yield b.extension
            if b.HasField("execution_progress"):
                if progress:
                    p = from_proto(b.execution_progress)
                    progress.update_ticks(*p, operation_id=b.operation_id)
            if b.HasField("arrow_batch"):
                logger.debug(
                    f"Received arrow batch rows={b.arrow_batch.row_count} "
                    f"size={len(b.arrow_batch.data)}"
                )

                if (
                    b.arrow_batch.HasField("start_offset")
                    and num_records != b.arrow_batch.start_offset
                ):
                    raise SparkConnectException(
                        f"Expected arrow batch to start at row offset {num_records} in results, "
                        + "but received arrow batch starting at offset "
                        + f"{b.arrow_batch.start_offset}."
                    )

                num_records_in_batch = 0
                with pa.ipc.open_stream(b.arrow_batch.data) as reader:
                    for batch in reader:
                        assert isinstance(batch, pa.RecordBatch)
                        num_records_in_batch += batch.num_rows
                        yield batch

                if num_records_in_batch != b.arrow_batch.row_count:
                    raise SparkConnectException(
                        f"Expected {b.arrow_batch.row_count} rows in arrow batch but got "
                        + f"{num_records_in_batch}."
                    )
                num_records += num_records_in_batch
            if b.HasField("create_resource_profile_command_result"):
                profile_id = b.create_resource_profile_command_result.profile_id
                yield {"create_resource_profile_command_result": profile_id}
            if b.HasField("checkpoint_command_result"):
                yield {
                    "checkpoint_command_result": proto_to_remote_cached_dataframe(
                        b.checkpoint_command_result.relation
                    )
                }
            if b.HasField("ml_command_result"):
                yield {"ml_command_result": b.ml_command_result}

        try:
            if self._use_reattachable_execute:
                # Don't use retryHandler - own retry handling is inside.
                generator = ExecutePlanResponseReattachableIterator(
                    req, self._stub, self._retrying, self._builder.metadata()
                )
                for b in generator:
                    yield from handle_response(b)
            else:
                for attempt in self._retrying():
                    with attempt:
                        for b in self._stub.ExecutePlan(req, metadata=self._builder.metadata()):
                            yield from handle_response(b)
        except KeyboardInterrupt as kb:
            logger.debug(f"Interrupt request received for operation={req.operation_id}")
            if progress is not None:
                progress.finish()
            self.interrupt_operation(req.operation_id)
            raise kb
        except Exception as error:
            self._handle_error(error)

    def _execute_and_fetch(
        self,
        req: pb2.ExecutePlanRequest,
        observations: Dict[str, Observation],
        self_destruct: bool = False,
    ) -> Tuple[
        Optional["pa.Table"],
        Optional[StructType],
        List[PlanMetrics],
        List[PlanObservedMetrics],
        Dict[str, Any],
    ]:
        logger.debug("ExecuteAndFetch")

        observed_metrics: List[PlanObservedMetrics] = []
        metrics: List[PlanMetrics] = []
        batches: List[pa.RecordBatch] = []
        schema: Optional[StructType] = None
        properties: Dict[str, Any] = {}

        with Progress(handlers=self._progress_handlers, operation_id=req.operation_id) as progress:
            for response in self._execute_and_fetch_as_iterator(
                req, observations, progress=progress
            ):
                if isinstance(response, StructType):
                    schema = response
                elif isinstance(response, pa.RecordBatch):
                    batches.append(response)
                elif isinstance(response, PlanMetrics):
                    metrics.append(response)
                elif isinstance(response, PlanObservedMetrics):
                    observed_metrics.append(response)
                elif isinstance(response, dict):
                    properties.update(**response)
                else:
                    raise PySparkValueError(
                        errorClass="UNKNOWN_RESPONSE",
                        messageParameters={
                            "response": response,
                        },
                    )

        if len(batches) > 0:
            if self_destruct:
                results = []
                for batch in batches:
                    # self_destruct frees memory column-wise, but Arrow record batches are
                    # oriented row-wise, so copies each column into its own allocation
                    batch = pa.RecordBatch.from_arrays(
                        [
                            # This call actually reallocates the array
                            pa.concat_arrays([array])
                            for array in batch
                        ],
                        schema=batch.schema,
                    )
                    results.append(batch)
                table = pa.Table.from_batches(batches=results)
                # Ensure only the table has a reference to the batches, so that
                # self_destruct (if enabled) is effective
                del results
                del batches
            else:
                table = pa.Table.from_batches(batches=batches)
            return table, schema, metrics, observed_metrics, properties
        else:
            return None, schema, metrics, observed_metrics, properties

    def _config_request_with_metadata(self) -> pb2.ConfigRequest:
        req = pb2.ConfigRequest()
        req.session_id = self._session_id
        if self._server_session_id is not None:
            req.client_observed_server_side_session_id = self._server_session_id
        req.client_type = self._builder.userAgent
        if self._user_id:
            req.user_context.user_id = self._user_id
        return req

    def get_configs(self, *keys: str) -> Tuple[Optional[str], ...]:
        op = pb2.ConfigRequest.Operation(get=pb2.ConfigRequest.Get(keys=keys))
        configs = dict(self.config(op).pairs)
        return tuple(configs.get(key) for key in keys)

    def get_config_dict(self, *keys: str) -> Mapping[str, Optional[str]]:
        op = pb2.ConfigRequest.Operation(get=pb2.ConfigRequest.Get(keys=keys))
        return dict(self.config(op).pairs)

    def get_config_with_defaults(
        self, *pairs: Tuple[str, Optional[str]]
    ) -> Tuple[Optional[str], ...]:
        op = pb2.ConfigRequest.Operation(
            get_with_default=pb2.ConfigRequest.GetWithDefault(
                pairs=[pb2.KeyValue(key=key, value=default) for key, default in pairs]
            )
        )
        configs = dict(self.config(op).pairs)
        return tuple(configs.get(key) for key, _ in pairs)

    def config(self, operation: pb2.ConfigRequest.Operation) -> ConfigResult:
        """
        Call the config RPC of Spark Connect.

        Parameters
        ----------
        operation : str
           Operation kind

        Returns
        -------
        The result of the config call.
        """
        req = self._config_request_with_metadata()
        if self._server_session_id is not None:
            req.client_observed_server_side_session_id = self._server_session_id
        req.operation.CopyFrom(operation)
        try:
            for attempt in self._retrying():
                with attempt:
                    resp = self._stub.Config(req, metadata=self._builder.metadata())
                    self._verify_response_integrity(resp)
                    return ConfigResult.fromProto(resp)
            raise SparkConnectException("Invalid state during retry exception handling.")
        except Exception as error:
            self._handle_error(error)

    def _interrupt_request(
        self, interrupt_type: str, id_or_tag: Optional[str] = None
    ) -> pb2.InterruptRequest:
        req = pb2.InterruptRequest()
        req.session_id = self._session_id
        if self._server_session_id is not None:
            req.client_observed_server_side_session_id = self._server_session_id
        req.client_type = self._builder.userAgent
        if interrupt_type == "all":
            req.interrupt_type = pb2.InterruptRequest.InterruptType.INTERRUPT_TYPE_ALL
        elif interrupt_type == "tag":
            assert id_or_tag is not None
            req.interrupt_type = pb2.InterruptRequest.InterruptType.INTERRUPT_TYPE_TAG
            req.operation_tag = id_or_tag
        elif interrupt_type == "operation":
            assert id_or_tag is not None
            req.interrupt_type = pb2.InterruptRequest.InterruptType.INTERRUPT_TYPE_OPERATION_ID
            req.operation_id = id_or_tag
        else:
            raise PySparkValueError(
                errorClass="UNKNOWN_INTERRUPT_TYPE",
                messageParameters={
                    "interrupt_type": str(interrupt_type),
                },
            )
        if self._user_id:
            req.user_context.user_id = self._user_id
        return req

    def interrupt_all(self) -> Optional[List[str]]:
        req = self._interrupt_request("all")
        try:
            for attempt in self._retrying():
                with attempt:
                    resp = self._stub.Interrupt(req, metadata=self._builder.metadata())
                    self._verify_response_integrity(resp)
                    return list(resp.interrupted_ids)
            raise SparkConnectException("Invalid state during retry exception handling.")
        except Exception as error:
            self._handle_error(error)

    def interrupt_tag(self, tag: str) -> Optional[List[str]]:
        req = self._interrupt_request("tag", tag)
        try:
            for attempt in self._retrying():
                with attempt:
                    resp = self._stub.Interrupt(req, metadata=self._builder.metadata())
                    self._verify_response_integrity(resp)
                    return list(resp.interrupted_ids)
            raise SparkConnectException("Invalid state during retry exception handling.")
        except Exception as error:
            self._handle_error(error)

    def interrupt_operation(self, op_id: str) -> Optional[List[str]]:
        req = self._interrupt_request("operation", op_id)
        try:
            for attempt in self._retrying():
                with attempt:
                    resp = self._stub.Interrupt(req, metadata=self._builder.metadata())
                    self._verify_response_integrity(resp)
                    return list(resp.interrupted_ids)
            raise SparkConnectException("Invalid state during retry exception handling.")
        except Exception as error:
            self._handle_error(error)

    def release_session(self) -> None:
        req = pb2.ReleaseSessionRequest()
        req.session_id = self._session_id
        req.client_type = self._builder.userAgent
        if self._user_id:
            req.user_context.user_id = self._user_id
        try:
            for attempt in self._retrying():
                with attempt:
                    resp = self._stub.ReleaseSession(req, metadata=self._builder.metadata())
                    self._verify_response_integrity(resp)
                    return
            raise SparkConnectException("Invalid state during retry exception handling.")
        except Exception as error:
            self._handle_error(error)

    def add_tag(self, tag: str) -> None:
        self._throw_if_invalid_tag(tag)
        if not hasattr(self.thread_local, "tags"):
            self.thread_local.tags = set()
        self.thread_local.tags.add(tag)

    def remove_tag(self, tag: str) -> None:
        self._throw_if_invalid_tag(tag)
        if not hasattr(self.thread_local, "tags"):
            self.thread_local.tags = set()
        self.thread_local.tags.remove(tag)

    def get_tags(self) -> Set[str]:
        if not hasattr(self.thread_local, "tags"):
            self.thread_local.tags = set()
        return self.thread_local.tags

    def clear_tags(self) -> None:
        self.thread_local.tags = set()

    def _throw_if_invalid_tag(self, tag: str) -> None:
        """
        Validate if a tag for ExecutePlanRequest.tags is valid. Throw ``ValueError`` if
        not.
        """
        spark_job_tags_sep = ","
        if tag is None:
            raise PySparkValueError(
                errorClass="CANNOT_BE_NONE", message_paramters={"arg_name": "Spark Connect tag"}
            )
        if spark_job_tags_sep in tag:
            raise PySparkValueError(
                errorClass="VALUE_ALLOWED",
                messageParameters={
                    "arg_name": "Spark Connect tag",
                    "disallowed_value": spark_job_tags_sep,
                },
            )
        if len(tag) == 0:
            raise PySparkValueError(
                errorClass="VALUE_NOT_NON_EMPTY_STR",
                messageParameters={"arg_name": "Spark Connect tag", "arg_value": tag},
            )

    def _handle_error(self, error: Exception) -> NoReturn:
        """
        Handle errors that occur during RPC calls.

        Parameters
        ----------
        error : Exception
            An exception thrown during RPC calls.

        Returns
        -------
        Throws the appropriate internal Python exception.
        """

        if getattr(self.thread_local, "inside_error_handling", False):
            # We are already inside error handling routine,
            # avoid recursive error processing (with potentially infinite recursion)
            raise error

        try:
            self.thread_local.inside_error_handling = True
            if isinstance(error, grpc.RpcError):
                self._handle_rpc_error(error)
            raise error
        finally:
            self.thread_local.inside_error_handling = False

    def _fetch_enriched_error(self, info: "ErrorInfo") -> Optional[pb2.FetchErrorDetailsResponse]:
        if "errorId" not in info.metadata:
            return None

        req = pb2.FetchErrorDetailsRequest(
            session_id=self._session_id,
            client_type=self._builder.userAgent,
            error_id=info.metadata["errorId"],
        )
        if self._server_session_id is not None:
            req.client_observed_server_side_session_id = self._server_session_id
        if self._user_id:
            req.user_context.user_id = self._user_id

        try:
            return self._stub.FetchErrorDetails(req, metadata=self._builder.metadata())
        except grpc.RpcError:
            return None

    def _display_server_stack_trace(self) -> bool:
        from pyspark.sql.connect.conf import RuntimeConf

        conf = RuntimeConf(self)
        try:
            if conf.get("spark.sql.connect.serverStacktrace.enabled") == "true":
                return True
            return conf.get("spark.sql.pyspark.jvmStacktrace.enabled") == "true"
        except Exception as e:  # noqa: F841
            # Falls back to true if an exception occurs during reading the config.
            # Otherwise, it will recursively try to get the conf when it consistently
            # fails, ending up with `RecursionError`.
            return True

    def _handle_rpc_error(self, rpc_error: grpc.RpcError) -> NoReturn:
        """
        Error handling helper for dealing with GRPC Errors. On the server side, certain
        exceptions are enriched with additional RPC Status information. These are
        unpacked in this function and put into the exception.

        To avoid overloading the user with GRPC errors, this message explicitly
        swallows the error context from the call. This GRPC Error is logged however,
        and can be enabled.

        Parameters
        ----------
        rpc_error : grpc.RpcError
           RPC Error containing the details of the exception.

        Returns
        -------
        Throws the appropriate internal Python exception.
        """
        logger.exception("GRPC Error received")
        # We have to cast the value here because, a RpcError is a Call as well.
        # https://grpc.github.io/grpc/python/grpc.html#grpc.UnaryUnaryMultiCallable.__call__
        status = rpc_status.from_call(cast(grpc.Call, rpc_error))
        if status:
            for d in status.details:
                if d.Is(error_details_pb2.ErrorInfo.DESCRIPTOR):
                    info = error_details_pb2.ErrorInfo()
                    d.Unpack(info)
                    logger.debug(f"Received ErrorInfo: {info}")

                    if info.metadata["errorClass"] == "INVALID_HANDLE.SESSION_CHANGED":
                        self._closed = True

                    raise convert_exception(
                        info,
                        status.message,
                        self._fetch_enriched_error(info),
                        self._display_server_stack_trace(),
                        status.code,
                    ) from None

            raise SparkConnectGrpcException(
                message=status.message, grpc_status_code=status.code
            ) from None
        else:
            raise SparkConnectGrpcException(str(rpc_error)) from None

    def add_artifacts(self, *paths: str, pyfile: bool, archive: bool, file: bool) -> None:
        try:
            for path in paths:
                for attempt in self._retrying():
                    with attempt:
                        self._artifact_manager.add_artifacts(
                            path, pyfile=pyfile, archive=archive, file=file
                        )
        except Exception as error:
            self._handle_error(error)

    def copy_from_local_to_fs(self, local_path: str, dest_path: str) -> None:
        for attempt in self._retrying():
            with attempt:
                self._artifact_manager._add_forward_to_fs_artifacts(local_path, dest_path)

    def cache_artifact(self, blob: bytes) -> str:
        for attempt in self._retrying():
            with attempt:
                return self._artifact_manager.cache_artifact(blob)
        raise SparkConnectException("Invalid state during retry exception handling.")

    def _verify_response_integrity(
        self,
        response: Union[
            pb2.ConfigResponse,
            pb2.ExecutePlanResponse,
            pb2.InterruptResponse,
            pb2.ReleaseExecuteResponse,
            pb2.AddArtifactsResponse,
            pb2.AnalyzePlanResponse,
            pb2.FetchErrorDetailsResponse,
            pb2.ReleaseSessionResponse,
        ],
    ) -> None:
        """
        Verifies the integrity of the response. This method checks if the session ID and the
        server-side session ID match. If not, it throws an exception.
        Parameters
        ----------
        response - One of the different response types handled by the Spark Connect service
        """
        if self._session_id != response.session_id:
            raise PySparkAssertionError(
                "Received incorrect session identifier for request:"
                f"{response.session_id} != {self._session_id}"
            )
        if self._server_session_id is not None:
            if (
                response.server_side_session_id
                and response.server_side_session_id != self._server_session_id
            ):
                self._closed = True
                raise PySparkAssertionError(
                    "Received incorrect server side session identifier for request. "
                    "Please create a new Spark Session to reconnect. ("
                    f"{response.server_side_session_id} != {self._server_session_id})"
                )
        else:
            # Update the server side session ID.
            self._server_session_id = response.server_side_session_id

    def _create_profile(self, profile: pb2.ResourceProfile) -> int:
        """Create the ResourceProfile on the server side and return the profile ID"""
        logger.debug("Creating the ResourceProfile")
        cmd = pb2.Command()
        cmd.create_resource_profile_command.profile.CopyFrom(profile)
        (_, properties, _) = self.execute_command(cmd)
        profile_id = properties["create_resource_profile_command_result"]
        return profile_id

    def add_ml_cache(self, cache_id: str) -> None:
        if not hasattr(self.thread_local, "ml_caches"):
            self.thread_local.ml_caches = set()
        self.thread_local.ml_caches.add(cache_id)

    def remove_ml_cache(self, cache_id: str) -> None:
        deleted = self._delete_ml_cache([cache_id])
        # TODO: Fix the code: change thread-local `ml_caches` to global `ml_caches`.
        if hasattr(self.thread_local, "ml_caches"):
            if cache_id in self.thread_local.ml_caches:
                for obj_id in deleted:
                    self.thread_local.ml_caches.remove(obj_id)

    def _delete_ml_cache(self, cache_ids: List[str]) -> List[str]:
        # try best to delete the cache
        try:
            if len(cache_ids) > 0:
                command = pb2.Command()
                command.ml_command.delete.obj_refs.extend(
                    [pb2.ObjectRef(id=cache_id) for cache_id in cache_ids]
                )
                (_, properties, _) = self.execute_command(command)

                assert properties is not None

                if properties is not None and "ml_command_result" in properties:
                    ml_command_result = properties["ml_command_result"]
                    deleted = ml_command_result.operator_info.obj_ref.id.split(",")
                    return cast(List[str], deleted)
            return []
        except Exception:
            return []

    def _cleanup_ml_cache(self) -> None:
        if hasattr(self.thread_local, "ml_caches"):
            try:
                command = pb2.Command()
                command.ml_command.clean_cache.SetInParent()
                self.execute_command(command)
                self.thread_local.ml_caches.clear()
            except Exception:
                pass

    def _get_ml_cache_info(self) -> List[str]:
        if hasattr(self.thread_local, "ml_caches"):
            command = pb2.Command()
            command.ml_command.get_cache_info.SetInParent()
            (_, properties, _) = self.execute_command(command)

            assert properties is not None

            if properties is not None and "ml_command_result" in properties:
                ml_command_result = properties["ml_command_result"]
                return [item.string for item in ml_command_result.param.array.elements]
        return []
