python/pyspark/sql/connect/expressions.py (1,039 lines of code) (raw):
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from pyspark.sql.connect.utils import check_dependencies
check_dependencies(__name__)
from typing import (
cast,
TYPE_CHECKING,
Any,
Callable,
Union,
Sequence,
Tuple,
Optional,
)
import json
import decimal
import datetime
import warnings
from threading import Lock
import numpy as np
from pyspark.serializers import CloudPickleSerializer
from pyspark.sql.types import (
_from_numpy_type,
DateType,
ArrayType,
NullType,
BooleanType,
BinaryType,
ByteType,
ShortType,
IntegerType,
LongType,
FloatType,
DoubleType,
DecimalType,
StringType,
DataType,
TimestampType,
TimestampNTZType,
DayTimeIntervalType,
)
import pyspark.sql.connect.proto as proto
from pyspark.util import (
JVM_BYTE_MIN,
JVM_BYTE_MAX,
JVM_SHORT_MIN,
JVM_SHORT_MAX,
JVM_INT_MIN,
JVM_INT_MAX,
JVM_LONG_MIN,
JVM_LONG_MAX,
)
from pyspark.sql.connect.types import (
UnparsedDataType,
pyspark_types_to_proto_types,
proto_schema_to_pyspark_data_type,
)
from pyspark.errors import PySparkTypeError, PySparkValueError
from pyspark.errors.utils import current_origin
from pyspark.sql.utils import is_timestamp_ntz_preferred, enum_to_value
if TYPE_CHECKING:
from pyspark.sql.connect.client import SparkConnectClient
from pyspark.sql.connect.window import WindowSpec
from pyspark.sql.connect.plan import LogicalPlan
class Expression:
"""
Expression base class.
"""
def __init__(self) -> None:
origin = current_origin()
fragment = origin.fragment
call_site = origin.call_site
self.origin = None
if fragment is not None and call_site is not None:
self.origin = proto.Origin(
python_origin=proto.PythonOrigin(
fragment=origin.fragment, call_site=origin.call_site
)
)
def to_plan( # type: ignore[empty-body]
self, session: "SparkConnectClient"
) -> "proto.Expression":
...
def __repr__(self) -> str: # type: ignore[empty-body]
...
def alias(self, *alias: str, **kwargs: Any) -> "ColumnAlias":
metadata = kwargs.pop("metadata", None)
if len(alias) > 1 and metadata is not None:
raise PySparkValueError(
errorClass="ONLY_ALLOWED_FOR_SINGLE_COLUMN",
messageParameters={"arg_name": "metadata"},
)
assert not kwargs, "Unexpected kwargs where passed: %s" % kwargs
return ColumnAlias(self, list(alias), metadata)
def name(self) -> str: # type: ignore[empty-body]
...
def _create_proto_expression(self) -> proto.Expression:
plan = proto.Expression()
if self.origin is not None:
plan.common.origin.CopyFrom(self.origin)
return plan
@property
def children(self) -> Sequence["Expression"]:
return []
def foreach(self, f: Callable[["Expression"], None]) -> None:
f(self)
for c in self.children:
c.foreach(f)
class CaseWhen(Expression):
def __init__(
self, branches: Sequence[Tuple[Expression, Expression]], else_value: Optional[Expression]
):
super().__init__()
assert isinstance(branches, list)
for branch in branches:
assert (
isinstance(branch, tuple)
and len(branch) == 2
and all(isinstance(expr, Expression) for expr in branch)
)
self._branches = branches
if else_value is not None:
assert isinstance(else_value, Expression)
self._else_value = else_value
def to_plan(self, session: "SparkConnectClient") -> "proto.Expression":
args = []
for condition, value in self._branches:
args.append(condition)
args.append(value)
if self._else_value is not None:
args.append(self._else_value)
unresolved_function = UnresolvedFunction(name="when", args=args)
return unresolved_function.to_plan(session)
@property
def children(self) -> Sequence["Expression"]:
children = []
for branch in self._branches:
children.append(branch[0])
children.append(branch[1])
if self._else_value is not None:
children.append(self._else_value)
return children
def __repr__(self) -> str:
_cases = "".join([f" WHEN {c} THEN {v}" for c, v in self._branches])
_else = f" ELSE {self._else_value}" if self._else_value is not None else ""
return "CASE" + _cases + _else + " END"
class ColumnAlias(Expression):
def __init__(self, child: Expression, alias: Sequence[str], metadata: Any):
super().__init__()
self._alias = alias
self._metadata = metadata
self._child = child
def to_plan(self, session: "SparkConnectClient") -> "proto.Expression":
if len(self._alias) == 1:
exp = self._create_proto_expression()
exp.alias.name.append(self._alias[0])
exp.alias.expr.CopyFrom(self._child.to_plan(session))
if self._metadata:
exp.alias.metadata = json.dumps(self._metadata)
return exp
else:
if self._metadata:
raise PySparkValueError(
errorClass="CANNOT_PROVIDE_METADATA",
messageParameters={},
)
exp = self._create_proto_expression()
exp.alias.name.extend(self._alias)
exp.alias.expr.CopyFrom(self._child.to_plan(session))
return exp
@property
def children(self) -> Sequence["Expression"]:
return [self._child]
def __repr__(self) -> str:
return f"{self._child} AS {','.join(self._alias)}"
class LiteralExpression(Expression):
"""A literal expression.
The Python types are converted best effort into the relevant proto types. On the Spark Connect
server side, the proto types are converted to the Catalyst equivalents."""
def __init__(self, value: Any, dataType: DataType) -> None:
super().__init__()
assert isinstance(
dataType,
(
NullType,
BinaryType,
BooleanType,
ByteType,
ShortType,
IntegerType,
LongType,
FloatType,
DoubleType,
DecimalType,
StringType,
DateType,
TimestampType,
TimestampNTZType,
DayTimeIntervalType,
ArrayType,
),
)
value = enum_to_value(value)
if isinstance(dataType, NullType):
assert value is None
if value is not None:
if isinstance(dataType, BinaryType):
assert isinstance(value, (bytes, bytearray))
elif isinstance(dataType, BooleanType):
assert isinstance(value, (bool, np.bool_))
value = bool(value)
elif isinstance(dataType, ByteType):
assert isinstance(value, (int, np.int8))
assert JVM_BYTE_MIN <= int(value) <= JVM_BYTE_MAX
value = int(value)
elif isinstance(dataType, ShortType):
assert isinstance(value, (int, np.int8, np.int16))
assert JVM_SHORT_MIN <= int(value) <= JVM_SHORT_MAX
value = int(value)
elif isinstance(dataType, IntegerType):
assert isinstance(value, (int, np.int8, np.int16, np.int32))
assert JVM_INT_MIN <= int(value) <= JVM_INT_MAX
value = int(value)
elif isinstance(dataType, LongType):
assert isinstance(value, (int, np.int8, np.int16, np.int32, np.int64))
assert JVM_LONG_MIN <= int(value) <= JVM_LONG_MAX
value = int(value)
elif isinstance(dataType, FloatType):
assert isinstance(value, (float, np.float32))
value = float(value)
elif isinstance(dataType, DoubleType):
assert isinstance(value, (float, np.float32, np.float64))
value = float(value)
elif isinstance(dataType, DecimalType):
assert isinstance(value, decimal.Decimal)
elif isinstance(dataType, StringType):
assert isinstance(value, (str, np.str_))
value = str(value)
elif isinstance(dataType, DateType):
assert isinstance(value, (datetime.date, datetime.datetime))
if isinstance(value, datetime.date):
value = DateType().toInternal(value)
else:
value = DateType().toInternal(value.date())
elif isinstance(dataType, TimestampType):
assert isinstance(value, datetime.datetime)
value = TimestampType().toInternal(value)
elif isinstance(dataType, TimestampNTZType):
assert isinstance(value, datetime.datetime)
value = TimestampNTZType().toInternal(value)
elif isinstance(dataType, DayTimeIntervalType):
assert isinstance(value, datetime.timedelta)
value = DayTimeIntervalType().toInternal(value)
assert value is not None
elif isinstance(dataType, ArrayType):
assert isinstance(value, list)
else:
raise PySparkTypeError(
errorClass="UNSUPPORTED_DATA_TYPE",
messageParameters={"data_type": str(dataType)},
)
self._value = value
self._dataType = dataType
@classmethod
def _infer_type(cls, value: Any) -> DataType:
value = enum_to_value(value)
if value is None:
return NullType()
elif isinstance(value, (bytes, bytearray)):
return BinaryType()
elif isinstance(value, (bool, np.bool_)):
return BooleanType()
elif isinstance(value, int):
if JVM_INT_MIN <= value <= JVM_INT_MAX:
return IntegerType()
elif JVM_LONG_MIN <= value <= JVM_LONG_MAX:
return LongType()
else:
raise PySparkValueError(
errorClass="VALUE_NOT_BETWEEN",
messageParameters={
"arg_name": "value",
"min": str(JVM_LONG_MIN),
"max": str(JVM_SHORT_MAX),
},
)
elif isinstance(value, float):
return DoubleType()
elif isinstance(value, (str, np.str_)):
return StringType()
elif isinstance(value, decimal.Decimal):
return DecimalType()
elif isinstance(value, datetime.datetime):
return TimestampNTZType() if is_timestamp_ntz_preferred() else TimestampType()
elif isinstance(value, datetime.date):
return DateType()
elif isinstance(value, datetime.timedelta):
return DayTimeIntervalType()
elif isinstance(value, np.generic):
dt = _from_numpy_type(value.dtype)
if dt is not None:
return dt
elif isinstance(value, list):
# follow the 'infer_array_from_first_element' strategy in 'sql.types._infer_type'
# right now, it's dedicated for pyspark.ml params like array<...>, array<array<...>>
if len(value) == 0 or value[0] is None:
raise PySparkTypeError(
errorClass="CANNOT_INFER_ARRAY_ELEMENT_TYPE",
messageParameters={},
)
return ArrayType(LiteralExpression._infer_type(value[0]), True)
raise PySparkTypeError(
errorClass="UNSUPPORTED_DATA_TYPE",
messageParameters={"data_type": type(value).__name__},
)
@classmethod
def _from_value(cls, value: Any) -> "LiteralExpression":
return LiteralExpression(value=value, dataType=LiteralExpression._infer_type(value))
@classmethod
def _to_value(
cls, literal: "proto.Expression.Literal", dataType: Optional[DataType] = None
) -> Any:
if literal.HasField("null"):
return None
elif literal.HasField("binary"):
assert dataType is None or isinstance(dataType, BinaryType)
return literal.binary
elif literal.HasField("boolean"):
assert dataType is None or isinstance(dataType, BooleanType)
return literal.boolean
elif literal.HasField("byte"):
assert dataType is None or isinstance(dataType, ByteType)
return literal.byte
elif literal.HasField("short"):
assert dataType is None or isinstance(dataType, ShortType)
return literal.short
elif literal.HasField("integer"):
assert dataType is None or isinstance(dataType, IntegerType)
return literal.integer
elif literal.HasField("long"):
assert dataType is None or isinstance(dataType, LongType)
return literal.long
elif literal.HasField("float"):
assert dataType is None or isinstance(dataType, FloatType)
return literal.float
elif literal.HasField("double"):
assert dataType is None or isinstance(dataType, DoubleType)
return literal.double
elif literal.HasField("decimal"):
assert dataType is None or isinstance(dataType, DecimalType)
return decimal.Decimal(literal.decimal.value)
elif literal.HasField("string"):
assert dataType is None or isinstance(dataType, StringType)
return literal.string
elif literal.HasField("date"):
assert dataType is None or isinstance(dataType, DataType)
return DateType().fromInternal(literal.date)
elif literal.HasField("timestamp"):
assert dataType is None or isinstance(dataType, TimestampType)
return TimestampType().fromInternal(literal.timestamp)
elif literal.HasField("timestamp_ntz"):
assert dataType is None or isinstance(dataType, TimestampNTZType)
return TimestampNTZType().fromInternal(literal.timestamp_ntz)
elif literal.HasField("day_time_interval"):
assert dataType is None or isinstance(dataType, DayTimeIntervalType)
return DayTimeIntervalType().fromInternal(literal.day_time_interval)
elif literal.HasField("array"):
elementType = proto_schema_to_pyspark_data_type(literal.array.element_type)
if dataType is not None:
assert isinstance(dataType, ArrayType)
assert elementType == dataType.elementType
return [LiteralExpression._to_value(v, elementType) for v in literal.array.elements]
raise PySparkTypeError(
errorClass="UNSUPPORTED_LITERAL",
messageParameters={"literal": str(literal)},
)
def to_plan(self, session: "SparkConnectClient") -> "proto.Expression":
"""Converts the literal expression to the literal in proto."""
expr = self._create_proto_expression()
if self._value is None:
expr.literal.null.CopyFrom(pyspark_types_to_proto_types(self._dataType))
elif isinstance(self._dataType, BinaryType):
expr.literal.binary = bytes(self._value)
elif isinstance(self._dataType, BooleanType):
expr.literal.boolean = bool(self._value)
elif isinstance(self._dataType, ByteType):
expr.literal.byte = int(self._value)
elif isinstance(self._dataType, ShortType):
expr.literal.short = int(self._value)
elif isinstance(self._dataType, IntegerType):
expr.literal.integer = int(self._value)
elif isinstance(self._dataType, LongType):
expr.literal.long = int(self._value)
elif isinstance(self._dataType, FloatType):
expr.literal.float = float(self._value)
elif isinstance(self._dataType, DoubleType):
expr.literal.double = float(self._value)
elif isinstance(self._dataType, DecimalType):
expr.literal.decimal.value = str(self._value)
expr.literal.decimal.precision = self._dataType.precision
expr.literal.decimal.scale = self._dataType.scale
elif isinstance(self._dataType, StringType):
expr.literal.string = str(self._value)
elif isinstance(self._dataType, DateType):
expr.literal.date = int(self._value)
elif isinstance(self._dataType, TimestampType):
expr.literal.timestamp = int(self._value)
elif isinstance(self._dataType, TimestampNTZType):
expr.literal.timestamp_ntz = int(self._value)
elif isinstance(self._dataType, DayTimeIntervalType):
expr.literal.day_time_interval = int(self._value)
elif isinstance(self._dataType, ArrayType):
element_type = self._dataType.elementType
expr.literal.array.element_type.CopyFrom(pyspark_types_to_proto_types(element_type))
for v in self._value:
expr.literal.array.elements.append(
LiteralExpression(v, element_type).to_plan(session).literal
)
else:
raise PySparkTypeError(
errorClass="UNSUPPORTED_DATA_TYPE",
messageParameters={"data_type": str(self._dataType)},
)
return expr
def __repr__(self) -> str:
if self._value is None:
return "NULL"
elif isinstance(self._dataType, DateType):
dt = DateType().fromInternal(self._value)
if dt is not None and isinstance(dt, datetime.date):
return dt.strftime("%Y-%m-%d")
elif isinstance(self._dataType, TimestampType):
ts = TimestampType().fromInternal(self._value)
if ts is not None and isinstance(ts, datetime.datetime):
return ts.strftime("%Y-%m-%d %H:%M:%S.%f")
elif isinstance(self._dataType, TimestampNTZType):
ts = TimestampNTZType().fromInternal(self._value)
if ts is not None and isinstance(ts, datetime.datetime):
return ts.strftime("%Y-%m-%d %H:%M:%S.%f")
elif isinstance(self._dataType, DayTimeIntervalType):
delta = DayTimeIntervalType().fromInternal(self._value)
if delta is not None and isinstance(delta, datetime.timedelta):
import pandas as pd
# Note: timedelta itself does not provide isoformat method.
# Both Pandas and java.time.Duration provide it, but the format
# is sightly different:
# java.time.Duration only applies HOURS, MINUTES, SECONDS units,
# while Pandas applies all supported units.
return pd.Timedelta(delta).isoformat()
return f"{self._value}"
class ColumnReference(Expression):
"""Represents a column reference. There is no guarantee that this column
actually exists. In the context of this project, we refer by its name and
treat it as an unresolved attribute. Attributes that have the same fully
qualified name are identical"""
def __init__(
self,
unparsed_identifier: str,
plan_id: Optional[int] = None,
is_metadata_column: bool = False,
) -> None:
super().__init__()
assert isinstance(unparsed_identifier, str)
self._unparsed_identifier = unparsed_identifier
assert plan_id is None or isinstance(plan_id, int)
self._plan_id = plan_id
self._is_metadata_column = is_metadata_column
def name(self) -> str:
"""Returns the qualified name of the column reference."""
return self._unparsed_identifier
def to_plan(self, session: "SparkConnectClient") -> proto.Expression:
"""Returns the Proto representation of the expression."""
expr = self._create_proto_expression()
expr.unresolved_attribute.unparsed_identifier = self._unparsed_identifier
if self._plan_id is not None:
expr.unresolved_attribute.plan_id = self._plan_id
expr.unresolved_attribute.is_metadata_column = self._is_metadata_column
return expr
def __repr__(self) -> str:
return f"{self._unparsed_identifier}"
def __eq__(self, other: Any) -> bool:
return (
other is not None
and isinstance(other, ColumnReference)
and other._unparsed_identifier == self._unparsed_identifier
)
class UnresolvedStar(Expression):
def __init__(self, unparsed_target: Optional[str], plan_id: Optional[int] = None):
super().__init__()
if unparsed_target is not None:
assert isinstance(unparsed_target, str) and unparsed_target.endswith(".*")
self._unparsed_target = unparsed_target
assert plan_id is None or isinstance(plan_id, int)
self._plan_id = plan_id
def to_plan(self, session: "SparkConnectClient") -> "proto.Expression":
expr = self._create_proto_expression()
expr.unresolved_star.SetInParent()
if self._unparsed_target is not None:
expr.unresolved_star.unparsed_target = self._unparsed_target
if self._plan_id is not None:
expr.unresolved_star.plan_id = self._plan_id
return expr
def __repr__(self) -> str:
if self._unparsed_target is not None:
return f"unresolvedstar({self._unparsed_target})"
else:
return "unresolvedstar()"
def __eq__(self, other: Any) -> bool:
return (
other is not None
and isinstance(other, UnresolvedStar)
and other._unparsed_target == self._unparsed_target
)
class SQLExpression(Expression):
"""Returns Expression which contains a string which is a SQL expression
and server side will parse it by Catalyst
"""
def __init__(self, expr: str) -> None:
super().__init__()
assert isinstance(expr, str)
self._expr: str = expr
def to_plan(self, session: "SparkConnectClient") -> proto.Expression:
"""Returns the Proto representation of the SQL expression."""
expr = self._create_proto_expression()
expr.expression_string.expression = self._expr
return expr
def __eq__(self, other: Any) -> bool:
return other is not None and isinstance(other, SQLExpression) and other._expr == self._expr
def __repr__(self) -> str:
return self._expr
class SortOrder(Expression):
def __init__(self, child: Expression, ascending: bool = True, nullsFirst: bool = True) -> None:
super().__init__()
self._child = child
self._ascending = ascending
self._nullsFirst = nullsFirst
def __repr__(self) -> str:
return (
str(self._child)
+ (" ASC" if self._ascending else " DESC")
+ (" NULLS FIRST" if self._nullsFirst else " NULLS LAST")
)
def to_plan(self, session: "SparkConnectClient") -> proto.Expression:
sort = self._create_proto_expression()
sort.sort_order.child.CopyFrom(self._child.to_plan(session))
if self._ascending:
sort.sort_order.direction = (
proto.Expression.SortOrder.SortDirection.SORT_DIRECTION_ASCENDING
)
else:
sort.sort_order.direction = (
proto.Expression.SortOrder.SortDirection.SORT_DIRECTION_DESCENDING
)
if self._nullsFirst:
sort.sort_order.null_ordering = proto.Expression.SortOrder.NullOrdering.SORT_NULLS_FIRST
else:
sort.sort_order.null_ordering = proto.Expression.SortOrder.NullOrdering.SORT_NULLS_LAST
return sort
@property
def children(self) -> Sequence["Expression"]:
return [self._child]
class UnresolvedFunction(Expression):
def __init__(
self,
name: str,
args: Sequence["Expression"],
is_distinct: bool = False,
) -> None:
super().__init__()
assert isinstance(name, str)
self._name = name
assert isinstance(args, list) and all(isinstance(arg, Expression) for arg in args)
self._args = args
assert isinstance(is_distinct, bool)
self._is_distinct = is_distinct
def to_plan(self, session: "SparkConnectClient") -> proto.Expression:
fun = self._create_proto_expression()
fun.unresolved_function.function_name = self._name
if len(self._args) > 0:
fun.unresolved_function.arguments.extend([arg.to_plan(session) for arg in self._args])
fun.unresolved_function.is_distinct = self._is_distinct
return fun
@property
def children(self) -> Sequence["Expression"]:
return self._args
def __repr__(self) -> str:
# Default print handling:
if self._is_distinct:
return f"{self._name}(distinct {', '.join([str(arg) for arg in self._args])})"
else:
return f"{self._name}({', '.join([str(arg) for arg in self._args])})"
class PythonUDF:
"""Represents a Python user-defined function."""
def __init__(
self,
output_type: Union[DataType, str],
eval_type: int,
func: Callable[..., Any],
python_ver: str,
) -> None:
self._output_type: DataType = (
UnparsedDataType(output_type) if isinstance(output_type, str) else output_type
)
self._eval_type = eval_type
self._func = func
self._python_ver = python_ver
def to_plan(self, session: "SparkConnectClient") -> proto.PythonUDF:
if isinstance(self._output_type, UnparsedDataType):
parsed = session._analyze(
method="ddl_parse", ddl_string=self._output_type.data_type_string
).parsed
assert isinstance(parsed, DataType)
output_type = parsed
else:
output_type = self._output_type
expr = proto.PythonUDF()
expr.output_type.CopyFrom(pyspark_types_to_proto_types(output_type))
expr.eval_type = self._eval_type
expr.command = CloudPickleSerializer().dumps((self._func, output_type))
expr.python_ver = self._python_ver
return expr
def __repr__(self) -> str:
return f"{self._output_type}, {self._eval_type}, {self._func}, f{self._python_ver}"
class JavaUDF:
"""Represents a Java (aggregate) user-defined function."""
def __init__(
self,
class_name: str,
output_type: Optional[Union[DataType, str]] = None,
aggregate: bool = False,
) -> None:
self._class_name = class_name
self._output_type: Optional[DataType] = (
UnparsedDataType(output_type) if isinstance(output_type, str) else output_type
)
self._aggregate = aggregate
def to_plan(self, session: "SparkConnectClient") -> proto.JavaUDF:
expr = proto.JavaUDF()
expr.class_name = self._class_name
if self._output_type is not None:
expr.output_type.CopyFrom(pyspark_types_to_proto_types(self._output_type))
expr.aggregate = self._aggregate
return expr
def __repr__(self) -> str:
return f"{self._class_name}, {self._output_type}"
class CommonInlineUserDefinedFunction(Expression):
"""Represents a user-defined function with an inlined defined function body of any programming
languages."""
def __init__(
self,
function_name: str,
function: Union[PythonUDF, JavaUDF],
deterministic: bool = False,
arguments: Optional[Sequence[Expression]] = None,
):
super().__init__()
self._function_name = function_name
self._deterministic = deterministic
self._arguments: Sequence[Expression] = arguments or []
self._function = function
def to_plan(self, session: "SparkConnectClient") -> "proto.Expression":
expr = self._create_proto_expression()
expr.common_inline_user_defined_function.function_name = self._function_name
expr.common_inline_user_defined_function.deterministic = self._deterministic
if len(self._arguments) > 0:
expr.common_inline_user_defined_function.arguments.extend(
[arg.to_plan(session) for arg in self._arguments]
)
expr.common_inline_user_defined_function.python_udf.CopyFrom(
cast(proto.PythonUDF, self._function.to_plan(session))
)
return expr
def to_plan_udf(self, session: "SparkConnectClient") -> "proto.CommonInlineUserDefinedFunction":
"""Compared to `to_plan`, it returns a CommonInlineUserDefinedFunction instead of an
Expression."""
expr = proto.CommonInlineUserDefinedFunction()
expr.function_name = self._function_name
expr.deterministic = self._deterministic
if len(self._arguments) > 0:
expr.arguments.extend([arg.to_plan(session) for arg in self._arguments])
expr.python_udf.CopyFrom(cast(proto.PythonUDF, self._function.to_plan(session)))
return expr
def to_plan_judf(
self, session: "SparkConnectClient"
) -> "proto.CommonInlineUserDefinedFunction":
expr = proto.CommonInlineUserDefinedFunction()
expr.function_name = self._function_name
expr.java_udf.CopyFrom(cast(proto.JavaUDF, self._function.to_plan(session)))
return expr
@property
def children(self) -> Sequence["Expression"]:
return self._arguments
def __repr__(self) -> str:
return f"{self._function_name}({', '.join([str(arg) for arg in self._arguments])})"
class WithField(Expression):
def __init__(
self,
structExpr: Expression,
fieldName: str,
valueExpr: Expression,
) -> None:
super().__init__()
assert isinstance(structExpr, Expression)
self._structExpr = structExpr
assert isinstance(fieldName, str)
self._fieldName = fieldName
assert isinstance(valueExpr, Expression)
self._valueExpr = valueExpr
def to_plan(self, session: "SparkConnectClient") -> proto.Expression:
expr = self._create_proto_expression()
expr.update_fields.struct_expression.CopyFrom(self._structExpr.to_plan(session))
expr.update_fields.field_name = self._fieldName
expr.update_fields.value_expression.CopyFrom(self._valueExpr.to_plan(session))
return expr
@property
def children(self) -> Sequence["Expression"]:
return [self._structExpr, self._valueExpr]
def __repr__(self) -> str:
return f"update_field({self._structExpr}, {self._fieldName}, {self._valueExpr})"
class DropField(Expression):
def __init__(
self,
structExpr: Expression,
fieldName: str,
) -> None:
super().__init__()
assert isinstance(structExpr, Expression)
self._structExpr = structExpr
assert isinstance(fieldName, str)
self._fieldName = fieldName
def to_plan(self, session: "SparkConnectClient") -> proto.Expression:
expr = self._create_proto_expression()
expr.update_fields.struct_expression.CopyFrom(self._structExpr.to_plan(session))
expr.update_fields.field_name = self._fieldName
return expr
@property
def children(self) -> Sequence["Expression"]:
return [self._structExpr]
def __repr__(self) -> str:
return f"drop_field({self._structExpr}, {self._fieldName})"
class UnresolvedExtractValue(Expression):
def __init__(
self,
child: Expression,
extraction: Expression,
) -> None:
super().__init__()
assert isinstance(child, Expression)
self._child = child
assert isinstance(extraction, Expression)
self._extraction = extraction
def to_plan(self, session: "SparkConnectClient") -> proto.Expression:
expr = self._create_proto_expression()
expr.unresolved_extract_value.child.CopyFrom(self._child.to_plan(session))
expr.unresolved_extract_value.extraction.CopyFrom(self._extraction.to_plan(session))
return expr
@property
def children(self) -> Sequence["Expression"]:
return [self._child, self._extraction]
def __repr__(self) -> str:
return f"{self._child}['{self._extraction}']"
class UnresolvedRegex(Expression):
def __init__(self, col_name: str, plan_id: Optional[int] = None) -> None:
super().__init__()
assert isinstance(col_name, str)
self.col_name = col_name
assert plan_id is None or isinstance(plan_id, int)
self._plan_id = plan_id
def to_plan(self, session: "SparkConnectClient") -> proto.Expression:
expr = self._create_proto_expression()
expr.unresolved_regex.col_name = self.col_name
if self._plan_id is not None:
expr.unresolved_regex.plan_id = self._plan_id
return expr
def __repr__(self) -> str:
return f"UnresolvedRegex({self.col_name})"
class CastExpression(Expression):
def __init__(
self,
expr: Expression,
data_type: Union[DataType, str],
eval_mode: Optional[str] = None,
) -> None:
super().__init__()
self._expr = expr
assert isinstance(data_type, (DataType, str))
self._data_type = data_type
if eval_mode is not None:
assert isinstance(eval_mode, str)
assert eval_mode in ["legacy", "ansi", "try"]
self._eval_mode = eval_mode
def to_plan(self, session: "SparkConnectClient") -> proto.Expression:
fun = self._create_proto_expression()
fun.cast.expr.CopyFrom(self._expr.to_plan(session))
if isinstance(self._data_type, str):
fun.cast.type_str = self._data_type
else:
fun.cast.type.CopyFrom(pyspark_types_to_proto_types(self._data_type))
if self._eval_mode is not None:
if self._eval_mode == "legacy":
fun.cast.eval_mode = proto.Expression.Cast.EvalMode.EVAL_MODE_LEGACY
elif self._eval_mode == "ansi":
fun.cast.eval_mode = proto.Expression.Cast.EvalMode.EVAL_MODE_ANSI
elif self._eval_mode == "try":
fun.cast.eval_mode = proto.Expression.Cast.EvalMode.EVAL_MODE_TRY
return fun
@property
def children(self) -> Sequence["Expression"]:
return [self._expr]
def __repr__(self) -> str:
# We cannot guarantee the string representations be exactly the same, e.g.
# str(sf.col("a").cast("long")):
# Column<'CAST(a AS BIGINT)'> <- Spark Classic
# Column<'CAST(a AS LONG)'> <- Spark Connect
if isinstance(self._data_type, DataType):
str_data_type = self._data_type.simpleString().upper()
else:
str_data_type = str(self._data_type).upper()
if self._eval_mode is not None and self._eval_mode == "try":
return f"TRY_CAST({self._expr} AS {str_data_type})"
else:
return f"CAST({self._expr} AS {str_data_type})"
class UnresolvedNamedLambdaVariable(Expression):
_lock: Lock = Lock()
_nextVarNameId: int = 0
def __init__(
self,
name_parts: Sequence[str],
) -> None:
super().__init__()
assert (
isinstance(name_parts, list)
and len(name_parts) > 0
and all(isinstance(p, str) for p in name_parts)
)
self._name_parts = name_parts
def to_plan(self, session: "SparkConnectClient") -> proto.Expression:
expr = self._create_proto_expression()
expr.unresolved_named_lambda_variable.name_parts.extend(self._name_parts)
return expr
def __repr__(self) -> str:
return ", ".join(self._name_parts)
@staticmethod
def fresh_var_name(name: str) -> str:
assert isinstance(name, str) and str != ""
_id: Optional[int] = None
with UnresolvedNamedLambdaVariable._lock:
_id = UnresolvedNamedLambdaVariable._nextVarNameId
UnresolvedNamedLambdaVariable._nextVarNameId += 1
assert _id is not None
return f"{name}_{_id}"
class LambdaFunction(Expression):
def __init__(
self,
function: Expression,
arguments: Sequence[UnresolvedNamedLambdaVariable],
) -> None:
super().__init__()
assert isinstance(function, Expression)
assert (
isinstance(arguments, list)
and len(arguments) > 0
and all(isinstance(arg, UnresolvedNamedLambdaVariable) for arg in arguments)
)
self._function = function
self._arguments = arguments
def to_plan(self, session: "SparkConnectClient") -> proto.Expression:
expr = self._create_proto_expression()
expr.lambda_function.function.CopyFrom(self._function.to_plan(session))
expr.lambda_function.arguments.extend(
[arg.to_plan(session).unresolved_named_lambda_variable for arg in self._arguments]
)
return expr
@property
def children(self) -> Sequence["Expression"]:
return [self._function] + self._arguments
def __repr__(self) -> str:
return (
f"LambdaFunction({str(self._function)}, "
+ f"{', '.join([str(arg) for arg in self._arguments])})"
)
class WindowExpression(Expression):
def __init__(
self,
windowFunction: Expression,
windowSpec: "WindowSpec",
) -> None:
super().__init__()
from pyspark.sql.connect.window import WindowSpec
assert windowFunction is not None and isinstance(windowFunction, Expression)
assert windowSpec is not None and isinstance(windowSpec, WindowSpec)
self._windowFunction = windowFunction
self._windowSpec = windowSpec
def to_plan(self, session: "SparkConnectClient") -> proto.Expression:
expr = self._create_proto_expression()
expr.window.window_function.CopyFrom(self._windowFunction.to_plan(session))
if len(self._windowSpec._partitionSpec) > 0:
expr.window.partition_spec.extend(
[p.to_plan(session) for p in self._windowSpec._partitionSpec]
)
else:
warnings.warn(
"WARN WindowExpression: No Partition Defined for Window operation! "
"Moving all data to a single partition, this can cause serious "
"performance degradation."
)
if len(self._windowSpec._orderSpec) > 0:
expr.window.order_spec.extend(
[s.to_plan(session).sort_order for s in self._windowSpec._orderSpec]
)
if self._windowSpec._frame is not None:
if self._windowSpec._frame._isRowFrame:
expr.window.frame_spec.frame_type = (
proto.Expression.Window.WindowFrame.FrameType.FRAME_TYPE_ROW
)
start = self._windowSpec._frame._start
if start == 0:
expr.window.frame_spec.lower.current_row = True
elif start == JVM_LONG_MIN:
expr.window.frame_spec.lower.unbounded = True
elif JVM_INT_MIN <= start <= JVM_INT_MAX:
expr.window.frame_spec.lower.value.literal.integer = start
else:
raise PySparkValueError(
errorClass="VALUE_NOT_BETWEEN",
messageParameters={
"arg_name": "start",
"min": str(JVM_INT_MIN),
"max": str(JVM_INT_MAX),
},
)
end = self._windowSpec._frame._end
if end == 0:
expr.window.frame_spec.upper.current_row = True
elif end == JVM_LONG_MAX:
expr.window.frame_spec.upper.unbounded = True
elif JVM_INT_MIN <= end <= JVM_INT_MAX:
expr.window.frame_spec.upper.value.literal.integer = end
else:
raise PySparkValueError(
errorClass="VALUE_NOT_BETWEEN",
messageParameters={
"arg_name": "end",
"min": str(JVM_INT_MIN),
"max": str(JVM_INT_MAX),
},
)
else:
expr.window.frame_spec.frame_type = (
proto.Expression.Window.WindowFrame.FrameType.FRAME_TYPE_RANGE
)
start = self._windowSpec._frame._start
if start == 0:
expr.window.frame_spec.lower.current_row = True
elif start == JVM_LONG_MIN:
expr.window.frame_spec.lower.unbounded = True
else:
expr.window.frame_spec.lower.value.literal.long = start
end = self._windowSpec._frame._end
if end == 0:
expr.window.frame_spec.upper.current_row = True
elif end == JVM_LONG_MAX:
expr.window.frame_spec.upper.unbounded = True
else:
expr.window.frame_spec.upper.value.literal.long = end
return expr
@property
def children(self) -> Sequence["Expression"]:
return (
[self._windowFunction] + self._windowSpec._partitionSpec + self._windowSpec._orderSpec
)
def __repr__(self) -> str:
return f"WindowExpression({str(self._windowFunction)}, ({str(self._windowSpec)}))"
class DistributedSequenceID(Expression):
def to_plan(self, session: "SparkConnectClient") -> proto.Expression:
unresolved_function = UnresolvedFunction(name="distributed_sequence_id", args=[])
return unresolved_function.to_plan(session)
def __repr__(self) -> str:
return "DistributedSequenceID()"
class CallFunction(Expression):
def __init__(self, name: str, args: Sequence["Expression"]):
super().__init__()
assert isinstance(name, str)
self._name = name
assert isinstance(args, list) and all(isinstance(arg, Expression) for arg in args)
self._args = args
def to_plan(self, session: "SparkConnectClient") -> "proto.Expression":
expr = self._create_proto_expression()
expr.call_function.function_name = self._name
if len(self._args) > 0:
expr.call_function.arguments.extend([arg.to_plan(session) for arg in self._args])
return expr
@property
def children(self) -> Sequence["Expression"]:
return self._args
def __repr__(self) -> str:
if len(self._args) > 0:
return f"CallFunction('{self._name}', {', '.join([str(arg) for arg in self._args])})"
else:
return f"CallFunction('{self._name}')"
class NamedArgumentExpression(Expression):
def __init__(self, key: str, value: Expression):
super().__init__()
assert isinstance(key, str)
self._key = key
assert isinstance(value, Expression)
self._value = value
def to_plan(self, session: "SparkConnectClient") -> "proto.Expression":
expr = self._create_proto_expression()
expr.named_argument_expression.key = self._key
expr.named_argument_expression.value.CopyFrom(self._value.to_plan(session))
return expr
@property
def children(self) -> Sequence["Expression"]:
return [self._value]
def __repr__(self) -> str:
return f"{self._key} => {self._value}"
class SubqueryExpression(Expression):
def __init__(
self,
plan: "LogicalPlan",
subquery_type: str,
partition_spec: Optional[Sequence["Expression"]] = None,
order_spec: Optional[Sequence["SortOrder"]] = None,
with_single_partition: Optional[bool] = None,
in_subquery_values: Optional[Sequence["Expression"]] = None,
) -> None:
assert isinstance(subquery_type, str)
assert subquery_type in ("scalar", "exists", "table_arg", "in")
super().__init__()
self._plan = plan
self._subquery_type = subquery_type
self._partition_spec = partition_spec or []
self._order_spec = order_spec or []
self._with_single_partition = with_single_partition
self._in_subquery_values = in_subquery_values or []
def to_plan(self, session: "SparkConnectClient") -> proto.Expression:
expr = self._create_proto_expression()
expr.subquery_expression.plan_id = self._plan._plan_id
if self._subquery_type == "scalar":
expr.subquery_expression.subquery_type = proto.SubqueryExpression.SUBQUERY_TYPE_SCALAR
elif self._subquery_type == "exists":
expr.subquery_expression.subquery_type = proto.SubqueryExpression.SUBQUERY_TYPE_EXISTS
elif self._subquery_type == "table_arg":
expr.subquery_expression.subquery_type = (
proto.SubqueryExpression.SUBQUERY_TYPE_TABLE_ARG
)
# Populate TableArgOptions
table_arg_options = expr.subquery_expression.table_arg_options
if len(self._partition_spec) > 0:
table_arg_options.partition_spec.extend(
[p.to_plan(session) for p in self._partition_spec]
)
if len(self._order_spec) > 0:
table_arg_options.order_spec.extend(
[o.to_plan(session).sort_order for o in self._order_spec]
)
if self._with_single_partition is not None:
table_arg_options.with_single_partition = self._with_single_partition
elif self._subquery_type == "in":
expr.subquery_expression.subquery_type = proto.SubqueryExpression.SUBQUERY_TYPE_IN
expr.subquery_expression.in_subquery_values.extend(
[expr.to_plan(session) for expr in self._in_subquery_values]
)
return expr
def __repr__(self) -> str:
repr_parts = [f"plan={self._plan}", f"type={self._subquery_type}"]
if self._subquery_type == "table_arg":
if self._partition_spec:
repr_parts.append(f"partition_spec={self._partition_spec}")
if self._order_spec:
repr_parts.append(f"order_spec={self._order_spec}")
if self._with_single_partition is not None:
repr_parts.append(f"with_single_partition={self._with_single_partition}")
elif self._subquery_type == "in":
repr_parts.append(f"values={self._in_subquery_values}")
return f"SubqueryExpression({', '.join(repr_parts)})"