#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import os
import json
import sys
import random
import warnings
from collections.abc import Iterable
from functools import reduce, cached_property
from typing import (
    Any,
    Callable,
    Dict,
    Iterator,
    List,
    Optional,
    Sequence,
    Tuple,
    Type,
    Union,
    cast,
    overload,
    TYPE_CHECKING,
)

from pyspark import _NoValue
from pyspark.resource import ResourceProfile
from pyspark._globals import _NoValueType
from pyspark.errors import (
    AnalysisException,
    PySparkTypeError,
    PySparkValueError,
    PySparkIndexError,
    PySparkAttributeError,
)
from pyspark.util import (
    _load_from_socket,
    _local_iterator_from_socket,
)
from pyspark.serializers import BatchedSerializer, CPickleSerializer, UTF8Deserializer
from pyspark.storagelevel import StorageLevel
from pyspark.traceback_utils import SCCallSiteSync
from pyspark.sql.column import Column
from pyspark.sql.functions import builtin as F
from pyspark.sql.classic.column import _to_seq, _to_list, _to_java_column
from pyspark.sql.readwriter import DataFrameWriter, DataFrameWriterV2
from pyspark.sql.merge import MergeIntoWriter
from pyspark.sql.streaming import DataStreamWriter
from pyspark.sql.types import (
    StructType,
    Row,
    _parse_datatype_json_string,
)
from pyspark.sql.dataframe import (
    DataFrame as ParentDataFrame,
    DataFrameNaFunctions as ParentDataFrameNaFunctions,
    DataFrameStatFunctions as ParentDataFrameStatFunctions,
)
from pyspark.sql.utils import get_active_spark_context, to_java_array, to_scala_map
from pyspark.sql.pandas.conversion import PandasConversionMixin
from pyspark.sql.pandas.map_ops import PandasMapOpsMixin
from pyspark.sql.table_arg import TableArg


if TYPE_CHECKING:
    from py4j.java_gateway import JavaObject
    import pyarrow as pa
    from pyspark.core.rdd import RDD
    from pyspark.core.context import SparkContext
    from pyspark._typing import PrimitiveType
    from pyspark.pandas.frame import DataFrame as PandasOnSparkDataFrame
    from pyspark.sql._typing import (
        ColumnOrName,
        ColumnOrNameOrOrdinal,
        LiteralType,
        OptionalPrimitiveType,
    )
    from pyspark.sql.pandas._typing import (
        PandasMapIterFunction,
        ArrowMapIterFunction,
        DataFrameLike as PandasDataFrameLike,
    )
    from pyspark.sql.context import SQLContext
    from pyspark.sql.session import SparkSession
    from pyspark.sql.group import GroupedData
    from pyspark.sql.observation import Observation
    from pyspark.sql.metrics import ExecutionInfo


class DataFrame(ParentDataFrame, PandasMapOpsMixin, PandasConversionMixin):
    def __new__(
        cls,
        jdf: "JavaObject",
        sql_ctx: Union["SQLContext", "SparkSession"],
    ) -> "DataFrame":
        self = object.__new__(cls)
        self.__init__(jdf, sql_ctx)  # type: ignore[misc]
        return self

    def __init__(
        self,
        jdf: "JavaObject",
        sql_ctx: Union["SQLContext", "SparkSession"],
    ):
        from pyspark.sql.context import SQLContext

        if isinstance(sql_ctx, SQLContext):
            assert not os.environ.get("SPARK_TESTING")  # Sanity check for our internal usage.
            assert isinstance(sql_ctx, SQLContext)
            # We should remove this if-else branch in the future release, and rename
            # sql_ctx to session in the constructor. This is an internal code path but
            # was kept with a warning because it's used intensively by third-party libraries.
            warnings.warn("DataFrame constructor is internal. Do not directly use it.")
            self._sql_ctx = sql_ctx
            session = sql_ctx.sparkSession
        else:
            session = sql_ctx
        self._session: "SparkSession" = session

        self._sc: "SparkContext" = sql_ctx._sc
        self._jdf: "JavaObject" = jdf
        self.is_cached = False
        # Check whether _repr_html is supported or not, we use it to avoid calling _jdf twice
        # by __repr__ and _repr_html_ while eager evaluation opens.
        self._support_repr_html = False

    @cached_property
    def sql_ctx(self) -> "SQLContext":
        from pyspark.sql.context import SQLContext

        warnings.warn(
            "DataFrame.sql_ctx is an internal property, and will be removed "
            "in future releases. Use DataFrame.sparkSession instead."
        )
        return SQLContext._get_or_create(self._sc)

    @property
    def sparkSession(self) -> "SparkSession":
        return self._session

    @cached_property
    def rdd(self) -> "RDD[Row]":
        from pyspark.core.rdd import RDD

        jrdd = self._jdf.javaToPython()
        return RDD(jrdd, self.sparkSession._sc, BatchedSerializer(CPickleSerializer()))

    @property
    def na(self) -> ParentDataFrameNaFunctions:
        return DataFrameNaFunctions(self)

    @property
    def stat(self) -> ParentDataFrameStatFunctions:
        return DataFrameStatFunctions(self)

    def toJSON(self, use_unicode: bool = True) -> "RDD[str]":
        from pyspark.core.rdd import RDD

        rdd = self._jdf.toJSON()
        return RDD(rdd.toJavaRDD(), self._sc, UTF8Deserializer(use_unicode))

    def registerTempTable(self, name: str) -> None:
        warnings.warn("Deprecated in 2.0, use createOrReplaceTempView instead.", FutureWarning)
        self._jdf.createOrReplaceTempView(name)

    def createTempView(self, name: str) -> None:
        self._jdf.createTempView(name)

    def createOrReplaceTempView(self, name: str) -> None:
        self._jdf.createOrReplaceTempView(name)

    def createGlobalTempView(self, name: str) -> None:
        self._jdf.createGlobalTempView(name)

    def createOrReplaceGlobalTempView(self, name: str) -> None:
        self._jdf.createOrReplaceGlobalTempView(name)

    @property
    def write(self) -> DataFrameWriter:
        return DataFrameWriter(self)

    @property
    def writeStream(self) -> DataStreamWriter:
        return DataStreamWriter(self)

    @cached_property
    def schema(self) -> StructType:
        try:
            return cast(StructType, _parse_datatype_json_string(self._jdf.schema().json()))
        except AnalysisException as e:
            raise e
        except Exception as e:
            raise PySparkValueError(
                errorClass="CANNOT_PARSE_DATATYPE",
                messageParameters={"error": str(e)},
            )

    def printSchema(self, level: Optional[int] = None) -> None:
        if level:
            print(self._jdf.schema().treeString(level))
        else:
            print(self._jdf.schema().treeString())

    def explain(
        self, extended: Optional[Union[bool, str]] = None, mode: Optional[str] = None
    ) -> None:
        if extended is not None and mode is not None:
            raise PySparkValueError(
                errorClass="CANNOT_SET_TOGETHER",
                messageParameters={"arg_list": "extended and mode"},
            )

        # For the no argument case: df.explain()
        is_no_argument = extended is None and mode is None

        # For the cases below:
        #   explain(True)
        #   explain(extended=False)
        is_extended_case = isinstance(extended, bool) and mode is None

        # For the case when extended is mode:
        #   df.explain("formatted")
        is_extended_as_mode = isinstance(extended, str) and mode is None

        # For the mode specified:
        #   df.explain(mode="formatted")
        is_mode_case = extended is None and isinstance(mode, str)

        if not (is_no_argument or is_extended_case or is_extended_as_mode or is_mode_case):
            if (extended is not None) and (not isinstance(extended, (bool, str))):
                raise PySparkTypeError(
                    errorClass="NOT_BOOL_OR_STR",
                    messageParameters={
                        "arg_name": "extended",
                        "arg_type": type(extended).__name__,
                    },
                )
            if (mode is not None) and (not isinstance(mode, str)):
                raise PySparkTypeError(
                    errorClass="NOT_STR",
                    messageParameters={"arg_name": "mode", "arg_type": type(mode).__name__},
                )

        # Sets an explain mode depending on a given argument
        if is_no_argument:
            explain_mode = "simple"
        elif is_extended_case:
            explain_mode = "extended" if extended else "simple"
        elif is_mode_case:
            explain_mode = cast(str, mode)
        elif is_extended_as_mode:
            explain_mode = cast(str, extended)
        assert self._sc._jvm is not None
        print(self._sc._jvm.PythonSQLUtils.explainString(self._jdf.queryExecution(), explain_mode))

    def exceptAll(self, other: ParentDataFrame) -> ParentDataFrame:
        return DataFrame(self._jdf.exceptAll(other._jdf), self.sparkSession)

    def isLocal(self) -> bool:
        return self._jdf.isLocal()

    @property
    def isStreaming(self) -> bool:
        return self._jdf.isStreaming()

    def isEmpty(self) -> bool:
        return self._jdf.isEmpty()

    def show(self, n: int = 20, truncate: Union[bool, int] = True, vertical: bool = False) -> None:
        print(self._show_string(n, truncate, vertical))

    def _show_string(
        self, n: int = 20, truncate: Union[bool, int] = True, vertical: bool = False
    ) -> str:
        if not isinstance(n, int) or isinstance(n, bool):
            raise PySparkTypeError(
                errorClass="NOT_INT",
                messageParameters={"arg_name": "n", "arg_type": type(n).__name__},
            )

        if not isinstance(vertical, bool):
            raise PySparkTypeError(
                errorClass="NOT_BOOL",
                messageParameters={"arg_name": "vertical", "arg_type": type(vertical).__name__},
            )

        if isinstance(truncate, bool) and truncate:
            return self._jdf.showString(n, 20, vertical)
        else:
            try:
                int_truncate = int(truncate)
            except ValueError:
                raise PySparkTypeError(
                    errorClass="NOT_BOOL",
                    messageParameters={
                        "arg_name": "truncate",
                        "arg_type": type(truncate).__name__,
                    },
                )

            return self._jdf.showString(n, int_truncate, vertical)

    def __repr__(self) -> str:
        if not self._support_repr_html and self.sparkSession._jconf.isReplEagerEvalEnabled():
            vertical = False
            return self._jdf.showString(
                self.sparkSession._jconf.replEagerEvalMaxNumRows(),
                self.sparkSession._jconf.replEagerEvalTruncate(),
                vertical,
            )
        else:
            return "DataFrame[%s]" % (", ".join("%s: %s" % c for c in self.dtypes))

    def _repr_html_(self) -> Optional[str]:
        """Returns a :class:`DataFrame` with html code when you enabled eager evaluation
        by 'spark.sql.repl.eagerEval.enabled', this only called by REPL you are
        using support eager evaluation with HTML.
        """
        if not self._support_repr_html:
            self._support_repr_html = True
        if self.sparkSession._jconf.isReplEagerEvalEnabled():
            return self._jdf.htmlString(
                self.sparkSession._jconf.replEagerEvalMaxNumRows(),
                self.sparkSession._jconf.replEagerEvalTruncate(),
            )
        else:
            return None

    def checkpoint(self, eager: bool = True) -> ParentDataFrame:
        jdf = self._jdf.checkpoint(eager)
        return DataFrame(jdf, self.sparkSession)

    def localCheckpoint(
        self, eager: bool = True, storageLevel: Optional[StorageLevel] = None
    ) -> ParentDataFrame:
        if storageLevel is None:
            jdf = self._jdf.localCheckpoint(eager)
        else:
            jdf = self._jdf.localCheckpoint(eager, self._sc._getJavaStorageLevel(storageLevel))
        return DataFrame(jdf, self.sparkSession)

    def withWatermark(self, eventTime: str, delayThreshold: str) -> ParentDataFrame:
        if not eventTime or type(eventTime) is not str:
            raise PySparkTypeError(
                errorClass="NOT_STR",
                messageParameters={"arg_name": "eventTime", "arg_type": type(eventTime).__name__},
            )
        if not delayThreshold or type(delayThreshold) is not str:
            raise PySparkTypeError(
                errorClass="NOT_STR",
                messageParameters={
                    "arg_name": "delayThreshold",
                    "arg_type": type(delayThreshold).__name__,
                },
            )
        jdf = self._jdf.withWatermark(eventTime, delayThreshold)
        return DataFrame(jdf, self.sparkSession)

    def hint(
        self, name: str, *parameters: Union["PrimitiveType", "Column", List["PrimitiveType"]]
    ) -> ParentDataFrame:
        if len(parameters) == 1 and isinstance(parameters[0], list):
            parameters = parameters[0]  # type: ignore[assignment]

        if not isinstance(name, str):
            raise PySparkTypeError(
                errorClass="NOT_STR",
                messageParameters={"arg_name": "name", "arg_type": type(name).__name__},
            )

        allowed_types = (str, float, int, Column, list)
        allowed_primitive_types = (str, float, int)
        allowed_types_repr = ", ".join(
            [t.__name__ for t in allowed_types[:-1]]
            + ["list[" + t.__name__ + "]" for t in allowed_primitive_types]
        )
        for p in parameters:
            if not isinstance(p, allowed_types):
                raise PySparkTypeError(
                    errorClass="DISALLOWED_TYPE_FOR_CONTAINER",
                    messageParameters={
                        "arg_name": "parameters",
                        "arg_type": type(parameters).__name__,
                        "allowed_types": allowed_types_repr,
                        "item_type": type(p).__name__,
                    },
                )
            if isinstance(p, list):
                if not all(isinstance(e, allowed_primitive_types) for e in p):
                    raise PySparkTypeError(
                        errorClass="DISALLOWED_TYPE_FOR_CONTAINER",
                        messageParameters={
                            "arg_name": "parameters",
                            "arg_type": type(parameters).__name__,
                            "allowed_types": allowed_types_repr,
                            "item_type": type(p).__name__ + "[" + type(p[0]).__name__ + "]",
                        },
                    )

        def _converter(parameter: Union[str, list, float, int, Column]) -> Any:
            if isinstance(parameter, Column):
                return _to_java_column(parameter)
            elif isinstance(parameter, list):
                # for list input, we are assuming only one element type exist in the list.
                # for empty list, we are converting it into an empty long[] in the JVM side.
                gateway = self._sc._gateway
                assert gateway is not None
                jclass = gateway.jvm.long
                if len(parameter) >= 1:
                    mapping = {
                        str: gateway.jvm.java.lang.String,
                        float: gateway.jvm.double,
                        int: gateway.jvm.long,
                    }
                    jclass = mapping[type(parameter[0])]
                return to_java_array(gateway, jclass, parameter)
            else:
                return parameter

        jdf = self._jdf.hint(name, self._jseq(parameters, _converter))
        return DataFrame(jdf, self.sparkSession)

    def count(self) -> int:
        return int(self._jdf.count())

    def collect(self) -> List[Row]:
        with SCCallSiteSync(self._sc):
            sock_info = self._jdf.collectToPython()
        return list(_load_from_socket(sock_info, BatchedSerializer(CPickleSerializer())))

    def toLocalIterator(self, prefetchPartitions: bool = False) -> Iterator[Row]:
        with SCCallSiteSync(self._sc):
            sock_info = self._jdf.toPythonIterator(prefetchPartitions)
        return _local_iterator_from_socket(sock_info, BatchedSerializer(CPickleSerializer()))

    def limit(self, num: int) -> ParentDataFrame:
        jdf = self._jdf.limit(num)
        return DataFrame(jdf, self.sparkSession)

    def offset(self, num: int) -> ParentDataFrame:
        jdf = self._jdf.offset(num)
        return DataFrame(jdf, self.sparkSession)

    def take(self, num: int) -> List[Row]:
        return self.limit(num).collect()

    def tail(self, num: int) -> List[Row]:
        with SCCallSiteSync(self._sc):
            sock_info = self._jdf.tailToPython(num)
        return list(_load_from_socket(sock_info, BatchedSerializer(CPickleSerializer())))

    def foreach(self, f: Callable[[Row], None]) -> None:
        self.rdd.foreach(f)

    def foreachPartition(self, f: Callable[[Iterator[Row]], None]) -> None:
        self.rdd.foreachPartition(f)  # type: ignore[arg-type]

    def cache(self) -> ParentDataFrame:
        self.is_cached = True
        self._jdf.cache()
        return self

    def persist(
        self,
        storageLevel: StorageLevel = (StorageLevel.MEMORY_AND_DISK_DESER),
    ) -> ParentDataFrame:
        self.is_cached = True
        javaStorageLevel = self._sc._getJavaStorageLevel(storageLevel)
        self._jdf.persist(javaStorageLevel)
        return self

    @property
    def storageLevel(self) -> StorageLevel:
        java_storage_level = self._jdf.storageLevel()
        storage_level = StorageLevel(
            java_storage_level.useDisk(),
            java_storage_level.useMemory(),
            java_storage_level.useOffHeap(),
            java_storage_level.deserialized(),
            java_storage_level.replication(),
        )
        return storage_level

    def unpersist(self, blocking: bool = False) -> ParentDataFrame:
        self.is_cached = False
        self._jdf.unpersist(blocking)
        return self

    def coalesce(self, numPartitions: int) -> ParentDataFrame:
        return DataFrame(self._jdf.coalesce(numPartitions), self.sparkSession)

    @overload
    def repartition(self, numPartitions: int, *cols: "ColumnOrName") -> ParentDataFrame:
        ...

    @overload
    def repartition(self, *cols: "ColumnOrName") -> ParentDataFrame:
        ...

    def repartition(  # type: ignore[misc]
        self, numPartitions: Union[int, "ColumnOrName"], *cols: "ColumnOrName"
    ) -> ParentDataFrame:
        if isinstance(numPartitions, int):
            if len(cols) == 0:
                return DataFrame(self._jdf.repartition(numPartitions), self.sparkSession)
            else:
                return DataFrame(
                    self._jdf.repartition(numPartitions, self._jcols(*cols)),
                    self.sparkSession,
                )
        elif isinstance(numPartitions, (str, Column)):
            cols = (numPartitions,) + cols
            return DataFrame(self._jdf.repartition(self._jcols(*cols)), self.sparkSession)
        else:
            raise PySparkTypeError(
                errorClass="NOT_COLUMN_OR_STR",
                messageParameters={
                    "arg_name": "numPartitions",
                    "arg_type": type(numPartitions).__name__,
                },
            )

    @overload
    def repartitionByRange(self, numPartitions: int, *cols: "ColumnOrName") -> ParentDataFrame:
        ...

    @overload
    def repartitionByRange(self, *cols: "ColumnOrName") -> ParentDataFrame:
        ...

    def repartitionByRange(  # type: ignore[misc]
        self, numPartitions: Union[int, "ColumnOrName"], *cols: "ColumnOrName"
    ) -> ParentDataFrame:
        if isinstance(numPartitions, int):
            if len(cols) == 0:
                raise PySparkValueError(
                    errorClass="CANNOT_BE_EMPTY",
                    messageParameters={"item": "partition-by expression"},
                )
            else:
                return DataFrame(
                    self._jdf.repartitionByRange(numPartitions, self._jcols(*cols)),
                    self.sparkSession,
                )
        elif isinstance(numPartitions, (str, Column)):
            cols = (numPartitions,) + cols
            return DataFrame(self._jdf.repartitionByRange(self._jcols(*cols)), self.sparkSession)
        else:
            raise PySparkTypeError(
                errorClass="NOT_COLUMN_OR_INT_OR_STR",
                messageParameters={
                    "arg_name": "numPartitions",
                    "arg_type": type(numPartitions).__name__,
                },
            )

    def distinct(self) -> ParentDataFrame:
        return DataFrame(self._jdf.distinct(), self.sparkSession)

    @overload
    def sample(self, fraction: float, seed: Optional[int] = ...) -> ParentDataFrame:
        ...

    @overload
    def sample(
        self,
        withReplacement: Optional[bool],
        fraction: float,
        seed: Optional[int] = ...,
    ) -> ParentDataFrame:
        ...

    def sample(  # type: ignore[misc]
        self,
        withReplacement: Optional[Union[float, bool]] = None,
        fraction: Optional[Union[int, float]] = None,
        seed: Optional[int] = None,
    ) -> ParentDataFrame:
        _w, _f, _s = self._preapare_args_for_sample(withReplacement, fraction, seed)
        jdf = self._jdf.sample(*[_w, _f, _s])
        return DataFrame(jdf, self.sparkSession)

    def sampleBy(
        self, col: "ColumnOrName", fractions: Dict[Any, float], seed: Optional[int] = None
    ) -> ParentDataFrame:
        if isinstance(col, str):
            col = Column(col)
        elif not isinstance(col, Column):
            raise PySparkTypeError(
                errorClass="NOT_COLUMN_OR_STR",
                messageParameters={"arg_name": "col", "arg_type": type(col).__name__},
            )
        if not isinstance(fractions, dict):
            raise PySparkTypeError(
                errorClass="NOT_DICT",
                messageParameters={"arg_name": "fractions", "arg_type": type(fractions).__name__},
            )
        for k, v in fractions.items():
            if not isinstance(k, (float, int, str)):
                raise PySparkTypeError(
                    errorClass="DISALLOWED_TYPE_FOR_CONTAINER",
                    messageParameters={
                        "arg_name": "fractions",
                        "arg_type": type(fractions).__name__,
                        "allowed_types": "float, int, str",
                        "item_type": type(k).__name__,
                    },
                )
            fractions[k] = float(v)
        col = col._jc
        seed = seed if seed is not None else random.randint(0, sys.maxsize)
        return DataFrame(
            self._jdf.stat().sampleBy(col, self._jmap(fractions), seed), self.sparkSession
        )

    def randomSplit(
        self, weights: List[float], seed: Optional[int] = None
    ) -> List[ParentDataFrame]:
        for w in weights:
            if w < 0.0:
                raise PySparkValueError(
                    errorClass="VALUE_NOT_POSITIVE",
                    messageParameters={"arg_name": "weights", "arg_value": str(w)},
                )
        seed = seed if seed is not None else random.randint(0, sys.maxsize)
        df_array = self._jdf.randomSplit(
            _to_list(self.sparkSession._sc, cast(List["ColumnOrName"], weights)), int(seed)
        )
        return [DataFrame(df, self.sparkSession) for df in df_array]

    @property
    def dtypes(self) -> List[Tuple[str, str]]:
        return [(str(f.name), f.dataType.simpleString()) for f in self.schema.fields]

    @property
    def columns(self) -> List[str]:
        return [f.name for f in self.schema.fields]

    def metadataColumn(self, colName: str) -> Column:
        if not isinstance(colName, str):
            raise PySparkTypeError(
                errorClass="NOT_STR",
                messageParameters={"arg_name": "colName", "arg_type": type(colName).__name__},
            )
        jc = self._jdf.metadataColumn(colName)
        return Column(jc)

    def colRegex(self, colName: str) -> Column:
        if not isinstance(colName, str):
            raise PySparkTypeError(
                errorClass="NOT_STR",
                messageParameters={"arg_name": "colName", "arg_type": type(colName).__name__},
            )
        jc = self._jdf.colRegex(colName)
        return Column(jc)

    def to(self, schema: StructType) -> ParentDataFrame:
        assert schema is not None
        jschema = self._jdf.sparkSession().parseDataType(schema.json())
        return DataFrame(self._jdf.to(jschema), self.sparkSession)

    def alias(self, alias: str) -> ParentDataFrame:
        assert isinstance(alias, str), "alias should be a string"
        return DataFrame(getattr(self._jdf, "as")(alias), self.sparkSession)

    def crossJoin(self, other: ParentDataFrame) -> ParentDataFrame:
        jdf = self._jdf.crossJoin(other._jdf)
        return DataFrame(jdf, self.sparkSession)

    def join(
        self,
        other: ParentDataFrame,
        on: Optional[Union[str, List[str], Column, List[Column]]] = None,
        how: Optional[str] = None,
    ) -> ParentDataFrame:
        if on is not None and not isinstance(on, list):
            on = [on]  # type: ignore[assignment]

        if on is not None:
            if isinstance(on[0], str):
                on = self._jseq(cast(List[str], on))
            else:
                assert isinstance(on[0], Column), "on should be Column or list of Column"
                on = reduce(lambda x, y: x.__and__(y), cast(List[Column], on))
                on = on._jc

        if on is None and how is None:
            jdf = self._jdf.join(other._jdf)
        else:
            if how is None:
                how = "inner"
            if on is None:
                on = self._jseq([])
            assert isinstance(how, str), "how should be a string"
            jdf = self._jdf.join(other._jdf, on, how)
        return DataFrame(jdf, self.sparkSession)

    def lateralJoin(
        self,
        other: ParentDataFrame,
        on: Optional[Column] = None,
        how: Optional[str] = None,
    ) -> ParentDataFrame:
        if on is None and how is None:
            jdf = self._jdf.lateralJoin(other._jdf)
        elif on is None:
            jdf = self._jdf.lateralJoin(other._jdf, how)
        elif how is None:
            jdf = self._jdf.lateralJoin(other._jdf, on._jc)
        else:
            jdf = self._jdf.lateralJoin(other._jdf, on._jc, how)
        return DataFrame(jdf, self.sparkSession)

    # TODO(SPARK-22947): Fix the DataFrame API.
    def _joinAsOf(
        self,
        other: ParentDataFrame,
        leftAsOfColumn: Union[str, Column],
        rightAsOfColumn: Union[str, Column],
        on: Optional[Union[str, List[str], Column, List[Column]]] = None,
        how: Optional[str] = None,
        *,
        tolerance: Optional[Column] = None,
        allowExactMatches: bool = True,
        direction: str = "backward",
    ) -> ParentDataFrame:
        """
        Perform an as-of join.

        This is similar to a left-join except that we match on the nearest
        key rather than equal keys.

        .. versionchanged:: 4.0.0
            Supports Spark Connect.

        Parameters
        ----------
        other : :class:`DataFrame`
            Right side of the join
        leftAsOfColumn : str or :class:`Column`
            a string for the as-of join column name, or a Column
        rightAsOfColumn : str or :class:`Column`
            a string for the as-of join column name, or a Column
        on : str, list or :class:`Column`, optional
            a string for the join column name, a list of column names,
            a join expression (Column), or a list of Columns.
            If `on` is a string or a list of strings indicating the name of the join column(s),
            the column(s) must exist on both sides, and this performs an equi-join.
        how : str, optional
            default ``inner``. Must be one of: ``inner`` and ``left``.
        tolerance : :class:`Column`, optional
            an asof tolerance within this range; must be compatible
            with the merge index.
        allowExactMatches : bool, optional
            default ``True``.
        direction : str, optional
            default ``backward``. Must be one of: ``backward``, ``forward``, and ``nearest``.

        Examples
        --------
        The following performs an as-of join between ``left`` and ``right``.

        >>> left = spark.createDataFrame([(1, "a"), (5, "b"), (10,  "c")], ["a", "left_val"])
        >>> right = spark.createDataFrame([(1, 1), (2, 2), (3, 3), (6, 6), (7, 7)],
        ...                               ["a", "right_val"])
        >>> left._joinAsOf(
        ...     right, leftAsOfColumn="a", rightAsOfColumn="a"
        ... ).select(left.a, 'left_val', 'right_val').sort("a").collect()
        [Row(a=1, left_val='a', right_val=1),
         Row(a=5, left_val='b', right_val=3),
         Row(a=10, left_val='c', right_val=7)]

        >>> from pyspark.sql import functions as sf
        >>> left._joinAsOf(
        ...     right, leftAsOfColumn="a", rightAsOfColumn="a", tolerance=sf.lit(1)
        ... ).select(left.a, 'left_val', 'right_val').sort("a").collect()
        [Row(a=1, left_val='a', right_val=1)]

        >>> left._joinAsOf(
        ...     right, leftAsOfColumn="a", rightAsOfColumn="a", how="left", tolerance=sf.lit(1)
        ... ).select(left.a, 'left_val', 'right_val').sort("a").collect()
        [Row(a=1, left_val='a', right_val=1),
         Row(a=5, left_val='b', right_val=None),
         Row(a=10, left_val='c', right_val=None)]

        >>> left._joinAsOf(
        ...     right, leftAsOfColumn="a", rightAsOfColumn="a", allowExactMatches=False
        ... ).select(left.a, 'left_val', 'right_val').sort("a").collect()
        [Row(a=5, left_val='b', right_val=3),
         Row(a=10, left_val='c', right_val=7)]

        >>> left._joinAsOf(
        ...     right, leftAsOfColumn="a", rightAsOfColumn="a", direction="forward"
        ... ).select(left.a, 'left_val', 'right_val').sort("a").collect()
        [Row(a=1, left_val='a', right_val=1),
         Row(a=5, left_val='b', right_val=6)]
        """
        if isinstance(leftAsOfColumn, str):
            leftAsOfColumn = self[leftAsOfColumn]
        left_as_of_jcol = leftAsOfColumn._jc
        if isinstance(rightAsOfColumn, str):
            rightAsOfColumn = other[rightAsOfColumn]
        right_as_of_jcol = rightAsOfColumn._jc

        if on is not None and not isinstance(on, list):
            on = [on]  # type: ignore[assignment]

        if on is not None:
            if isinstance(on[0], str):
                on = self._jseq(cast(List[str], on))
            else:
                assert isinstance(on[0], Column), "on should be Column or list of Column"
                on = reduce(lambda x, y: x.__and__(y), cast(List[Column], on))
                on = on._jc

        if how is None:
            how = "inner"
        assert isinstance(how, str), "how should be a string"

        if tolerance is not None:
            assert isinstance(tolerance, Column), "tolerance should be Column"
            tolerance = tolerance._jc

        jdf = self._jdf.joinAsOf(
            other._jdf,
            left_as_of_jcol,
            right_as_of_jcol,
            on,
            how,
            tolerance,
            allowExactMatches,
            direction,
        )
        return DataFrame(jdf, self.sparkSession)

    def sortWithinPartitions(
        self,
        *cols: Union[int, str, Column, List[Union[int, str, Column]]],
        **kwargs: Any,
    ) -> ParentDataFrame:
        _cols = self._preapare_cols_for_sort(F.col, cols, kwargs)
        jdf = self._jdf.sortWithinPartitions(self._jseq(_cols, _to_java_column))
        return DataFrame(jdf, self.sparkSession)

    def sort(
        self,
        *cols: Union[int, str, Column, List[Union[int, str, Column]]],
        **kwargs: Any,
    ) -> ParentDataFrame:
        _cols = self._preapare_cols_for_sort(F.col, cols, kwargs)
        jdf = self._jdf.sort(self._jseq(_cols, _to_java_column))
        return DataFrame(jdf, self.sparkSession)

    orderBy = sort

    def _jseq(
        self,
        cols: Sequence,
        converter: Optional[Callable[..., Union["PrimitiveType", "JavaObject"]]] = None,
    ) -> "JavaObject":
        """Return a JVM Seq of Columns from a list of Column or names"""
        return _to_seq(self.sparkSession._sc, cols, converter)

    def _jmap(self, jm: Dict) -> "JavaObject":
        """Return a JVM Scala Map from a dict"""
        return to_scala_map(self.sparkSession._sc._jvm, jm)

    def _jcols(self, *cols: "ColumnOrName") -> "JavaObject":
        """Return a JVM Seq of Columns from a list of Column or column names

        If `cols` has only one list in it, cols[0] will be used as the list.
        """
        if len(cols) == 1 and isinstance(cols[0], list):
            cols = cols[0]
        return self._jseq(cols, _to_java_column)

    def _jcols_ordinal(self, *cols: "ColumnOrNameOrOrdinal") -> "JavaObject":
        """Return a JVM Seq of Columns from a list of Column or column names or column ordinals.

        If `cols` has only one list in it, cols[0] will be used as the list.
        """
        if len(cols) == 1 and isinstance(cols[0], list):
            cols = cols[0]

        _cols = []
        for c in cols:
            if isinstance(c, int) and not isinstance(c, bool):
                if c < 1:
                    raise PySparkIndexError(
                        errorClass="INDEX_NOT_POSITIVE", messageParameters={"index": str(c)}
                    )
                # ordinal is 1-based
                _cols.append(self[c - 1])
            else:
                _cols.append(c)  # type: ignore[arg-type]
        return self._jseq(_cols, _to_java_column)

    def describe(self, *cols: Union[str, List[str]]) -> ParentDataFrame:
        if len(cols) == 1 and isinstance(cols[0], list):
            cols = cols[0]  # type: ignore[assignment]
        jdf = self._jdf.describe(self._jseq(cols))
        return DataFrame(jdf, self.sparkSession)

    def summary(self, *statistics: str) -> ParentDataFrame:
        if len(statistics) == 1 and isinstance(statistics[0], list):
            statistics = statistics[0]
        jdf = self._jdf.summary(self._jseq(statistics))
        return DataFrame(jdf, self.sparkSession)

    @overload
    def head(self) -> Optional[Row]:
        ...

    @overload
    def head(self, n: int) -> List[Row]:
        ...

    def head(self, n: Optional[int] = None) -> Union[Optional[Row], List[Row]]:
        if n is None:
            rs = self.head(1)
            return rs[0] if rs else None
        return self.take(n)

    def first(self) -> Optional[Row]:
        return self.head()

    @overload
    def __getitem__(self, item: Union[int, str]) -> Column:
        ...

    @overload
    def __getitem__(self, item: Union[Column, List, Tuple]) -> ParentDataFrame:
        ...

    def __getitem__(
        self, item: Union[int, str, Column, List, Tuple]
    ) -> Union[Column, ParentDataFrame]:
        if isinstance(item, str):
            jc = self._jdf.apply(item)
            return Column(jc)
        elif isinstance(item, Column):
            return self.filter(item)
        elif isinstance(item, (list, tuple)):
            return self.select(*item)
        elif isinstance(item, int):
            jc = self._jdf.apply(self.columns[item])
            return Column(jc)
        else:
            raise PySparkTypeError(
                errorClass="NOT_COLUMN_OR_FLOAT_OR_INT_OR_LIST_OR_STR",
                messageParameters={"arg_name": "item", "arg_type": type(item).__name__},
            )

    def __getattr__(self, name: str) -> Column:
        if name not in self.columns:
            raise PySparkAttributeError(
                errorClass="ATTRIBUTE_NOT_SUPPORTED", messageParameters={"attr_name": name}
            )
        jc = self._jdf.apply(name)
        return Column(jc)

    def __dir__(self) -> List[str]:
        attrs = set(dir(DataFrame))
        attrs.update(filter(lambda s: s.isidentifier(), self.columns))
        return sorted(attrs)

    @overload
    def select(self, *cols: "ColumnOrName") -> ParentDataFrame:
        ...

    @overload
    def select(self, __cols: Union[List[Column], List[str]]) -> ParentDataFrame:
        ...

    def select(self, *cols: "ColumnOrName") -> ParentDataFrame:  # type: ignore[misc]
        jdf = self._jdf.select(self._jcols(*cols))
        return DataFrame(jdf, self.sparkSession)

    @overload
    def selectExpr(self, *expr: str) -> ParentDataFrame:
        ...

    @overload
    def selectExpr(self, *expr: List[str]) -> ParentDataFrame:
        ...

    def selectExpr(self, *expr: Union[str, List[str]]) -> ParentDataFrame:
        if len(expr) == 1 and isinstance(expr[0], list):
            expr = expr[0]  # type: ignore[assignment]
        jdf = self._jdf.selectExpr(self._jseq(expr))
        return DataFrame(jdf, self.sparkSession)

    def filter(self, condition: Union[Column, str]) -> ParentDataFrame:
        if isinstance(condition, str):
            jdf = self._jdf.filter(condition)
        elif isinstance(condition, Column):
            jdf = self._jdf.filter(condition._jc)
        else:
            raise PySparkTypeError(
                errorClass="NOT_COLUMN_OR_STR",
                messageParameters={"arg_name": "condition", "arg_type": type(condition).__name__},
            )
        return DataFrame(jdf, self.sparkSession)

    @overload
    def groupBy(self, *cols: "ColumnOrNameOrOrdinal") -> "GroupedData":
        ...

    @overload
    def groupBy(self, __cols: Union[List[Column], List[str], List[int]]) -> "GroupedData":
        ...

    def groupBy(self, *cols: "ColumnOrNameOrOrdinal") -> "GroupedData":  # type: ignore[misc]
        jgd = self._jdf.groupBy(self._jcols_ordinal(*cols))
        from pyspark.sql.group import GroupedData

        return GroupedData(jgd, self)

    @overload
    def rollup(self, *cols: "ColumnOrName") -> "GroupedData":
        ...

    @overload
    def rollup(self, __cols: Union[List[Column], List[str]]) -> "GroupedData":
        ...

    def rollup(self, *cols: "ColumnOrNameOrOrdinal") -> "GroupedData":  # type: ignore[misc]
        jgd = self._jdf.rollup(self._jcols_ordinal(*cols))
        from pyspark.sql.group import GroupedData

        return GroupedData(jgd, self)

    @overload
    def cube(self, *cols: "ColumnOrName") -> "GroupedData":
        ...

    @overload
    def cube(self, __cols: Union[List[Column], List[str]]) -> "GroupedData":
        ...

    def cube(self, *cols: "ColumnOrName") -> "GroupedData":  # type: ignore[misc]
        jgd = self._jdf.cube(self._jcols_ordinal(*cols))
        from pyspark.sql.group import GroupedData

        return GroupedData(jgd, self)

    def groupingSets(
        self, groupingSets: Sequence[Sequence["ColumnOrName"]], *cols: "ColumnOrName"
    ) -> "GroupedData":
        from pyspark.sql.group import GroupedData

        jgrouping_sets = _to_seq(self._sc, [self._jcols(*inner) for inner in groupingSets])

        jgd = self._jdf.groupingSets(jgrouping_sets, self._jcols(*cols))
        return GroupedData(jgd, self)

    def unpivot(
        self,
        ids: Union["ColumnOrName", List["ColumnOrName"], Tuple["ColumnOrName", ...]],
        values: Optional[Union["ColumnOrName", List["ColumnOrName"], Tuple["ColumnOrName", ...]]],
        variableColumnName: str,
        valueColumnName: str,
    ) -> ParentDataFrame:
        assert ids is not None, "ids must not be None"

        def to_jcols(
            cols: Union["ColumnOrName", List["ColumnOrName"], Tuple["ColumnOrName", ...]]
        ) -> "JavaObject":
            if isinstance(cols, list):
                return self._jcols(*cols)
            if isinstance(cols, tuple):
                return self._jcols(*list(cols))
            return self._jcols(cols)

        jids = to_jcols(ids)
        if values is None:
            jdf = self._jdf.unpivotWithSeq(jids, variableColumnName, valueColumnName)
        else:
            jvals = to_jcols(values)
            jdf = self._jdf.unpivotWithSeq(jids, jvals, variableColumnName, valueColumnName)

        return DataFrame(jdf, self.sparkSession)

    def melt(
        self,
        ids: Union["ColumnOrName", List["ColumnOrName"], Tuple["ColumnOrName", ...]],
        values: Optional[Union["ColumnOrName", List["ColumnOrName"], Tuple["ColumnOrName", ...]]],
        variableColumnName: str,
        valueColumnName: str,
    ) -> ParentDataFrame:
        return self.unpivot(ids, values, variableColumnName, valueColumnName)

    def agg(self, *exprs: Union[Column, Dict[str, str]]) -> ParentDataFrame:
        return self.groupBy().agg(*exprs)  # type: ignore[arg-type]

    def observe(
        self,
        observation: Union["Observation", str],
        *exprs: Column,
    ) -> ParentDataFrame:
        from pyspark.sql import Observation

        if len(exprs) == 0:
            raise PySparkValueError(
                errorClass="CANNOT_BE_EMPTY",
                messageParameters={"item": "exprs"},
            )
        if not all(isinstance(c, Column) for c in exprs):
            raise PySparkTypeError(
                errorClass="NOT_LIST_OF_COLUMN",
                messageParameters={"arg_name": "exprs"},
            )

        if isinstance(observation, Observation):
            return observation._on(self, *exprs)
        elif isinstance(observation, str):
            return DataFrame(
                self._jdf.observe(
                    observation, exprs[0]._jc, _to_seq(self._sc, [c._jc for c in exprs[1:]])
                ),
                self.sparkSession,
            )
        else:
            raise PySparkTypeError(
                errorClass="NOT_LIST_OF_COLUMN",
                messageParameters={
                    "arg_name": "observation",
                    "arg_type": type(observation).__name__,
                },
            )

    def union(self, other: ParentDataFrame) -> ParentDataFrame:
        return DataFrame(self._jdf.union(other._jdf), self.sparkSession)

    def unionAll(self, other: ParentDataFrame) -> ParentDataFrame:
        return self.union(other)

    def unionByName(
        self, other: ParentDataFrame, allowMissingColumns: bool = False
    ) -> ParentDataFrame:
        return DataFrame(self._jdf.unionByName(other._jdf, allowMissingColumns), self.sparkSession)

    def intersect(self, other: ParentDataFrame) -> ParentDataFrame:
        return DataFrame(self._jdf.intersect(other._jdf), self.sparkSession)

    def intersectAll(self, other: ParentDataFrame) -> ParentDataFrame:
        return DataFrame(self._jdf.intersectAll(other._jdf), self.sparkSession)

    def subtract(self, other: ParentDataFrame) -> ParentDataFrame:
        return DataFrame(getattr(self._jdf, "except")(other._jdf), self.sparkSession)

    def dropDuplicates(self, subset: Optional[List[str]] = None) -> ParentDataFrame:
        if subset is not None and (not isinstance(subset, Iterable) or isinstance(subset, str)):
            raise PySparkTypeError(
                errorClass="NOT_LIST_OR_TUPLE",
                messageParameters={"arg_name": "subset", "arg_type": type(subset).__name__},
            )

        if subset is None:
            jdf = self._jdf.dropDuplicates()
        else:
            for c in subset:
                if not isinstance(c, str):
                    raise PySparkTypeError(
                        errorClass="NOT_STR",
                        messageParameters={"arg_name": "subset", "arg_type": type(c).__name__},
                    )
            jdf = self._jdf.dropDuplicates(self._jseq(subset))
        return DataFrame(jdf, self.sparkSession)

    def dropDuplicatesWithinWatermark(self, subset: Optional[List[str]] = None) -> ParentDataFrame:
        if subset is not None and (not isinstance(subset, Iterable) or isinstance(subset, str)):
            raise PySparkTypeError(
                errorClass="NOT_LIST_OR_TUPLE",
                messageParameters={"arg_name": "subset", "arg_type": type(subset).__name__},
            )

        if subset is None:
            jdf = self._jdf.dropDuplicatesWithinWatermark()
        else:
            for c in subset:
                if not isinstance(c, str):
                    raise PySparkTypeError(
                        errorClass="NOT_STR",
                        messageParameters={"arg_name": "subset", "arg_type": type(c).__name__},
                    )
            jdf = self._jdf.dropDuplicatesWithinWatermark(self._jseq(subset))
        return DataFrame(jdf, self.sparkSession)

    def dropna(
        self,
        how: str = "any",
        thresh: Optional[int] = None,
        subset: Optional[Union[str, Tuple[str, ...], List[str]]] = None,
    ) -> ParentDataFrame:
        if how is not None and how not in ["any", "all"]:
            raise PySparkValueError(
                errorClass="VALUE_NOT_ANY_OR_ALL",
                messageParameters={"arg_name": "how", "arg_type": how},
            )

        if subset is None:
            subset = self.columns
        elif isinstance(subset, str):
            subset = [subset]
        elif not isinstance(subset, (list, tuple)):
            raise PySparkTypeError(
                errorClass="NOT_LIST_OR_STR_OR_TUPLE",
                messageParameters={"arg_name": "subset", "arg_type": type(subset).__name__},
            )

        if thresh is None:
            thresh = len(subset) if how == "any" else 1

        return DataFrame(self._jdf.na().drop(thresh, self._jseq(subset)), self.sparkSession)

    @overload
    def fillna(
        self,
        value: "LiteralType",
        subset: Optional[Union[str, Tuple[str, ...], List[str]]] = ...,
    ) -> ParentDataFrame:
        ...

    @overload
    def fillna(self, value: Dict[str, "LiteralType"]) -> ParentDataFrame:
        ...

    def fillna(
        self,
        value: Union["LiteralType", Dict[str, "LiteralType"]],
        subset: Optional[Union[str, Tuple[str, ...], List[str]]] = None,
    ) -> ParentDataFrame:
        if not isinstance(value, (float, int, str, bool, dict)):
            raise PySparkTypeError(
                errorClass="NOT_BOOL_OR_DICT_OR_FLOAT_OR_INT_OR_STR",
                messageParameters={"arg_name": "value", "arg_type": type(value).__name__},
            )

        # Note that bool validates isinstance(int), but we don't want to
        # convert bools to floats

        if not isinstance(value, bool) and isinstance(value, int):
            value = float(value)

        if isinstance(value, dict):
            return DataFrame(self._jdf.na().fill(value), self.sparkSession)
        elif subset is None:
            return DataFrame(self._jdf.na().fill(value), self.sparkSession)
        else:
            if isinstance(subset, str):
                subset = [subset]
            elif not isinstance(subset, (list, tuple)):
                raise PySparkTypeError(
                    errorClass="NOT_LIST_OR_TUPLE",
                    messageParameters={"arg_name": "subset", "arg_type": type(subset).__name__},
                )

            return DataFrame(self._jdf.na().fill(value, self._jseq(subset)), self.sparkSession)

    @overload
    def replace(
        self,
        to_replace: "LiteralType",
        value: "OptionalPrimitiveType",
        subset: Optional[List[str]] = ...,
    ) -> ParentDataFrame:
        ...

    @overload
    def replace(
        self,
        to_replace: List["LiteralType"],
        value: List["OptionalPrimitiveType"],
        subset: Optional[List[str]] = ...,
    ) -> ParentDataFrame:
        ...

    @overload
    def replace(
        self,
        to_replace: Dict["LiteralType", "OptionalPrimitiveType"],
        subset: Optional[List[str]] = ...,
    ) -> ParentDataFrame:
        ...

    @overload
    def replace(
        self,
        to_replace: List["LiteralType"],
        value: "OptionalPrimitiveType",
        subset: Optional[List[str]] = ...,
    ) -> ParentDataFrame:
        ...

    def replace(  # type: ignore[misc]
        self,
        to_replace: Union[
            "LiteralType", List["LiteralType"], Dict["LiteralType", "OptionalPrimitiveType"]
        ],
        value: Optional[
            Union["OptionalPrimitiveType", List["OptionalPrimitiveType"], _NoValueType]
        ] = _NoValue,
        subset: Optional[List[str]] = None,
    ) -> ParentDataFrame:
        if value is _NoValue:
            if isinstance(to_replace, dict):
                value = None
            else:
                raise PySparkTypeError(
                    errorClass="ARGUMENT_REQUIRED",
                    messageParameters={"arg_name": "value", "condition": "`to_replace` is dict"},
                )

        # Helper functions
        def all_of(types: Union[Type, Tuple[Type, ...]]) -> Callable[[Iterable], bool]:
            """Given a type or tuple of types and a sequence of xs
            check if each x is instance of type(s)

            >>> all_of(bool)([True, False])
            True
            >>> all_of(str)(["a", 1])
            False
            """

            def all_of_(xs: Iterable) -> bool:
                return all(isinstance(x, types) for x in xs)

            return all_of_

        all_of_bool = all_of(bool)
        all_of_str = all_of(str)
        all_of_numeric = all_of((float, int))

        # Validate input types
        valid_types = (bool, float, int, str, list, tuple)
        if not isinstance(to_replace, valid_types + (dict,)):
            raise PySparkTypeError(
                errorClass="NOT_BOOL_OR_DICT_OR_FLOAT_OR_INT_OR_LIST_OR_STR_OR_TUPLE",
                messageParameters={
                    "arg_name": "to_replace",
                    "arg_type": type(to_replace).__name__,
                },
            )

        if (
            not isinstance(value, valid_types)
            and value is not None
            and not isinstance(to_replace, dict)
        ):
            raise PySparkTypeError(
                errorClass="NOT_BOOL_OR_FLOAT_OR_INT_OR_LIST_OR_NONE_OR_STR_OR_TUPLE",
                messageParameters={
                    "arg_name": "value",
                    "arg_type": type(value).__name__,
                },
            )

        if isinstance(to_replace, (list, tuple)) and isinstance(value, (list, tuple)):
            if len(to_replace) != len(value):
                raise PySparkValueError(
                    errorClass="LENGTH_SHOULD_BE_THE_SAME",
                    messageParameters={
                        "arg1": "to_replace",
                        "arg2": "value",
                        "arg1_length": str(len(to_replace)),
                        "arg2_length": str(len(value)),
                    },
                )

        if not (subset is None or isinstance(subset, (list, tuple, str))):
            raise PySparkTypeError(
                errorClass="NOT_LIST_OR_STR_OR_TUPLE",
                messageParameters={"arg_name": "subset", "arg_type": type(subset).__name__},
            )

        # Reshape input arguments if necessary
        if isinstance(to_replace, (float, int, str)):
            to_replace = [to_replace]

        if isinstance(to_replace, dict):
            rep_dict = to_replace
            if value is not None:
                warnings.warn("to_replace is a dict and value is not None. value will be ignored.")
        else:
            if isinstance(value, (float, int, str)) or value is None:
                value = [value for _ in range(len(to_replace))]
            rep_dict = dict(zip(to_replace, cast("Iterable[Optional[Union[float, str]]]", value)))

        if isinstance(subset, str):
            subset = [subset]

        # Verify we were not passed in mixed type generics.
        if not any(
            all_of_type(rep_dict.keys())
            and all_of_type(x for x in rep_dict.values() if x is not None)
            for all_of_type in [all_of_bool, all_of_str, all_of_numeric]
        ):
            raise PySparkValueError(
                errorClass="MIXED_TYPE_REPLACEMENT",
                messageParameters={},
            )

        if subset is None:
            return DataFrame(self._jdf.na().replace("*", rep_dict), self.sparkSession)
        else:
            return DataFrame(
                self._jdf.na().replace(self._jseq(subset), self._jmap(rep_dict)),
                self.sparkSession,
            )

    @overload
    def approxQuantile(
        self,
        col: str,
        probabilities: Union[List[float], Tuple[float]],
        relativeError: float,
    ) -> List[float]:
        ...

    @overload
    def approxQuantile(
        self,
        col: Union[List[str], Tuple[str]],
        probabilities: Union[List[float], Tuple[float]],
        relativeError: float,
    ) -> List[List[float]]:
        ...

    def approxQuantile(
        self,
        col: Union[str, List[str], Tuple[str]],
        probabilities: Union[List[float], Tuple[float]],
        relativeError: float,
    ) -> Union[List[float], List[List[float]]]:
        if not isinstance(col, (str, list, tuple)):
            raise PySparkTypeError(
                errorClass="NOT_LIST_OR_STR_OR_TUPLE",
                messageParameters={"arg_name": "col", "arg_type": type(col).__name__},
            )

        isStr = isinstance(col, str)

        if isinstance(col, tuple):
            col = list(col)
        elif isStr:
            col = [cast(str, col)]

        for c in col:
            if not isinstance(c, str):
                raise PySparkTypeError(
                    errorClass="DISALLOWED_TYPE_FOR_CONTAINER",
                    messageParameters={
                        "arg_name": "col",
                        "arg_type": type(col).__name__,
                        "allowed_types": "str",
                        "item_type": type(c).__name__,
                    },
                )
        col = _to_list(self._sc, cast(List["ColumnOrName"], col))

        if not isinstance(probabilities, (list, tuple)):
            raise PySparkTypeError(
                errorClass="NOT_LIST_OR_TUPLE",
                messageParameters={
                    "arg_name": "probabilities",
                    "arg_type": type(probabilities).__name__,
                },
            )
        if isinstance(probabilities, tuple):
            probabilities = list(probabilities)
        for p in probabilities:
            if not isinstance(p, (float, int)) or p < 0 or p > 1:
                raise PySparkTypeError(
                    errorClass="NOT_LIST_OF_FLOAT_OR_INT",
                    messageParameters={
                        "arg_name": "probabilities",
                        "arg_type": type(p).__name__,
                    },
                )
        probabilities = _to_list(self._sc, cast(List["ColumnOrName"], probabilities))

        if not isinstance(relativeError, (float, int)):
            raise PySparkTypeError(
                errorClass="NOT_FLOAT_OR_INT",
                messageParameters={
                    "arg_name": "relativeError",
                    "arg_type": type(relativeError).__name__,
                },
            )
        if relativeError < 0:
            raise PySparkValueError(
                errorClass="NEGATIVE_VALUE",
                messageParameters={
                    "arg_name": "relativeError",
                    "arg_value": str(relativeError),
                },
            )
        relativeError = float(relativeError)

        jaq = self._jdf.stat().approxQuantile(col, probabilities, relativeError)
        jaq_list = [list(j) for j in jaq]
        return jaq_list[0] if isStr else jaq_list

    def corr(self, col1: str, col2: str, method: Optional[str] = None) -> float:
        if not isinstance(col1, str):
            raise PySparkTypeError(
                errorClass="NOT_STR",
                messageParameters={"arg_name": "col1", "arg_type": type(col1).__name__},
            )
        if not isinstance(col2, str):
            raise PySparkTypeError(
                errorClass="NOT_STR",
                messageParameters={"arg_name": "col2", "arg_type": type(col2).__name__},
            )
        if not method:
            method = "pearson"
        if not method == "pearson":
            raise PySparkValueError(
                errorClass="VALUE_NOT_PEARSON",
                messageParameters={"arg_name": "method", "arg_value": method},
            )
        return self._jdf.stat().corr(col1, col2, method)

    def cov(self, col1: str, col2: str) -> float:
        if not isinstance(col1, str):
            raise PySparkTypeError(
                errorClass="NOT_STR",
                messageParameters={"arg_name": "col1", "arg_type": type(col1).__name__},
            )
        if not isinstance(col2, str):
            raise PySparkTypeError(
                errorClass="NOT_STR",
                messageParameters={"arg_name": "col2", "arg_type": type(col2).__name__},
            )
        return self._jdf.stat().cov(col1, col2)

    def crosstab(self, col1: str, col2: str) -> ParentDataFrame:
        if not isinstance(col1, str):
            raise PySparkTypeError(
                errorClass="NOT_STR",
                messageParameters={"arg_name": "col1", "arg_type": type(col1).__name__},
            )
        if not isinstance(col2, str):
            raise PySparkTypeError(
                errorClass="NOT_STR",
                messageParameters={"arg_name": "col2", "arg_type": type(col2).__name__},
            )
        return DataFrame(self._jdf.stat().crosstab(col1, col2), self.sparkSession)

    def freqItems(
        self, cols: Union[List[str], Tuple[str]], support: Optional[float] = None
    ) -> ParentDataFrame:
        if isinstance(cols, tuple):
            cols = list(cols)
        if not isinstance(cols, list):
            raise PySparkTypeError(
                errorClass="NOT_LIST_OR_TUPLE",
                messageParameters={"arg_name": "cols", "arg_type": type(cols).__name__},
            )
        if not support:
            support = 0.01
        return DataFrame(
            self._jdf.stat().freqItems(_to_seq(self._sc, cols), support), self.sparkSession
        )

    def _ipython_key_completions_(self) -> List[str]:
        """Returns the names of columns in this :class:`DataFrame`.

        Examples
        --------
        >>> df = spark.createDataFrame([(2, "Alice"), (5, "Bob")], ["age", "name"])
        >>> df._ipython_key_completions_()
        ['age', 'name']

        Would return illegal identifiers.
        >>> df = spark.createDataFrame([(2, "Alice"), (5, "Bob")], ["age 1", "name?1"])
        >>> df._ipython_key_completions_()
        ['age 1', 'name?1']
        """
        return self.columns

    def withColumns(self, *colsMap: Dict[str, Column]) -> ParentDataFrame:
        # Below code is to help enable kwargs in future.
        assert len(colsMap) == 1
        colsMap = colsMap[0]  # type: ignore[assignment]

        if not isinstance(colsMap, dict):
            raise PySparkTypeError(
                errorClass="NOT_DICT",
                messageParameters={"arg_name": "colsMap", "arg_type": type(colsMap).__name__},
            )

        col_names = list(colsMap.keys())
        cols = list(colsMap.values())

        return DataFrame(
            self._jdf.withColumns(_to_seq(self._sc, col_names), self._jcols(*cols)),
            self.sparkSession,
        )

    def withColumn(self, colName: str, col: Column) -> ParentDataFrame:
        if not isinstance(col, Column):
            raise PySparkTypeError(
                errorClass="NOT_COLUMN",
                messageParameters={"arg_name": "col", "arg_type": type(col).__name__},
            )
        return DataFrame(self._jdf.withColumn(colName, col._jc), self.sparkSession)

    def withColumnRenamed(self, existing: str, new: str) -> ParentDataFrame:
        return DataFrame(self._jdf.withColumnRenamed(existing, new), self.sparkSession)

    def withColumnsRenamed(self, colsMap: Dict[str, str]) -> ParentDataFrame:
        if not isinstance(colsMap, dict):
            raise PySparkTypeError(
                errorClass="NOT_DICT",
                messageParameters={"arg_name": "colsMap", "arg_type": type(colsMap).__name__},
            )

        col_names: List[str] = []
        new_col_names: List[str] = []
        for k, v in colsMap.items():
            col_names.append(k)
            new_col_names.append(v)

        return DataFrame(
            self._jdf.withColumnsRenamed(
                _to_seq(self._sc, col_names), _to_seq(self._sc, new_col_names)
            ),
            self.sparkSession,
        )

    def withMetadata(self, columnName: str, metadata: Dict[str, Any]) -> ParentDataFrame:
        from py4j.java_gateway import JVMView

        if not isinstance(metadata, dict):
            raise PySparkTypeError(
                errorClass="NOT_DICT",
                messageParameters={"arg_name": "metadata", "arg_type": type(metadata).__name__},
            )
        sc = get_active_spark_context()
        jmeta = cast(JVMView, sc._jvm).org.apache.spark.sql.types.Metadata.fromJson(
            json.dumps(metadata)
        )
        return DataFrame(self._jdf.withMetadata(columnName, jmeta), self.sparkSession)

    @overload
    def drop(self, cols: "ColumnOrName") -> ParentDataFrame:
        ...

    @overload
    def drop(self, *cols: str) -> ParentDataFrame:
        ...

    def drop(self, *cols: "ColumnOrName") -> ParentDataFrame:  # type: ignore[misc]
        column_names: List[str] = []
        java_columns: List["JavaObject"] = []

        for c in cols:
            if isinstance(c, str):
                column_names.append(c)
            elif isinstance(c, Column):
                java_columns.append(c._jc)
            else:
                raise PySparkTypeError(
                    errorClass="NOT_COLUMN_OR_STR",
                    messageParameters={"arg_name": "col", "arg_type": type(c).__name__},
                )

        jdf = self._jdf
        if len(java_columns) > 0:
            first_column, *remaining_columns = java_columns
            jdf = jdf.drop(first_column, self._jseq(remaining_columns))
        if len(column_names) > 0:
            jdf = jdf.drop(self._jseq(column_names))

        return DataFrame(jdf, self.sparkSession)

    def toDF(self, *cols: str) -> ParentDataFrame:
        for col in cols:
            if not isinstance(col, str):
                raise PySparkTypeError(
                    errorClass="NOT_LIST_OF_STR",
                    messageParameters={"arg_name": "cols", "arg_type": type(col).__name__},
                )
        jdf = self._jdf.toDF(self._jseq(cols))
        return DataFrame(jdf, self.sparkSession)

    def transform(
        self, func: Callable[..., ParentDataFrame], *args: Any, **kwargs: Any
    ) -> ParentDataFrame:
        result = func(self, *args, **kwargs)
        assert isinstance(
            result, DataFrame
        ), "Func returned an instance of type [%s], " "should have been DataFrame." % type(result)
        return result

    def sameSemantics(self, other: ParentDataFrame) -> bool:
        if not isinstance(other, DataFrame):
            raise PySparkTypeError(
                errorClass="NOT_DATAFRAME",
                messageParameters={"arg_name": "other", "arg_type": type(other).__name__},
            )
        return self._jdf.sameSemantics(other._jdf)

    def semanticHash(self) -> int:
        return self._jdf.semanticHash()

    def inputFiles(self) -> List[str]:
        return list(self._jdf.inputFiles())

    def where(self, condition: Union[Column, str]) -> ParentDataFrame:
        return self.filter(condition)

    # Two aliases below were added for pandas compatibility many years ago.
    # There are too many differences compared to pandas and we cannot just
    # make it "compatible" by adding aliases. Therefore, we stop adding such
    # aliases as of Spark 3.0. Two methods below remain just
    # for legacy users currently.
    @overload
    def groupby(self, *cols: "ColumnOrNameOrOrdinal") -> "GroupedData":
        ...

    @overload
    def groupby(self, __cols: Union[List[Column], List[str], List[int]]) -> "GroupedData":
        ...

    def groupby(self, *cols: "ColumnOrNameOrOrdinal") -> "GroupedData":  # type: ignore[misc]
        return self.groupBy(*cols)

    def drop_duplicates(self, subset: Optional[List[str]] = None) -> ParentDataFrame:
        return self.dropDuplicates(subset)

    def writeTo(self, table: str) -> "DataFrameWriterV2":
        return DataFrameWriterV2(self, table)

    def mergeInto(self, table: str, condition: Column) -> "MergeIntoWriter":
        return MergeIntoWriter(self, table, condition)

    def pandas_api(
        self, index_col: Optional[Union[str, List[str]]] = None
    ) -> "PandasOnSparkDataFrame":
        from pyspark.pandas.namespace import _get_index_map
        from pyspark.pandas.frame import DataFrame as PandasOnSparkDataFrame
        from pyspark.pandas.internal import InternalFrame

        index_spark_columns, index_names = _get_index_map(self, index_col)
        internal = InternalFrame(
            spark_frame=self,
            index_spark_columns=index_spark_columns,
            index_names=index_names,  # type: ignore[arg-type]
        )
        return PandasOnSparkDataFrame(internal)

    def mapInPandas(
        self,
        func: "PandasMapIterFunction",
        schema: Union[StructType, str],
        barrier: bool = False,
        profile: Optional[ResourceProfile] = None,
    ) -> ParentDataFrame:
        return PandasMapOpsMixin.mapInPandas(self, func, schema, barrier, profile)

    def mapInArrow(
        self,
        func: "ArrowMapIterFunction",
        schema: Union[StructType, str],
        barrier: bool = False,
        profile: Optional[ResourceProfile] = None,
    ) -> ParentDataFrame:
        return PandasMapOpsMixin.mapInArrow(self, func, schema, barrier, profile)

    def toArrow(self) -> "pa.Table":
        return PandasConversionMixin.toArrow(self)

    def toPandas(self) -> "PandasDataFrameLike":
        return PandasConversionMixin.toPandas(self)

    def transpose(self, indexColumn: Optional["ColumnOrName"] = None) -> ParentDataFrame:
        if indexColumn is not None:
            return DataFrame(self._jdf.transpose(_to_java_column(indexColumn)), self.sparkSession)
        else:
            return DataFrame(self._jdf.transpose(), self.sparkSession)

    def asTable(self) -> TableArg:
        from pyspark.sql.classic.table_arg import TableArg as ClassicTableArg

        return ClassicTableArg(self._jdf.asTable())

    def scalar(self) -> Column:
        return Column(self._jdf.scalar())

    def exists(self) -> Column:
        return Column(self._jdf.exists())

    @property
    def executionInfo(self) -> Optional["ExecutionInfo"]:
        raise PySparkValueError(
            errorClass="CLASSIC_OPERATION_NOT_SUPPORTED_ON_DF",
            messageParameters={"member": "queryExecution"},
        )

    @property
    def plot(self) -> "PySparkPlotAccessor":  # type: ignore[name-defined] # noqa: F821
        from pyspark.sql.plot import PySparkPlotAccessor

        return PySparkPlotAccessor(self)


class DataFrameNaFunctions(ParentDataFrameNaFunctions):
    def __init__(self, df: ParentDataFrame):
        self.df = df

    def drop(
        self,
        how: str = "any",
        thresh: Optional[int] = None,
        subset: Optional[Union[str, Tuple[str, ...], List[str]]] = None,
    ) -> ParentDataFrame:
        return self.df.dropna(how=how, thresh=thresh, subset=subset)

    @overload
    def fill(self, value: "LiteralType", subset: Optional[List[str]] = ...) -> ParentDataFrame:
        ...

    @overload
    def fill(self, value: Dict[str, "LiteralType"]) -> ParentDataFrame:
        ...

    def fill(
        self,
        value: Union["LiteralType", Dict[str, "LiteralType"]],
        subset: Optional[List[str]] = None,
    ) -> ParentDataFrame:
        return self.df.fillna(value=value, subset=subset)  # type: ignore[arg-type]

    @overload
    def replace(
        self,
        to_replace: List["LiteralType"],
        value: List["OptionalPrimitiveType"],
        subset: Optional[List[str]] = ...,
    ) -> ParentDataFrame:
        ...

    @overload
    def replace(
        self,
        to_replace: Dict["LiteralType", "OptionalPrimitiveType"],
        subset: Optional[List[str]] = ...,
    ) -> ParentDataFrame:
        ...

    @overload
    def replace(
        self,
        to_replace: List["LiteralType"],
        value: "OptionalPrimitiveType",
        subset: Optional[List[str]] = ...,
    ) -> ParentDataFrame:
        ...

    def replace(  # type: ignore[misc]
        self,
        to_replace: Union[List["LiteralType"], Dict["LiteralType", "OptionalPrimitiveType"]],
        value: Optional[
            Union["OptionalPrimitiveType", List["OptionalPrimitiveType"], _NoValueType]
        ] = _NoValue,
        subset: Optional[List[str]] = None,
    ) -> ParentDataFrame:
        return self.df.replace(to_replace, value, subset)  # type: ignore[arg-type]


class DataFrameStatFunctions(ParentDataFrameStatFunctions):
    def __init__(self, df: ParentDataFrame):
        self.df = df

    @overload
    def approxQuantile(
        self,
        col: str,
        probabilities: Union[List[float], Tuple[float]],
        relativeError: float,
    ) -> List[float]:
        ...

    @overload
    def approxQuantile(
        self,
        col: Union[List[str], Tuple[str]],
        probabilities: Union[List[float], Tuple[float]],
        relativeError: float,
    ) -> List[List[float]]:
        ...

    def approxQuantile(
        self,
        col: Union[str, List[str], Tuple[str]],
        probabilities: Union[List[float], Tuple[float]],
        relativeError: float,
    ) -> Union[List[float], List[List[float]]]:
        return self.df.approxQuantile(col, probabilities, relativeError)

    def corr(self, col1: str, col2: str, method: Optional[str] = None) -> float:
        return self.df.corr(col1, col2, method)

    def cov(self, col1: str, col2: str) -> float:
        return self.df.cov(col1, col2)

    def crosstab(self, col1: str, col2: str) -> ParentDataFrame:
        return self.df.crosstab(col1, col2)

    def freqItems(self, cols: List[str], support: Optional[float] = None) -> ParentDataFrame:
        return self.df.freqItems(cols, support)

    def sampleBy(
        self, col: str, fractions: Dict[Any, float], seed: Optional[int] = None
    ) -> ParentDataFrame:
        return self.df.sampleBy(col, fractions, seed)


def _test() -> None:
    import doctest
    from pyspark.sql import SparkSession
    import pyspark.sql.dataframe

    # It inherits docstrings but doctests cannot detect them so we run
    # the parent classe's doctests here directly.
    globs = pyspark.sql.dataframe.__dict__.copy()
    spark = (
        SparkSession.builder.master("local[4]").appName("sql.classic.dataframe tests").getOrCreate()
    )
    globs["spark"] = spark
    (failure_count, test_count) = doctest.testmod(
        pyspark.sql.dataframe,
        globs=globs,
        optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF,
    )
    spark.stop()
    if failure_count:
        sys.exit(-1)


if __name__ == "__main__":
    _test()
