python/pyspark/sql/classic/dataframe.py (1,539 lines of code) (raw):

# # Licensed to the Apache Software Foundation (ASF) under one or more # contributor license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright ownership. # The ASF licenses this file to You under the Apache License, Version 2.0 # (the "License"); you may not use this file except in compliance with # the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import os import json import sys import random import warnings from collections.abc import Iterable from functools import reduce, cached_property from typing import ( Any, Callable, Dict, Iterator, List, Optional, Sequence, Tuple, Type, Union, cast, overload, TYPE_CHECKING, ) from pyspark import _NoValue from pyspark.resource import ResourceProfile from pyspark._globals import _NoValueType from pyspark.errors import ( AnalysisException, PySparkTypeError, PySparkValueError, PySparkIndexError, PySparkAttributeError, ) from pyspark.util import ( _load_from_socket, _local_iterator_from_socket, ) from pyspark.serializers import BatchedSerializer, CPickleSerializer, UTF8Deserializer from pyspark.storagelevel import StorageLevel from pyspark.traceback_utils import SCCallSiteSync from pyspark.sql.column import Column from pyspark.sql.functions import builtin as F from pyspark.sql.classic.column import _to_seq, _to_list, _to_java_column from pyspark.sql.readwriter import DataFrameWriter, DataFrameWriterV2 from pyspark.sql.merge import MergeIntoWriter from pyspark.sql.streaming import DataStreamWriter from pyspark.sql.types import ( StructType, Row, _parse_datatype_json_string, ) from pyspark.sql.dataframe import ( DataFrame as ParentDataFrame, DataFrameNaFunctions as ParentDataFrameNaFunctions, DataFrameStatFunctions as ParentDataFrameStatFunctions, ) from pyspark.sql.utils import get_active_spark_context, to_java_array, to_scala_map from pyspark.sql.pandas.conversion import PandasConversionMixin from pyspark.sql.pandas.map_ops import PandasMapOpsMixin from pyspark.sql.table_arg import TableArg if TYPE_CHECKING: from py4j.java_gateway import JavaObject import pyarrow as pa from pyspark.core.rdd import RDD from pyspark.core.context import SparkContext from pyspark._typing import PrimitiveType from pyspark.pandas.frame import DataFrame as PandasOnSparkDataFrame from pyspark.sql._typing import ( ColumnOrName, ColumnOrNameOrOrdinal, LiteralType, OptionalPrimitiveType, ) from pyspark.sql.pandas._typing import ( PandasMapIterFunction, ArrowMapIterFunction, DataFrameLike as PandasDataFrameLike, ) from pyspark.sql.context import SQLContext from pyspark.sql.session import SparkSession from pyspark.sql.group import GroupedData from pyspark.sql.observation import Observation from pyspark.sql.metrics import ExecutionInfo class DataFrame(ParentDataFrame, PandasMapOpsMixin, PandasConversionMixin): def __new__( cls, jdf: "JavaObject", sql_ctx: Union["SQLContext", "SparkSession"], ) -> "DataFrame": self = object.__new__(cls) self.__init__(jdf, sql_ctx) # type: ignore[misc] return self def __init__( self, jdf: "JavaObject", sql_ctx: Union["SQLContext", "SparkSession"], ): from pyspark.sql.context import SQLContext if isinstance(sql_ctx, SQLContext): assert not os.environ.get("SPARK_TESTING") # Sanity check for our internal usage. assert isinstance(sql_ctx, SQLContext) # We should remove this if-else branch in the future release, and rename # sql_ctx to session in the constructor. This is an internal code path but # was kept with a warning because it's used intensively by third-party libraries. warnings.warn("DataFrame constructor is internal. Do not directly use it.") self._sql_ctx = sql_ctx session = sql_ctx.sparkSession else: session = sql_ctx self._session: "SparkSession" = session self._sc: "SparkContext" = sql_ctx._sc self._jdf: "JavaObject" = jdf self.is_cached = False # Check whether _repr_html is supported or not, we use it to avoid calling _jdf twice # by __repr__ and _repr_html_ while eager evaluation opens. self._support_repr_html = False @cached_property def sql_ctx(self) -> "SQLContext": from pyspark.sql.context import SQLContext warnings.warn( "DataFrame.sql_ctx is an internal property, and will be removed " "in future releases. Use DataFrame.sparkSession instead." ) return SQLContext._get_or_create(self._sc) @property def sparkSession(self) -> "SparkSession": return self._session @cached_property def rdd(self) -> "RDD[Row]": from pyspark.core.rdd import RDD jrdd = self._jdf.javaToPython() return RDD(jrdd, self.sparkSession._sc, BatchedSerializer(CPickleSerializer())) @property def na(self) -> ParentDataFrameNaFunctions: return DataFrameNaFunctions(self) @property def stat(self) -> ParentDataFrameStatFunctions: return DataFrameStatFunctions(self) def toJSON(self, use_unicode: bool = True) -> "RDD[str]": from pyspark.core.rdd import RDD rdd = self._jdf.toJSON() return RDD(rdd.toJavaRDD(), self._sc, UTF8Deserializer(use_unicode)) def registerTempTable(self, name: str) -> None: warnings.warn("Deprecated in 2.0, use createOrReplaceTempView instead.", FutureWarning) self._jdf.createOrReplaceTempView(name) def createTempView(self, name: str) -> None: self._jdf.createTempView(name) def createOrReplaceTempView(self, name: str) -> None: self._jdf.createOrReplaceTempView(name) def createGlobalTempView(self, name: str) -> None: self._jdf.createGlobalTempView(name) def createOrReplaceGlobalTempView(self, name: str) -> None: self._jdf.createOrReplaceGlobalTempView(name) @property def write(self) -> DataFrameWriter: return DataFrameWriter(self) @property def writeStream(self) -> DataStreamWriter: return DataStreamWriter(self) @cached_property def schema(self) -> StructType: try: return cast(StructType, _parse_datatype_json_string(self._jdf.schema().json())) except AnalysisException as e: raise e except Exception as e: raise PySparkValueError( errorClass="CANNOT_PARSE_DATATYPE", messageParameters={"error": str(e)}, ) def printSchema(self, level: Optional[int] = None) -> None: if level: print(self._jdf.schema().treeString(level)) else: print(self._jdf.schema().treeString()) def explain( self, extended: Optional[Union[bool, str]] = None, mode: Optional[str] = None ) -> None: if extended is not None and mode is not None: raise PySparkValueError( errorClass="CANNOT_SET_TOGETHER", messageParameters={"arg_list": "extended and mode"}, ) # For the no argument case: df.explain() is_no_argument = extended is None and mode is None # For the cases below: # explain(True) # explain(extended=False) is_extended_case = isinstance(extended, bool) and mode is None # For the case when extended is mode: # df.explain("formatted") is_extended_as_mode = isinstance(extended, str) and mode is None # For the mode specified: # df.explain(mode="formatted") is_mode_case = extended is None and isinstance(mode, str) if not (is_no_argument or is_extended_case or is_extended_as_mode or is_mode_case): if (extended is not None) and (not isinstance(extended, (bool, str))): raise PySparkTypeError( errorClass="NOT_BOOL_OR_STR", messageParameters={ "arg_name": "extended", "arg_type": type(extended).__name__, }, ) if (mode is not None) and (not isinstance(mode, str)): raise PySparkTypeError( errorClass="NOT_STR", messageParameters={"arg_name": "mode", "arg_type": type(mode).__name__}, ) # Sets an explain mode depending on a given argument if is_no_argument: explain_mode = "simple" elif is_extended_case: explain_mode = "extended" if extended else "simple" elif is_mode_case: explain_mode = cast(str, mode) elif is_extended_as_mode: explain_mode = cast(str, extended) assert self._sc._jvm is not None print(self._sc._jvm.PythonSQLUtils.explainString(self._jdf.queryExecution(), explain_mode)) def exceptAll(self, other: ParentDataFrame) -> ParentDataFrame: return DataFrame(self._jdf.exceptAll(other._jdf), self.sparkSession) def isLocal(self) -> bool: return self._jdf.isLocal() @property def isStreaming(self) -> bool: return self._jdf.isStreaming() def isEmpty(self) -> bool: return self._jdf.isEmpty() def show(self, n: int = 20, truncate: Union[bool, int] = True, vertical: bool = False) -> None: print(self._show_string(n, truncate, vertical)) def _show_string( self, n: int = 20, truncate: Union[bool, int] = True, vertical: bool = False ) -> str: if not isinstance(n, int) or isinstance(n, bool): raise PySparkTypeError( errorClass="NOT_INT", messageParameters={"arg_name": "n", "arg_type": type(n).__name__}, ) if not isinstance(vertical, bool): raise PySparkTypeError( errorClass="NOT_BOOL", messageParameters={"arg_name": "vertical", "arg_type": type(vertical).__name__}, ) if isinstance(truncate, bool) and truncate: return self._jdf.showString(n, 20, vertical) else: try: int_truncate = int(truncate) except ValueError: raise PySparkTypeError( errorClass="NOT_BOOL", messageParameters={ "arg_name": "truncate", "arg_type": type(truncate).__name__, }, ) return self._jdf.showString(n, int_truncate, vertical) def __repr__(self) -> str: if not self._support_repr_html and self.sparkSession._jconf.isReplEagerEvalEnabled(): vertical = False return self._jdf.showString( self.sparkSession._jconf.replEagerEvalMaxNumRows(), self.sparkSession._jconf.replEagerEvalTruncate(), vertical, ) else: return "DataFrame[%s]" % (", ".join("%s: %s" % c for c in self.dtypes)) def _repr_html_(self) -> Optional[str]: """Returns a :class:`DataFrame` with html code when you enabled eager evaluation by 'spark.sql.repl.eagerEval.enabled', this only called by REPL you are using support eager evaluation with HTML. """ if not self._support_repr_html: self._support_repr_html = True if self.sparkSession._jconf.isReplEagerEvalEnabled(): return self._jdf.htmlString( self.sparkSession._jconf.replEagerEvalMaxNumRows(), self.sparkSession._jconf.replEagerEvalTruncate(), ) else: return None def checkpoint(self, eager: bool = True) -> ParentDataFrame: jdf = self._jdf.checkpoint(eager) return DataFrame(jdf, self.sparkSession) def localCheckpoint( self, eager: bool = True, storageLevel: Optional[StorageLevel] = None ) -> ParentDataFrame: if storageLevel is None: jdf = self._jdf.localCheckpoint(eager) else: jdf = self._jdf.localCheckpoint(eager, self._sc._getJavaStorageLevel(storageLevel)) return DataFrame(jdf, self.sparkSession) def withWatermark(self, eventTime: str, delayThreshold: str) -> ParentDataFrame: if not eventTime or type(eventTime) is not str: raise PySparkTypeError( errorClass="NOT_STR", messageParameters={"arg_name": "eventTime", "arg_type": type(eventTime).__name__}, ) if not delayThreshold or type(delayThreshold) is not str: raise PySparkTypeError( errorClass="NOT_STR", messageParameters={ "arg_name": "delayThreshold", "arg_type": type(delayThreshold).__name__, }, ) jdf = self._jdf.withWatermark(eventTime, delayThreshold) return DataFrame(jdf, self.sparkSession) def hint( self, name: str, *parameters: Union["PrimitiveType", "Column", List["PrimitiveType"]] ) -> ParentDataFrame: if len(parameters) == 1 and isinstance(parameters[0], list): parameters = parameters[0] # type: ignore[assignment] if not isinstance(name, str): raise PySparkTypeError( errorClass="NOT_STR", messageParameters={"arg_name": "name", "arg_type": type(name).__name__}, ) allowed_types = (str, float, int, Column, list) allowed_primitive_types = (str, float, int) allowed_types_repr = ", ".join( [t.__name__ for t in allowed_types[:-1]] + ["list[" + t.__name__ + "]" for t in allowed_primitive_types] ) for p in parameters: if not isinstance(p, allowed_types): raise PySparkTypeError( errorClass="DISALLOWED_TYPE_FOR_CONTAINER", messageParameters={ "arg_name": "parameters", "arg_type": type(parameters).__name__, "allowed_types": allowed_types_repr, "item_type": type(p).__name__, }, ) if isinstance(p, list): if not all(isinstance(e, allowed_primitive_types) for e in p): raise PySparkTypeError( errorClass="DISALLOWED_TYPE_FOR_CONTAINER", messageParameters={ "arg_name": "parameters", "arg_type": type(parameters).__name__, "allowed_types": allowed_types_repr, "item_type": type(p).__name__ + "[" + type(p[0]).__name__ + "]", }, ) def _converter(parameter: Union[str, list, float, int, Column]) -> Any: if isinstance(parameter, Column): return _to_java_column(parameter) elif isinstance(parameter, list): # for list input, we are assuming only one element type exist in the list. # for empty list, we are converting it into an empty long[] in the JVM side. gateway = self._sc._gateway assert gateway is not None jclass = gateway.jvm.long if len(parameter) >= 1: mapping = { str: gateway.jvm.java.lang.String, float: gateway.jvm.double, int: gateway.jvm.long, } jclass = mapping[type(parameter[0])] return to_java_array(gateway, jclass, parameter) else: return parameter jdf = self._jdf.hint(name, self._jseq(parameters, _converter)) return DataFrame(jdf, self.sparkSession) def count(self) -> int: return int(self._jdf.count()) def collect(self) -> List[Row]: with SCCallSiteSync(self._sc): sock_info = self._jdf.collectToPython() return list(_load_from_socket(sock_info, BatchedSerializer(CPickleSerializer()))) def toLocalIterator(self, prefetchPartitions: bool = False) -> Iterator[Row]: with SCCallSiteSync(self._sc): sock_info = self._jdf.toPythonIterator(prefetchPartitions) return _local_iterator_from_socket(sock_info, BatchedSerializer(CPickleSerializer())) def limit(self, num: int) -> ParentDataFrame: jdf = self._jdf.limit(num) return DataFrame(jdf, self.sparkSession) def offset(self, num: int) -> ParentDataFrame: jdf = self._jdf.offset(num) return DataFrame(jdf, self.sparkSession) def take(self, num: int) -> List[Row]: return self.limit(num).collect() def tail(self, num: int) -> List[Row]: with SCCallSiteSync(self._sc): sock_info = self._jdf.tailToPython(num) return list(_load_from_socket(sock_info, BatchedSerializer(CPickleSerializer()))) def foreach(self, f: Callable[[Row], None]) -> None: self.rdd.foreach(f) def foreachPartition(self, f: Callable[[Iterator[Row]], None]) -> None: self.rdd.foreachPartition(f) # type: ignore[arg-type] def cache(self) -> ParentDataFrame: self.is_cached = True self._jdf.cache() return self def persist( self, storageLevel: StorageLevel = (StorageLevel.MEMORY_AND_DISK_DESER), ) -> ParentDataFrame: self.is_cached = True javaStorageLevel = self._sc._getJavaStorageLevel(storageLevel) self._jdf.persist(javaStorageLevel) return self @property def storageLevel(self) -> StorageLevel: java_storage_level = self._jdf.storageLevel() storage_level = StorageLevel( java_storage_level.useDisk(), java_storage_level.useMemory(), java_storage_level.useOffHeap(), java_storage_level.deserialized(), java_storage_level.replication(), ) return storage_level def unpersist(self, blocking: bool = False) -> ParentDataFrame: self.is_cached = False self._jdf.unpersist(blocking) return self def coalesce(self, numPartitions: int) -> ParentDataFrame: return DataFrame(self._jdf.coalesce(numPartitions), self.sparkSession) @overload def repartition(self, numPartitions: int, *cols: "ColumnOrName") -> ParentDataFrame: ... @overload def repartition(self, *cols: "ColumnOrName") -> ParentDataFrame: ... def repartition( # type: ignore[misc] self, numPartitions: Union[int, "ColumnOrName"], *cols: "ColumnOrName" ) -> ParentDataFrame: if isinstance(numPartitions, int): if len(cols) == 0: return DataFrame(self._jdf.repartition(numPartitions), self.sparkSession) else: return DataFrame( self._jdf.repartition(numPartitions, self._jcols(*cols)), self.sparkSession, ) elif isinstance(numPartitions, (str, Column)): cols = (numPartitions,) + cols return DataFrame(self._jdf.repartition(self._jcols(*cols)), self.sparkSession) else: raise PySparkTypeError( errorClass="NOT_COLUMN_OR_STR", messageParameters={ "arg_name": "numPartitions", "arg_type": type(numPartitions).__name__, }, ) @overload def repartitionByRange(self, numPartitions: int, *cols: "ColumnOrName") -> ParentDataFrame: ... @overload def repartitionByRange(self, *cols: "ColumnOrName") -> ParentDataFrame: ... def repartitionByRange( # type: ignore[misc] self, numPartitions: Union[int, "ColumnOrName"], *cols: "ColumnOrName" ) -> ParentDataFrame: if isinstance(numPartitions, int): if len(cols) == 0: raise PySparkValueError( errorClass="CANNOT_BE_EMPTY", messageParameters={"item": "partition-by expression"}, ) else: return DataFrame( self._jdf.repartitionByRange(numPartitions, self._jcols(*cols)), self.sparkSession, ) elif isinstance(numPartitions, (str, Column)): cols = (numPartitions,) + cols return DataFrame(self._jdf.repartitionByRange(self._jcols(*cols)), self.sparkSession) else: raise PySparkTypeError( errorClass="NOT_COLUMN_OR_INT_OR_STR", messageParameters={ "arg_name": "numPartitions", "arg_type": type(numPartitions).__name__, }, ) def distinct(self) -> ParentDataFrame: return DataFrame(self._jdf.distinct(), self.sparkSession) @overload def sample(self, fraction: float, seed: Optional[int] = ...) -> ParentDataFrame: ... @overload def sample( self, withReplacement: Optional[bool], fraction: float, seed: Optional[int] = ..., ) -> ParentDataFrame: ... def sample( # type: ignore[misc] self, withReplacement: Optional[Union[float, bool]] = None, fraction: Optional[Union[int, float]] = None, seed: Optional[int] = None, ) -> ParentDataFrame: _w, _f, _s = self._preapare_args_for_sample(withReplacement, fraction, seed) jdf = self._jdf.sample(*[_w, _f, _s]) return DataFrame(jdf, self.sparkSession) def sampleBy( self, col: "ColumnOrName", fractions: Dict[Any, float], seed: Optional[int] = None ) -> ParentDataFrame: if isinstance(col, str): col = Column(col) elif not isinstance(col, Column): raise PySparkTypeError( errorClass="NOT_COLUMN_OR_STR", messageParameters={"arg_name": "col", "arg_type": type(col).__name__}, ) if not isinstance(fractions, dict): raise PySparkTypeError( errorClass="NOT_DICT", messageParameters={"arg_name": "fractions", "arg_type": type(fractions).__name__}, ) for k, v in fractions.items(): if not isinstance(k, (float, int, str)): raise PySparkTypeError( errorClass="DISALLOWED_TYPE_FOR_CONTAINER", messageParameters={ "arg_name": "fractions", "arg_type": type(fractions).__name__, "allowed_types": "float, int, str", "item_type": type(k).__name__, }, ) fractions[k] = float(v) col = col._jc seed = seed if seed is not None else random.randint(0, sys.maxsize) return DataFrame( self._jdf.stat().sampleBy(col, self._jmap(fractions), seed), self.sparkSession ) def randomSplit( self, weights: List[float], seed: Optional[int] = None ) -> List[ParentDataFrame]: for w in weights: if w < 0.0: raise PySparkValueError( errorClass="VALUE_NOT_POSITIVE", messageParameters={"arg_name": "weights", "arg_value": str(w)}, ) seed = seed if seed is not None else random.randint(0, sys.maxsize) df_array = self._jdf.randomSplit( _to_list(self.sparkSession._sc, cast(List["ColumnOrName"], weights)), int(seed) ) return [DataFrame(df, self.sparkSession) for df in df_array] @property def dtypes(self) -> List[Tuple[str, str]]: return [(str(f.name), f.dataType.simpleString()) for f in self.schema.fields] @property def columns(self) -> List[str]: return [f.name for f in self.schema.fields] def metadataColumn(self, colName: str) -> Column: if not isinstance(colName, str): raise PySparkTypeError( errorClass="NOT_STR", messageParameters={"arg_name": "colName", "arg_type": type(colName).__name__}, ) jc = self._jdf.metadataColumn(colName) return Column(jc) def colRegex(self, colName: str) -> Column: if not isinstance(colName, str): raise PySparkTypeError( errorClass="NOT_STR", messageParameters={"arg_name": "colName", "arg_type": type(colName).__name__}, ) jc = self._jdf.colRegex(colName) return Column(jc) def to(self, schema: StructType) -> ParentDataFrame: assert schema is not None jschema = self._jdf.sparkSession().parseDataType(schema.json()) return DataFrame(self._jdf.to(jschema), self.sparkSession) def alias(self, alias: str) -> ParentDataFrame: assert isinstance(alias, str), "alias should be a string" return DataFrame(getattr(self._jdf, "as")(alias), self.sparkSession) def crossJoin(self, other: ParentDataFrame) -> ParentDataFrame: jdf = self._jdf.crossJoin(other._jdf) return DataFrame(jdf, self.sparkSession) def join( self, other: ParentDataFrame, on: Optional[Union[str, List[str], Column, List[Column]]] = None, how: Optional[str] = None, ) -> ParentDataFrame: if on is not None and not isinstance(on, list): on = [on] # type: ignore[assignment] if on is not None: if isinstance(on[0], str): on = self._jseq(cast(List[str], on)) else: assert isinstance(on[0], Column), "on should be Column or list of Column" on = reduce(lambda x, y: x.__and__(y), cast(List[Column], on)) on = on._jc if on is None and how is None: jdf = self._jdf.join(other._jdf) else: if how is None: how = "inner" if on is None: on = self._jseq([]) assert isinstance(how, str), "how should be a string" jdf = self._jdf.join(other._jdf, on, how) return DataFrame(jdf, self.sparkSession) def lateralJoin( self, other: ParentDataFrame, on: Optional[Column] = None, how: Optional[str] = None, ) -> ParentDataFrame: if on is None and how is None: jdf = self._jdf.lateralJoin(other._jdf) elif on is None: jdf = self._jdf.lateralJoin(other._jdf, how) elif how is None: jdf = self._jdf.lateralJoin(other._jdf, on._jc) else: jdf = self._jdf.lateralJoin(other._jdf, on._jc, how) return DataFrame(jdf, self.sparkSession) # TODO(SPARK-22947): Fix the DataFrame API. def _joinAsOf( self, other: ParentDataFrame, leftAsOfColumn: Union[str, Column], rightAsOfColumn: Union[str, Column], on: Optional[Union[str, List[str], Column, List[Column]]] = None, how: Optional[str] = None, *, tolerance: Optional[Column] = None, allowExactMatches: bool = True, direction: str = "backward", ) -> ParentDataFrame: """ Perform an as-of join. This is similar to a left-join except that we match on the nearest key rather than equal keys. .. versionchanged:: 4.0.0 Supports Spark Connect. Parameters ---------- other : :class:`DataFrame` Right side of the join leftAsOfColumn : str or :class:`Column` a string for the as-of join column name, or a Column rightAsOfColumn : str or :class:`Column` a string for the as-of join column name, or a Column on : str, list or :class:`Column`, optional a string for the join column name, a list of column names, a join expression (Column), or a list of Columns. If `on` is a string or a list of strings indicating the name of the join column(s), the column(s) must exist on both sides, and this performs an equi-join. how : str, optional default ``inner``. Must be one of: ``inner`` and ``left``. tolerance : :class:`Column`, optional an asof tolerance within this range; must be compatible with the merge index. allowExactMatches : bool, optional default ``True``. direction : str, optional default ``backward``. Must be one of: ``backward``, ``forward``, and ``nearest``. Examples -------- The following performs an as-of join between ``left`` and ``right``. >>> left = spark.createDataFrame([(1, "a"), (5, "b"), (10, "c")], ["a", "left_val"]) >>> right = spark.createDataFrame([(1, 1), (2, 2), (3, 3), (6, 6), (7, 7)], ... ["a", "right_val"]) >>> left._joinAsOf( ... right, leftAsOfColumn="a", rightAsOfColumn="a" ... ).select(left.a, 'left_val', 'right_val').sort("a").collect() [Row(a=1, left_val='a', right_val=1), Row(a=5, left_val='b', right_val=3), Row(a=10, left_val='c', right_val=7)] >>> from pyspark.sql import functions as sf >>> left._joinAsOf( ... right, leftAsOfColumn="a", rightAsOfColumn="a", tolerance=sf.lit(1) ... ).select(left.a, 'left_val', 'right_val').sort("a").collect() [Row(a=1, left_val='a', right_val=1)] >>> left._joinAsOf( ... right, leftAsOfColumn="a", rightAsOfColumn="a", how="left", tolerance=sf.lit(1) ... ).select(left.a, 'left_val', 'right_val').sort("a").collect() [Row(a=1, left_val='a', right_val=1), Row(a=5, left_val='b', right_val=None), Row(a=10, left_val='c', right_val=None)] >>> left._joinAsOf( ... right, leftAsOfColumn="a", rightAsOfColumn="a", allowExactMatches=False ... ).select(left.a, 'left_val', 'right_val').sort("a").collect() [Row(a=5, left_val='b', right_val=3), Row(a=10, left_val='c', right_val=7)] >>> left._joinAsOf( ... right, leftAsOfColumn="a", rightAsOfColumn="a", direction="forward" ... ).select(left.a, 'left_val', 'right_val').sort("a").collect() [Row(a=1, left_val='a', right_val=1), Row(a=5, left_val='b', right_val=6)] """ if isinstance(leftAsOfColumn, str): leftAsOfColumn = self[leftAsOfColumn] left_as_of_jcol = leftAsOfColumn._jc if isinstance(rightAsOfColumn, str): rightAsOfColumn = other[rightAsOfColumn] right_as_of_jcol = rightAsOfColumn._jc if on is not None and not isinstance(on, list): on = [on] # type: ignore[assignment] if on is not None: if isinstance(on[0], str): on = self._jseq(cast(List[str], on)) else: assert isinstance(on[0], Column), "on should be Column or list of Column" on = reduce(lambda x, y: x.__and__(y), cast(List[Column], on)) on = on._jc if how is None: how = "inner" assert isinstance(how, str), "how should be a string" if tolerance is not None: assert isinstance(tolerance, Column), "tolerance should be Column" tolerance = tolerance._jc jdf = self._jdf.joinAsOf( other._jdf, left_as_of_jcol, right_as_of_jcol, on, how, tolerance, allowExactMatches, direction, ) return DataFrame(jdf, self.sparkSession) def sortWithinPartitions( self, *cols: Union[int, str, Column, List[Union[int, str, Column]]], **kwargs: Any, ) -> ParentDataFrame: _cols = self._preapare_cols_for_sort(F.col, cols, kwargs) jdf = self._jdf.sortWithinPartitions(self._jseq(_cols, _to_java_column)) return DataFrame(jdf, self.sparkSession) def sort( self, *cols: Union[int, str, Column, List[Union[int, str, Column]]], **kwargs: Any, ) -> ParentDataFrame: _cols = self._preapare_cols_for_sort(F.col, cols, kwargs) jdf = self._jdf.sort(self._jseq(_cols, _to_java_column)) return DataFrame(jdf, self.sparkSession) orderBy = sort def _jseq( self, cols: Sequence, converter: Optional[Callable[..., Union["PrimitiveType", "JavaObject"]]] = None, ) -> "JavaObject": """Return a JVM Seq of Columns from a list of Column or names""" return _to_seq(self.sparkSession._sc, cols, converter) def _jmap(self, jm: Dict) -> "JavaObject": """Return a JVM Scala Map from a dict""" return to_scala_map(self.sparkSession._sc._jvm, jm) def _jcols(self, *cols: "ColumnOrName") -> "JavaObject": """Return a JVM Seq of Columns from a list of Column or column names If `cols` has only one list in it, cols[0] will be used as the list. """ if len(cols) == 1 and isinstance(cols[0], list): cols = cols[0] return self._jseq(cols, _to_java_column) def _jcols_ordinal(self, *cols: "ColumnOrNameOrOrdinal") -> "JavaObject": """Return a JVM Seq of Columns from a list of Column or column names or column ordinals. If `cols` has only one list in it, cols[0] will be used as the list. """ if len(cols) == 1 and isinstance(cols[0], list): cols = cols[0] _cols = [] for c in cols: if isinstance(c, int) and not isinstance(c, bool): if c < 1: raise PySparkIndexError( errorClass="INDEX_NOT_POSITIVE", messageParameters={"index": str(c)} ) # ordinal is 1-based _cols.append(self[c - 1]) else: _cols.append(c) # type: ignore[arg-type] return self._jseq(_cols, _to_java_column) def describe(self, *cols: Union[str, List[str]]) -> ParentDataFrame: if len(cols) == 1 and isinstance(cols[0], list): cols = cols[0] # type: ignore[assignment] jdf = self._jdf.describe(self._jseq(cols)) return DataFrame(jdf, self.sparkSession) def summary(self, *statistics: str) -> ParentDataFrame: if len(statistics) == 1 and isinstance(statistics[0], list): statistics = statistics[0] jdf = self._jdf.summary(self._jseq(statistics)) return DataFrame(jdf, self.sparkSession) @overload def head(self) -> Optional[Row]: ... @overload def head(self, n: int) -> List[Row]: ... def head(self, n: Optional[int] = None) -> Union[Optional[Row], List[Row]]: if n is None: rs = self.head(1) return rs[0] if rs else None return self.take(n) def first(self) -> Optional[Row]: return self.head() @overload def __getitem__(self, item: Union[int, str]) -> Column: ... @overload def __getitem__(self, item: Union[Column, List, Tuple]) -> ParentDataFrame: ... def __getitem__( self, item: Union[int, str, Column, List, Tuple] ) -> Union[Column, ParentDataFrame]: if isinstance(item, str): jc = self._jdf.apply(item) return Column(jc) elif isinstance(item, Column): return self.filter(item) elif isinstance(item, (list, tuple)): return self.select(*item) elif isinstance(item, int): jc = self._jdf.apply(self.columns[item]) return Column(jc) else: raise PySparkTypeError( errorClass="NOT_COLUMN_OR_FLOAT_OR_INT_OR_LIST_OR_STR", messageParameters={"arg_name": "item", "arg_type": type(item).__name__}, ) def __getattr__(self, name: str) -> Column: if name not in self.columns: raise PySparkAttributeError( errorClass="ATTRIBUTE_NOT_SUPPORTED", messageParameters={"attr_name": name} ) jc = self._jdf.apply(name) return Column(jc) def __dir__(self) -> List[str]: attrs = set(dir(DataFrame)) attrs.update(filter(lambda s: s.isidentifier(), self.columns)) return sorted(attrs) @overload def select(self, *cols: "ColumnOrName") -> ParentDataFrame: ... @overload def select(self, __cols: Union[List[Column], List[str]]) -> ParentDataFrame: ... def select(self, *cols: "ColumnOrName") -> ParentDataFrame: # type: ignore[misc] jdf = self._jdf.select(self._jcols(*cols)) return DataFrame(jdf, self.sparkSession) @overload def selectExpr(self, *expr: str) -> ParentDataFrame: ... @overload def selectExpr(self, *expr: List[str]) -> ParentDataFrame: ... def selectExpr(self, *expr: Union[str, List[str]]) -> ParentDataFrame: if len(expr) == 1 and isinstance(expr[0], list): expr = expr[0] # type: ignore[assignment] jdf = self._jdf.selectExpr(self._jseq(expr)) return DataFrame(jdf, self.sparkSession) def filter(self, condition: Union[Column, str]) -> ParentDataFrame: if isinstance(condition, str): jdf = self._jdf.filter(condition) elif isinstance(condition, Column): jdf = self._jdf.filter(condition._jc) else: raise PySparkTypeError( errorClass="NOT_COLUMN_OR_STR", messageParameters={"arg_name": "condition", "arg_type": type(condition).__name__}, ) return DataFrame(jdf, self.sparkSession) @overload def groupBy(self, *cols: "ColumnOrNameOrOrdinal") -> "GroupedData": ... @overload def groupBy(self, __cols: Union[List[Column], List[str], List[int]]) -> "GroupedData": ... def groupBy(self, *cols: "ColumnOrNameOrOrdinal") -> "GroupedData": # type: ignore[misc] jgd = self._jdf.groupBy(self._jcols_ordinal(*cols)) from pyspark.sql.group import GroupedData return GroupedData(jgd, self) @overload def rollup(self, *cols: "ColumnOrName") -> "GroupedData": ... @overload def rollup(self, __cols: Union[List[Column], List[str]]) -> "GroupedData": ... def rollup(self, *cols: "ColumnOrNameOrOrdinal") -> "GroupedData": # type: ignore[misc] jgd = self._jdf.rollup(self._jcols_ordinal(*cols)) from pyspark.sql.group import GroupedData return GroupedData(jgd, self) @overload def cube(self, *cols: "ColumnOrName") -> "GroupedData": ... @overload def cube(self, __cols: Union[List[Column], List[str]]) -> "GroupedData": ... def cube(self, *cols: "ColumnOrName") -> "GroupedData": # type: ignore[misc] jgd = self._jdf.cube(self._jcols_ordinal(*cols)) from pyspark.sql.group import GroupedData return GroupedData(jgd, self) def groupingSets( self, groupingSets: Sequence[Sequence["ColumnOrName"]], *cols: "ColumnOrName" ) -> "GroupedData": from pyspark.sql.group import GroupedData jgrouping_sets = _to_seq(self._sc, [self._jcols(*inner) for inner in groupingSets]) jgd = self._jdf.groupingSets(jgrouping_sets, self._jcols(*cols)) return GroupedData(jgd, self) def unpivot( self, ids: Union["ColumnOrName", List["ColumnOrName"], Tuple["ColumnOrName", ...]], values: Optional[Union["ColumnOrName", List["ColumnOrName"], Tuple["ColumnOrName", ...]]], variableColumnName: str, valueColumnName: str, ) -> ParentDataFrame: assert ids is not None, "ids must not be None" def to_jcols( cols: Union["ColumnOrName", List["ColumnOrName"], Tuple["ColumnOrName", ...]] ) -> "JavaObject": if isinstance(cols, list): return self._jcols(*cols) if isinstance(cols, tuple): return self._jcols(*list(cols)) return self._jcols(cols) jids = to_jcols(ids) if values is None: jdf = self._jdf.unpivotWithSeq(jids, variableColumnName, valueColumnName) else: jvals = to_jcols(values) jdf = self._jdf.unpivotWithSeq(jids, jvals, variableColumnName, valueColumnName) return DataFrame(jdf, self.sparkSession) def melt( self, ids: Union["ColumnOrName", List["ColumnOrName"], Tuple["ColumnOrName", ...]], values: Optional[Union["ColumnOrName", List["ColumnOrName"], Tuple["ColumnOrName", ...]]], variableColumnName: str, valueColumnName: str, ) -> ParentDataFrame: return self.unpivot(ids, values, variableColumnName, valueColumnName) def agg(self, *exprs: Union[Column, Dict[str, str]]) -> ParentDataFrame: return self.groupBy().agg(*exprs) # type: ignore[arg-type] def observe( self, observation: Union["Observation", str], *exprs: Column, ) -> ParentDataFrame: from pyspark.sql import Observation if len(exprs) == 0: raise PySparkValueError( errorClass="CANNOT_BE_EMPTY", messageParameters={"item": "exprs"}, ) if not all(isinstance(c, Column) for c in exprs): raise PySparkTypeError( errorClass="NOT_LIST_OF_COLUMN", messageParameters={"arg_name": "exprs"}, ) if isinstance(observation, Observation): return observation._on(self, *exprs) elif isinstance(observation, str): return DataFrame( self._jdf.observe( observation, exprs[0]._jc, _to_seq(self._sc, [c._jc for c in exprs[1:]]) ), self.sparkSession, ) else: raise PySparkTypeError( errorClass="NOT_LIST_OF_COLUMN", messageParameters={ "arg_name": "observation", "arg_type": type(observation).__name__, }, ) def union(self, other: ParentDataFrame) -> ParentDataFrame: return DataFrame(self._jdf.union(other._jdf), self.sparkSession) def unionAll(self, other: ParentDataFrame) -> ParentDataFrame: return self.union(other) def unionByName( self, other: ParentDataFrame, allowMissingColumns: bool = False ) -> ParentDataFrame: return DataFrame(self._jdf.unionByName(other._jdf, allowMissingColumns), self.sparkSession) def intersect(self, other: ParentDataFrame) -> ParentDataFrame: return DataFrame(self._jdf.intersect(other._jdf), self.sparkSession) def intersectAll(self, other: ParentDataFrame) -> ParentDataFrame: return DataFrame(self._jdf.intersectAll(other._jdf), self.sparkSession) def subtract(self, other: ParentDataFrame) -> ParentDataFrame: return DataFrame(getattr(self._jdf, "except")(other._jdf), self.sparkSession) def dropDuplicates(self, subset: Optional[List[str]] = None) -> ParentDataFrame: if subset is not None and (not isinstance(subset, Iterable) or isinstance(subset, str)): raise PySparkTypeError( errorClass="NOT_LIST_OR_TUPLE", messageParameters={"arg_name": "subset", "arg_type": type(subset).__name__}, ) if subset is None: jdf = self._jdf.dropDuplicates() else: for c in subset: if not isinstance(c, str): raise PySparkTypeError( errorClass="NOT_STR", messageParameters={"arg_name": "subset", "arg_type": type(c).__name__}, ) jdf = self._jdf.dropDuplicates(self._jseq(subset)) return DataFrame(jdf, self.sparkSession) def dropDuplicatesWithinWatermark(self, subset: Optional[List[str]] = None) -> ParentDataFrame: if subset is not None and (not isinstance(subset, Iterable) or isinstance(subset, str)): raise PySparkTypeError( errorClass="NOT_LIST_OR_TUPLE", messageParameters={"arg_name": "subset", "arg_type": type(subset).__name__}, ) if subset is None: jdf = self._jdf.dropDuplicatesWithinWatermark() else: for c in subset: if not isinstance(c, str): raise PySparkTypeError( errorClass="NOT_STR", messageParameters={"arg_name": "subset", "arg_type": type(c).__name__}, ) jdf = self._jdf.dropDuplicatesWithinWatermark(self._jseq(subset)) return DataFrame(jdf, self.sparkSession) def dropna( self, how: str = "any", thresh: Optional[int] = None, subset: Optional[Union[str, Tuple[str, ...], List[str]]] = None, ) -> ParentDataFrame: if how is not None and how not in ["any", "all"]: raise PySparkValueError( errorClass="VALUE_NOT_ANY_OR_ALL", messageParameters={"arg_name": "how", "arg_type": how}, ) if subset is None: subset = self.columns elif isinstance(subset, str): subset = [subset] elif not isinstance(subset, (list, tuple)): raise PySparkTypeError( errorClass="NOT_LIST_OR_STR_OR_TUPLE", messageParameters={"arg_name": "subset", "arg_type": type(subset).__name__}, ) if thresh is None: thresh = len(subset) if how == "any" else 1 return DataFrame(self._jdf.na().drop(thresh, self._jseq(subset)), self.sparkSession) @overload def fillna( self, value: "LiteralType", subset: Optional[Union[str, Tuple[str, ...], List[str]]] = ..., ) -> ParentDataFrame: ... @overload def fillna(self, value: Dict[str, "LiteralType"]) -> ParentDataFrame: ... def fillna( self, value: Union["LiteralType", Dict[str, "LiteralType"]], subset: Optional[Union[str, Tuple[str, ...], List[str]]] = None, ) -> ParentDataFrame: if not isinstance(value, (float, int, str, bool, dict)): raise PySparkTypeError( errorClass="NOT_BOOL_OR_DICT_OR_FLOAT_OR_INT_OR_STR", messageParameters={"arg_name": "value", "arg_type": type(value).__name__}, ) # Note that bool validates isinstance(int), but we don't want to # convert bools to floats if not isinstance(value, bool) and isinstance(value, int): value = float(value) if isinstance(value, dict): return DataFrame(self._jdf.na().fill(value), self.sparkSession) elif subset is None: return DataFrame(self._jdf.na().fill(value), self.sparkSession) else: if isinstance(subset, str): subset = [subset] elif not isinstance(subset, (list, tuple)): raise PySparkTypeError( errorClass="NOT_LIST_OR_TUPLE", messageParameters={"arg_name": "subset", "arg_type": type(subset).__name__}, ) return DataFrame(self._jdf.na().fill(value, self._jseq(subset)), self.sparkSession) @overload def replace( self, to_replace: "LiteralType", value: "OptionalPrimitiveType", subset: Optional[List[str]] = ..., ) -> ParentDataFrame: ... @overload def replace( self, to_replace: List["LiteralType"], value: List["OptionalPrimitiveType"], subset: Optional[List[str]] = ..., ) -> ParentDataFrame: ... @overload def replace( self, to_replace: Dict["LiteralType", "OptionalPrimitiveType"], subset: Optional[List[str]] = ..., ) -> ParentDataFrame: ... @overload def replace( self, to_replace: List["LiteralType"], value: "OptionalPrimitiveType", subset: Optional[List[str]] = ..., ) -> ParentDataFrame: ... def replace( # type: ignore[misc] self, to_replace: Union[ "LiteralType", List["LiteralType"], Dict["LiteralType", "OptionalPrimitiveType"] ], value: Optional[ Union["OptionalPrimitiveType", List["OptionalPrimitiveType"], _NoValueType] ] = _NoValue, subset: Optional[List[str]] = None, ) -> ParentDataFrame: if value is _NoValue: if isinstance(to_replace, dict): value = None else: raise PySparkTypeError( errorClass="ARGUMENT_REQUIRED", messageParameters={"arg_name": "value", "condition": "`to_replace` is dict"}, ) # Helper functions def all_of(types: Union[Type, Tuple[Type, ...]]) -> Callable[[Iterable], bool]: """Given a type or tuple of types and a sequence of xs check if each x is instance of type(s) >>> all_of(bool)([True, False]) True >>> all_of(str)(["a", 1]) False """ def all_of_(xs: Iterable) -> bool: return all(isinstance(x, types) for x in xs) return all_of_ all_of_bool = all_of(bool) all_of_str = all_of(str) all_of_numeric = all_of((float, int)) # Validate input types valid_types = (bool, float, int, str, list, tuple) if not isinstance(to_replace, valid_types + (dict,)): raise PySparkTypeError( errorClass="NOT_BOOL_OR_DICT_OR_FLOAT_OR_INT_OR_LIST_OR_STR_OR_TUPLE", messageParameters={ "arg_name": "to_replace", "arg_type": type(to_replace).__name__, }, ) if ( not isinstance(value, valid_types) and value is not None and not isinstance(to_replace, dict) ): raise PySparkTypeError( errorClass="NOT_BOOL_OR_FLOAT_OR_INT_OR_LIST_OR_NONE_OR_STR_OR_TUPLE", messageParameters={ "arg_name": "value", "arg_type": type(value).__name__, }, ) if isinstance(to_replace, (list, tuple)) and isinstance(value, (list, tuple)): if len(to_replace) != len(value): raise PySparkValueError( errorClass="LENGTH_SHOULD_BE_THE_SAME", messageParameters={ "arg1": "to_replace", "arg2": "value", "arg1_length": str(len(to_replace)), "arg2_length": str(len(value)), }, ) if not (subset is None or isinstance(subset, (list, tuple, str))): raise PySparkTypeError( errorClass="NOT_LIST_OR_STR_OR_TUPLE", messageParameters={"arg_name": "subset", "arg_type": type(subset).__name__}, ) # Reshape input arguments if necessary if isinstance(to_replace, (float, int, str)): to_replace = [to_replace] if isinstance(to_replace, dict): rep_dict = to_replace if value is not None: warnings.warn("to_replace is a dict and value is not None. value will be ignored.") else: if isinstance(value, (float, int, str)) or value is None: value = [value for _ in range(len(to_replace))] rep_dict = dict(zip(to_replace, cast("Iterable[Optional[Union[float, str]]]", value))) if isinstance(subset, str): subset = [subset] # Verify we were not passed in mixed type generics. if not any( all_of_type(rep_dict.keys()) and all_of_type(x for x in rep_dict.values() if x is not None) for all_of_type in [all_of_bool, all_of_str, all_of_numeric] ): raise PySparkValueError( errorClass="MIXED_TYPE_REPLACEMENT", messageParameters={}, ) if subset is None: return DataFrame(self._jdf.na().replace("*", rep_dict), self.sparkSession) else: return DataFrame( self._jdf.na().replace(self._jseq(subset), self._jmap(rep_dict)), self.sparkSession, ) @overload def approxQuantile( self, col: str, probabilities: Union[List[float], Tuple[float]], relativeError: float, ) -> List[float]: ... @overload def approxQuantile( self, col: Union[List[str], Tuple[str]], probabilities: Union[List[float], Tuple[float]], relativeError: float, ) -> List[List[float]]: ... def approxQuantile( self, col: Union[str, List[str], Tuple[str]], probabilities: Union[List[float], Tuple[float]], relativeError: float, ) -> Union[List[float], List[List[float]]]: if not isinstance(col, (str, list, tuple)): raise PySparkTypeError( errorClass="NOT_LIST_OR_STR_OR_TUPLE", messageParameters={"arg_name": "col", "arg_type": type(col).__name__}, ) isStr = isinstance(col, str) if isinstance(col, tuple): col = list(col) elif isStr: col = [cast(str, col)] for c in col: if not isinstance(c, str): raise PySparkTypeError( errorClass="DISALLOWED_TYPE_FOR_CONTAINER", messageParameters={ "arg_name": "col", "arg_type": type(col).__name__, "allowed_types": "str", "item_type": type(c).__name__, }, ) col = _to_list(self._sc, cast(List["ColumnOrName"], col)) if not isinstance(probabilities, (list, tuple)): raise PySparkTypeError( errorClass="NOT_LIST_OR_TUPLE", messageParameters={ "arg_name": "probabilities", "arg_type": type(probabilities).__name__, }, ) if isinstance(probabilities, tuple): probabilities = list(probabilities) for p in probabilities: if not isinstance(p, (float, int)) or p < 0 or p > 1: raise PySparkTypeError( errorClass="NOT_LIST_OF_FLOAT_OR_INT", messageParameters={ "arg_name": "probabilities", "arg_type": type(p).__name__, }, ) probabilities = _to_list(self._sc, cast(List["ColumnOrName"], probabilities)) if not isinstance(relativeError, (float, int)): raise PySparkTypeError( errorClass="NOT_FLOAT_OR_INT", messageParameters={ "arg_name": "relativeError", "arg_type": type(relativeError).__name__, }, ) if relativeError < 0: raise PySparkValueError( errorClass="NEGATIVE_VALUE", messageParameters={ "arg_name": "relativeError", "arg_value": str(relativeError), }, ) relativeError = float(relativeError) jaq = self._jdf.stat().approxQuantile(col, probabilities, relativeError) jaq_list = [list(j) for j in jaq] return jaq_list[0] if isStr else jaq_list def corr(self, col1: str, col2: str, method: Optional[str] = None) -> float: if not isinstance(col1, str): raise PySparkTypeError( errorClass="NOT_STR", messageParameters={"arg_name": "col1", "arg_type": type(col1).__name__}, ) if not isinstance(col2, str): raise PySparkTypeError( errorClass="NOT_STR", messageParameters={"arg_name": "col2", "arg_type": type(col2).__name__}, ) if not method: method = "pearson" if not method == "pearson": raise PySparkValueError( errorClass="VALUE_NOT_PEARSON", messageParameters={"arg_name": "method", "arg_value": method}, ) return self._jdf.stat().corr(col1, col2, method) def cov(self, col1: str, col2: str) -> float: if not isinstance(col1, str): raise PySparkTypeError( errorClass="NOT_STR", messageParameters={"arg_name": "col1", "arg_type": type(col1).__name__}, ) if not isinstance(col2, str): raise PySparkTypeError( errorClass="NOT_STR", messageParameters={"arg_name": "col2", "arg_type": type(col2).__name__}, ) return self._jdf.stat().cov(col1, col2) def crosstab(self, col1: str, col2: str) -> ParentDataFrame: if not isinstance(col1, str): raise PySparkTypeError( errorClass="NOT_STR", messageParameters={"arg_name": "col1", "arg_type": type(col1).__name__}, ) if not isinstance(col2, str): raise PySparkTypeError( errorClass="NOT_STR", messageParameters={"arg_name": "col2", "arg_type": type(col2).__name__}, ) return DataFrame(self._jdf.stat().crosstab(col1, col2), self.sparkSession) def freqItems( self, cols: Union[List[str], Tuple[str]], support: Optional[float] = None ) -> ParentDataFrame: if isinstance(cols, tuple): cols = list(cols) if not isinstance(cols, list): raise PySparkTypeError( errorClass="NOT_LIST_OR_TUPLE", messageParameters={"arg_name": "cols", "arg_type": type(cols).__name__}, ) if not support: support = 0.01 return DataFrame( self._jdf.stat().freqItems(_to_seq(self._sc, cols), support), self.sparkSession ) def _ipython_key_completions_(self) -> List[str]: """Returns the names of columns in this :class:`DataFrame`. Examples -------- >>> df = spark.createDataFrame([(2, "Alice"), (5, "Bob")], ["age", "name"]) >>> df._ipython_key_completions_() ['age', 'name'] Would return illegal identifiers. >>> df = spark.createDataFrame([(2, "Alice"), (5, "Bob")], ["age 1", "name?1"]) >>> df._ipython_key_completions_() ['age 1', 'name?1'] """ return self.columns def withColumns(self, *colsMap: Dict[str, Column]) -> ParentDataFrame: # Below code is to help enable kwargs in future. assert len(colsMap) == 1 colsMap = colsMap[0] # type: ignore[assignment] if not isinstance(colsMap, dict): raise PySparkTypeError( errorClass="NOT_DICT", messageParameters={"arg_name": "colsMap", "arg_type": type(colsMap).__name__}, ) col_names = list(colsMap.keys()) cols = list(colsMap.values()) return DataFrame( self._jdf.withColumns(_to_seq(self._sc, col_names), self._jcols(*cols)), self.sparkSession, ) def withColumn(self, colName: str, col: Column) -> ParentDataFrame: if not isinstance(col, Column): raise PySparkTypeError( errorClass="NOT_COLUMN", messageParameters={"arg_name": "col", "arg_type": type(col).__name__}, ) return DataFrame(self._jdf.withColumn(colName, col._jc), self.sparkSession) def withColumnRenamed(self, existing: str, new: str) -> ParentDataFrame: return DataFrame(self._jdf.withColumnRenamed(existing, new), self.sparkSession) def withColumnsRenamed(self, colsMap: Dict[str, str]) -> ParentDataFrame: if not isinstance(colsMap, dict): raise PySparkTypeError( errorClass="NOT_DICT", messageParameters={"arg_name": "colsMap", "arg_type": type(colsMap).__name__}, ) col_names: List[str] = [] new_col_names: List[str] = [] for k, v in colsMap.items(): col_names.append(k) new_col_names.append(v) return DataFrame( self._jdf.withColumnsRenamed( _to_seq(self._sc, col_names), _to_seq(self._sc, new_col_names) ), self.sparkSession, ) def withMetadata(self, columnName: str, metadata: Dict[str, Any]) -> ParentDataFrame: from py4j.java_gateway import JVMView if not isinstance(metadata, dict): raise PySparkTypeError( errorClass="NOT_DICT", messageParameters={"arg_name": "metadata", "arg_type": type(metadata).__name__}, ) sc = get_active_spark_context() jmeta = cast(JVMView, sc._jvm).org.apache.spark.sql.types.Metadata.fromJson( json.dumps(metadata) ) return DataFrame(self._jdf.withMetadata(columnName, jmeta), self.sparkSession) @overload def drop(self, cols: "ColumnOrName") -> ParentDataFrame: ... @overload def drop(self, *cols: str) -> ParentDataFrame: ... def drop(self, *cols: "ColumnOrName") -> ParentDataFrame: # type: ignore[misc] column_names: List[str] = [] java_columns: List["JavaObject"] = [] for c in cols: if isinstance(c, str): column_names.append(c) elif isinstance(c, Column): java_columns.append(c._jc) else: raise PySparkTypeError( errorClass="NOT_COLUMN_OR_STR", messageParameters={"arg_name": "col", "arg_type": type(c).__name__}, ) jdf = self._jdf if len(java_columns) > 0: first_column, *remaining_columns = java_columns jdf = jdf.drop(first_column, self._jseq(remaining_columns)) if len(column_names) > 0: jdf = jdf.drop(self._jseq(column_names)) return DataFrame(jdf, self.sparkSession) def toDF(self, *cols: str) -> ParentDataFrame: for col in cols: if not isinstance(col, str): raise PySparkTypeError( errorClass="NOT_LIST_OF_STR", messageParameters={"arg_name": "cols", "arg_type": type(col).__name__}, ) jdf = self._jdf.toDF(self._jseq(cols)) return DataFrame(jdf, self.sparkSession) def transform( self, func: Callable[..., ParentDataFrame], *args: Any, **kwargs: Any ) -> ParentDataFrame: result = func(self, *args, **kwargs) assert isinstance( result, DataFrame ), "Func returned an instance of type [%s], " "should have been DataFrame." % type(result) return result def sameSemantics(self, other: ParentDataFrame) -> bool: if not isinstance(other, DataFrame): raise PySparkTypeError( errorClass="NOT_DATAFRAME", messageParameters={"arg_name": "other", "arg_type": type(other).__name__}, ) return self._jdf.sameSemantics(other._jdf) def semanticHash(self) -> int: return self._jdf.semanticHash() def inputFiles(self) -> List[str]: return list(self._jdf.inputFiles()) def where(self, condition: Union[Column, str]) -> ParentDataFrame: return self.filter(condition) # Two aliases below were added for pandas compatibility many years ago. # There are too many differences compared to pandas and we cannot just # make it "compatible" by adding aliases. Therefore, we stop adding such # aliases as of Spark 3.0. Two methods below remain just # for legacy users currently. @overload def groupby(self, *cols: "ColumnOrNameOrOrdinal") -> "GroupedData": ... @overload def groupby(self, __cols: Union[List[Column], List[str], List[int]]) -> "GroupedData": ... def groupby(self, *cols: "ColumnOrNameOrOrdinal") -> "GroupedData": # type: ignore[misc] return self.groupBy(*cols) def drop_duplicates(self, subset: Optional[List[str]] = None) -> ParentDataFrame: return self.dropDuplicates(subset) def writeTo(self, table: str) -> "DataFrameWriterV2": return DataFrameWriterV2(self, table) def mergeInto(self, table: str, condition: Column) -> "MergeIntoWriter": return MergeIntoWriter(self, table, condition) def pandas_api( self, index_col: Optional[Union[str, List[str]]] = None ) -> "PandasOnSparkDataFrame": from pyspark.pandas.namespace import _get_index_map from pyspark.pandas.frame import DataFrame as PandasOnSparkDataFrame from pyspark.pandas.internal import InternalFrame index_spark_columns, index_names = _get_index_map(self, index_col) internal = InternalFrame( spark_frame=self, index_spark_columns=index_spark_columns, index_names=index_names, # type: ignore[arg-type] ) return PandasOnSparkDataFrame(internal) def mapInPandas( self, func: "PandasMapIterFunction", schema: Union[StructType, str], barrier: bool = False, profile: Optional[ResourceProfile] = None, ) -> ParentDataFrame: return PandasMapOpsMixin.mapInPandas(self, func, schema, barrier, profile) def mapInArrow( self, func: "ArrowMapIterFunction", schema: Union[StructType, str], barrier: bool = False, profile: Optional[ResourceProfile] = None, ) -> ParentDataFrame: return PandasMapOpsMixin.mapInArrow(self, func, schema, barrier, profile) def toArrow(self) -> "pa.Table": return PandasConversionMixin.toArrow(self) def toPandas(self) -> "PandasDataFrameLike": return PandasConversionMixin.toPandas(self) def transpose(self, indexColumn: Optional["ColumnOrName"] = None) -> ParentDataFrame: if indexColumn is not None: return DataFrame(self._jdf.transpose(_to_java_column(indexColumn)), self.sparkSession) else: return DataFrame(self._jdf.transpose(), self.sparkSession) def asTable(self) -> TableArg: from pyspark.sql.classic.table_arg import TableArg as ClassicTableArg return ClassicTableArg(self._jdf.asTable()) def scalar(self) -> Column: return Column(self._jdf.scalar()) def exists(self) -> Column: return Column(self._jdf.exists()) @property def executionInfo(self) -> Optional["ExecutionInfo"]: raise PySparkValueError( errorClass="CLASSIC_OPERATION_NOT_SUPPORTED_ON_DF", messageParameters={"member": "queryExecution"}, ) @property def plot(self) -> "PySparkPlotAccessor": # type: ignore[name-defined] # noqa: F821 from pyspark.sql.plot import PySparkPlotAccessor return PySparkPlotAccessor(self) class DataFrameNaFunctions(ParentDataFrameNaFunctions): def __init__(self, df: ParentDataFrame): self.df = df def drop( self, how: str = "any", thresh: Optional[int] = None, subset: Optional[Union[str, Tuple[str, ...], List[str]]] = None, ) -> ParentDataFrame: return self.df.dropna(how=how, thresh=thresh, subset=subset) @overload def fill(self, value: "LiteralType", subset: Optional[List[str]] = ...) -> ParentDataFrame: ... @overload def fill(self, value: Dict[str, "LiteralType"]) -> ParentDataFrame: ... def fill( self, value: Union["LiteralType", Dict[str, "LiteralType"]], subset: Optional[List[str]] = None, ) -> ParentDataFrame: return self.df.fillna(value=value, subset=subset) # type: ignore[arg-type] @overload def replace( self, to_replace: List["LiteralType"], value: List["OptionalPrimitiveType"], subset: Optional[List[str]] = ..., ) -> ParentDataFrame: ... @overload def replace( self, to_replace: Dict["LiteralType", "OptionalPrimitiveType"], subset: Optional[List[str]] = ..., ) -> ParentDataFrame: ... @overload def replace( self, to_replace: List["LiteralType"], value: "OptionalPrimitiveType", subset: Optional[List[str]] = ..., ) -> ParentDataFrame: ... def replace( # type: ignore[misc] self, to_replace: Union[List["LiteralType"], Dict["LiteralType", "OptionalPrimitiveType"]], value: Optional[ Union["OptionalPrimitiveType", List["OptionalPrimitiveType"], _NoValueType] ] = _NoValue, subset: Optional[List[str]] = None, ) -> ParentDataFrame: return self.df.replace(to_replace, value, subset) # type: ignore[arg-type] class DataFrameStatFunctions(ParentDataFrameStatFunctions): def __init__(self, df: ParentDataFrame): self.df = df @overload def approxQuantile( self, col: str, probabilities: Union[List[float], Tuple[float]], relativeError: float, ) -> List[float]: ... @overload def approxQuantile( self, col: Union[List[str], Tuple[str]], probabilities: Union[List[float], Tuple[float]], relativeError: float, ) -> List[List[float]]: ... def approxQuantile( self, col: Union[str, List[str], Tuple[str]], probabilities: Union[List[float], Tuple[float]], relativeError: float, ) -> Union[List[float], List[List[float]]]: return self.df.approxQuantile(col, probabilities, relativeError) def corr(self, col1: str, col2: str, method: Optional[str] = None) -> float: return self.df.corr(col1, col2, method) def cov(self, col1: str, col2: str) -> float: return self.df.cov(col1, col2) def crosstab(self, col1: str, col2: str) -> ParentDataFrame: return self.df.crosstab(col1, col2) def freqItems(self, cols: List[str], support: Optional[float] = None) -> ParentDataFrame: return self.df.freqItems(cols, support) def sampleBy( self, col: str, fractions: Dict[Any, float], seed: Optional[int] = None ) -> ParentDataFrame: return self.df.sampleBy(col, fractions, seed) def _test() -> None: import doctest from pyspark.sql import SparkSession import pyspark.sql.dataframe # It inherits docstrings but doctests cannot detect them so we run # the parent classe's doctests here directly. globs = pyspark.sql.dataframe.__dict__.copy() spark = ( SparkSession.builder.master("local[4]").appName("sql.classic.dataframe tests").getOrCreate() ) globs["spark"] = spark (failure_count, test_count) = doctest.testmod( pyspark.sql.dataframe, globs=globs, optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF, ) spark.stop() if failure_count: sys.exit(-1) if __name__ == "__main__": _test()