python/pyspark/sql/connect/plan.py (2,133 lines of code) (raw):

# # Licensed to the Apache Software Foundation (ASF) under one or more # contributor license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright ownership. # The ASF licenses this file to You under the Apache License, Version 2.0 # (the "License"); you may not use this file except in compliance with # the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # mypy: disable-error-code="operator" from pyspark.resource import ResourceProfile from pyspark.sql.connect.utils import check_dependencies check_dependencies(__name__) from typing import ( Any, List, Optional, Type, Sequence, Union, cast, TYPE_CHECKING, Mapping, Dict, Tuple, ) import functools import json import pickle from threading import Lock from inspect import signature, isclass import pyarrow as pa from pyspark.serializers import CloudPickleSerializer from pyspark.storagelevel import StorageLevel from pyspark.sql.types import DataType import pyspark.sql.connect.proto as proto from pyspark.sql.column import Column from pyspark.sql.connect.logging import logger from pyspark.sql.connect.proto import base_pb2 as spark_dot_connect_dot_base__pb2 from pyspark.sql.connect.conversion import storage_level_to_proto from pyspark.sql.connect.expressions import Expression, SubqueryExpression from pyspark.sql.connect.types import pyspark_types_to_proto_types, UnparsedDataType from pyspark.errors import ( AnalysisException, PySparkValueError, PySparkPicklingError, ) if TYPE_CHECKING: from pyspark.sql.connect.client import SparkConnectClient from pyspark.sql.connect.udf import UserDefinedFunction from pyspark.sql.connect.observation import Observation from pyspark.sql.connect.session import SparkSession class LogicalPlan: _lock: Lock = Lock() _nextPlanId: int = 0 INDENT = 2 def __init__( self, child: Optional["LogicalPlan"], references: Optional[Sequence["LogicalPlan"]] = None ) -> None: """ Parameters ---------- child : :class:`LogicalPlan`, optional. The child logical plan. references : list of :class:`LogicalPlan`, optional. The list of logical plans that are referenced as subqueries in this logical plan. """ self._child = child self._root_plan_id = LogicalPlan._fresh_plan_id() self._references: Sequence["LogicalPlan"] = references or [] self._plan_id_with_rel: Optional[int] = None if len(self._references) > 0: assert all(isinstance(r, LogicalPlan) for r in self._references) self._plan_id_with_rel = LogicalPlan._fresh_plan_id() @property def _plan_id(self) -> int: return self._plan_id_with_rel or self._root_plan_id @staticmethod def _fresh_plan_id() -> int: plan_id: Optional[int] = None with LogicalPlan._lock: plan_id = LogicalPlan._nextPlanId LogicalPlan._nextPlanId += 1 assert plan_id is not None return plan_id def _create_proto_relation(self) -> proto.Relation: plan = proto.Relation() plan.common.plan_id = self._root_plan_id return plan def plan(self, session: "SparkConnectClient") -> proto.Relation: # type: ignore[empty-body] ... def command(self, session: "SparkConnectClient") -> proto.Command: # type: ignore[empty-body] ... def _verify(self, session: "SparkConnectClient") -> bool: """This method is used to verify that the current logical plan can be serialized to Proto and back and afterwards is identical.""" plan = proto.Plan() plan.root.CopyFrom(self.plan(session)) serialized_plan = plan.SerializeToString() test_plan = proto.Plan() test_plan.ParseFromString(serialized_plan) return test_plan == plan def to_proto(self, session: "SparkConnectClient", debug: bool = False) -> proto.Plan: """ Generates connect proto plan based on this LogicalPlan. Parameters ---------- session : :class:`SparkConnectClient`, optional. a session that connects remote spark cluster. debug: bool if enabled, the proto plan will be printed. """ plan = proto.Plan() plan.root.CopyFrom(self.plan(session)) if debug: print(plan) return plan @property def observations(self) -> Dict[str, "Observation"]: if self._child is None: return {} else: return self._child.observations @staticmethod def _collect_references( cols_or_exprs: Sequence[Union[Column, Expression]] ) -> Sequence["LogicalPlan"]: references: List[LogicalPlan] = [] def append_reference(e: Expression) -> None: if isinstance(e, SubqueryExpression): references.append(e._plan) for col_or_expr in cols_or_exprs: if isinstance(col_or_expr, Column): col_or_expr._expr.foreach(append_reference) else: col_or_expr.foreach(append_reference) return references def _with_relations( self, root: proto.Relation, session: "SparkConnectClient" ) -> proto.Relation: if len(self._references) == 0: return root else: # When there are references to other DataFrame, e.g., subqueries, build new plan like: # with_relations [id 10] # root: plan [id 9] # reference: # refs#1: [id 8] # refs#2: [id 5] plan = proto.Relation() assert isinstance(self._plan_id_with_rel, int) plan.common.plan_id = self._plan_id_with_rel plan.with_relations.root.CopyFrom(root) plan.with_relations.references.extend([ref.plan(session) for ref in self._references]) return plan def _parameters_to_print(self, parameters: Mapping[str, Any]) -> Mapping[str, Any]: """ Extracts the parameters that are able to be printed. It looks up the signature in the constructor of this :class:`LogicalPlan`, and retrieves the variables from this instance by the same name (or the name with prefix `_`) defined in the constructor. Parameters ---------- parameters : map Parameter mapping from ``inspect.signature(...).parameters`` Returns ------- dict A dictionary consisting of a string name and variable found in this :class:`LogicalPlan`. Notes ----- :class:`LogicalPlan` itself is filtered out and considered as a non-printable parameter. Examples -------- The example below returns a dictionary from `self._start`, `self._end`, `self._num_partitions`. >>> rg = Range(0, 10, 1) >>> rg._parameters_to_print(signature(rg.__class__.__init__).parameters) {'start': 0, 'end': 10, 'step': 1, 'num_partitions': None} If the child is defined, it is not considered as a printable instance >>> project = Project(rg, "value") >>> project._parameters_to_print(signature(project.__class__.__init__).parameters) {'columns': ['value']} """ params = {} for name, tpe in parameters.items(): # LogicalPlan is not to print, e.g., LogicalPlan is_logical_plan = isclass(tpe.annotation) and isinstance(tpe.annotation, LogicalPlan) # Look up the string argument defined as a forward reference e.g., "LogicalPlan" is_forwardref_logical_plan = getattr(tpe.annotation, "__forward_arg__", "").endswith( "LogicalPlan" ) # Wrapped LogicalPlan, e.g., Optional[LogicalPlan] is_nested_logical_plan = any( isclass(a) and issubclass(a, LogicalPlan) for a in getattr(tpe.annotation, "__args__", ()) ) # Wrapped forward reference of LogicalPlan, e.g., Optional["LogicalPlan"]. is_nested_forwardref_logical_plan = any( getattr(a, "__forward_arg__", "").endswith("LogicalPlan") for a in getattr(tpe.annotation, "__args__", ()) ) if ( not is_logical_plan and not is_forwardref_logical_plan and not is_nested_logical_plan and not is_nested_forwardref_logical_plan ): # Searches self.name or self._name try: params[name] = getattr(self, name) except AttributeError: try: params[name] = getattr(self, "_" + name) except AttributeError: pass # Simply ignore return params def print(self, indent: int = 0) -> str: """ Print the simple string representation of the current :class:`LogicalPlan`. Parameters ---------- indent : int The number of leading spaces for the output string. Returns ------- str Simple string representation of this :class:`LogicalPlan`. """ params = self._parameters_to_print(signature(self.__class__.__init__).parameters) pretty_params = [f"{name}='{param}'" for name, param in params.items()] if len(pretty_params) == 0: pretty_str = "" else: pretty_str = " " + ", ".join(pretty_params) return f"{' ' * indent}<{self.__class__.__name__}{pretty_str}>\n{self._child_print(indent)}" def _repr_html_(self) -> str: """Returns a :class:`LogicalPlan` with HTML code. This is generally called in third-party systems such as Jupyter. Returns ------- str HTML representation of this :class:`LogicalPlan`. """ params = self._parameters_to_print(signature(self.__class__.__init__).parameters) pretty_params = [ f"\n {name}: " f"{param} <br/>" for name, param in params.items() ] if len(pretty_params) == 0: pretty_str = "" else: pretty_str = "".join(pretty_params) return f""" <ul> <li> <b>{self.__class__.__name__}</b><br/>{pretty_str} {self._child_repr()} </li> </ul> """ def _child_print(self, indent: int) -> str: return self._child.print(indent + LogicalPlan.INDENT) if self._child else "" def _child_repr(self) -> str: return self._child._repr_html_() if self._child is not None else "" class DataSource(LogicalPlan): """A datasource with a format and optional a schema from which Spark reads data""" def __init__( self, format: Optional[str] = None, schema: Optional[str] = None, options: Optional[Mapping[str, str]] = None, paths: Optional[List[str]] = None, predicates: Optional[List[str]] = None, is_streaming: Optional[bool] = None, ) -> None: super().__init__(None) assert format is None or isinstance(format, str) assert schema is None or isinstance(schema, str) if options is not None: new_options = {} for k, v in options.items(): if v is not None: assert isinstance(k, str) assert isinstance(v, str) new_options[k] = v options = new_options if paths is not None: assert isinstance(paths, list) assert all(isinstance(path, str) for path in paths) if predicates is not None: assert isinstance(predicates, list) assert all(isinstance(predicate, str) for predicate in predicates) self._format = format self._schema = schema self._options = options self._paths = paths self._predicates = predicates self._is_streaming = is_streaming def plan(self, session: "SparkConnectClient") -> proto.Relation: plan = self._create_proto_relation() if self._format is not None: plan.read.data_source.format = self._format if self._schema is not None: plan.read.data_source.schema = self._schema if self._options is not None and len(self._options) > 0: for k, v in self._options.items(): plan.read.data_source.options[k] = v if self._paths is not None and len(self._paths) > 0: plan.read.data_source.paths.extend(self._paths) if self._predicates is not None and len(self._predicates) > 0: plan.read.data_source.predicates.extend(self._predicates) if self._is_streaming is not None: plan.read.is_streaming = self._is_streaming return plan class Read(LogicalPlan): def __init__( self, table_name: str, options: Optional[Dict[str, str]] = None, is_streaming: Optional[bool] = None, ) -> None: super().__init__(None) self.table_name = table_name self.options = options or {} self._is_streaming = is_streaming def plan(self, session: "SparkConnectClient") -> proto.Relation: plan = self._create_proto_relation() plan.read.named_table.unparsed_identifier = self.table_name if self._is_streaming is not None: plan.read.is_streaming = self._is_streaming for k, v in self.options.items(): plan.read.named_table.options[k] = v return plan def print(self, indent: int = 0) -> str: return f"{' ' * indent}<Read table_name={self.table_name}>\n" class LocalRelation(LogicalPlan): """Creates a LocalRelation plan object based on a PyArrow Table.""" def __init__( self, table: Optional["pa.Table"], schema: Optional[str] = None, ) -> None: super().__init__(None) if table is None: assert schema is not None else: assert isinstance(table, pa.Table) assert schema is None or isinstance(schema, str) self._table = table self._schema = schema def plan(self, session: "SparkConnectClient") -> proto.Relation: plan = self._create_proto_relation() if self._table is not None: sink = pa.BufferOutputStream() with pa.ipc.new_stream(sink, self._table.schema) as writer: for b in self._table.to_batches(): writer.write_batch(b) plan.local_relation.data = sink.getvalue().to_pybytes() if self._schema is not None: plan.local_relation.schema = self._schema return plan def serialize(self, session: "SparkConnectClient") -> bytes: p = self.plan(session) return bytes(p.local_relation.SerializeToString()) def print(self, indent: int = 0) -> str: return f"{' ' * indent}<LocalRelation>\n" def _repr_html_(self) -> str: return """ <ul> <li><b>LocalRelation</b></li> </ul> """ class CachedLocalRelation(LogicalPlan): """Creates a CachedLocalRelation plan object based on a hash of a LocalRelation.""" def __init__(self, hash: str) -> None: super().__init__(None) self._hash = hash def plan(self, session: "SparkConnectClient") -> proto.Relation: plan = self._create_proto_relation() clr = plan.cached_local_relation clr.hash = self._hash return plan def print(self, indent: int = 0) -> str: return f"{' ' * indent}<CachedLocalRelation>\n" def _repr_html_(self) -> str: return """ <ul> <li><b>CachedLocalRelation</b></li> </ul> """ class ShowString(LogicalPlan): def __init__( self, child: Optional["LogicalPlan"], num_rows: int, truncate: int, vertical: bool ) -> None: super().__init__(child) self.num_rows = num_rows self.truncate = truncate self.vertical = vertical def plan(self, session: "SparkConnectClient") -> proto.Relation: assert self._child is not None plan = self._create_proto_relation() plan.show_string.input.CopyFrom(self._child.plan(session)) plan.show_string.num_rows = self.num_rows plan.show_string.truncate = self.truncate plan.show_string.vertical = self.vertical return plan class HtmlString(LogicalPlan): def __init__(self, child: Optional["LogicalPlan"], num_rows: int, truncate: int) -> None: super().__init__(child) self.num_rows = num_rows self.truncate = truncate def plan(self, session: "SparkConnectClient") -> proto.Relation: assert self._child is not None plan = self._create_proto_relation() plan.html_string.input.CopyFrom(self._child.plan(session)) plan.html_string.num_rows = self.num_rows plan.html_string.truncate = self.truncate return plan class Project(LogicalPlan): """Logical plan object for a projection. All input arguments are directly serialized into the corresponding protocol buffer objects. This class only provides very limited error handling and input validation. To be compatible with PySpark, we validate that the input arguments are all expressions to be able to serialize them to the server. """ def __init__( self, child: Optional["LogicalPlan"], columns: List[Column], ) -> None: assert all(isinstance(c, Column) for c in columns) super().__init__(child, self._collect_references(columns)) self._columns = columns def plan(self, session: "SparkConnectClient") -> proto.Relation: assert self._child is not None plan = self._create_proto_relation() plan.project.input.CopyFrom(self._child.plan(session)) plan.project.expressions.extend([c.to_plan(session) for c in self._columns]) return self._with_relations(plan, session) class WithColumns(LogicalPlan): """Logical plan object for a withColumns operation.""" def __init__( self, child: Optional["LogicalPlan"], columnNames: Sequence[str], columns: Sequence[Column], metadata: Optional[Sequence[str]] = None, ) -> None: assert isinstance(columnNames, list) assert len(columnNames) > 0 assert all(isinstance(c, str) for c in columnNames) assert isinstance(columns, list) assert len(columns) == len(columnNames) assert all(isinstance(c, Column) for c in columns) if metadata is not None: assert isinstance(metadata, list) assert len(metadata) == len(columnNames) for m in metadata: assert isinstance(m, str) # validate json string assert m == "" or json.loads(m) is not None super().__init__(child, self._collect_references(columns)) self._columnNames = columnNames self._columns = columns self._metadata = metadata def plan(self, session: "SparkConnectClient") -> proto.Relation: assert self._child is not None plan = self._create_proto_relation() plan.with_columns.input.CopyFrom(self._child.plan(session)) for i in range(0, len(self._columnNames)): alias = proto.Expression.Alias() alias.expr.CopyFrom(self._columns[i].to_plan(session)) alias.name.append(self._columnNames[i]) if self._metadata is not None: alias.metadata = self._metadata[i] plan.with_columns.aliases.append(alias) return self._with_relations(plan, session) class WithWatermark(LogicalPlan): """Logical plan object for a WithWatermark operation.""" def __init__(self, child: Optional["LogicalPlan"], event_time: str, delay_threshold: str): super().__init__(child) self._event_time = event_time self._delay_threshold = delay_threshold def plan(self, session: "SparkConnectClient") -> proto.Relation: assert self._child is not None plan = self._create_proto_relation() plan.with_watermark.input.CopyFrom(self._child.plan(session)) plan.with_watermark.event_time = self._event_time plan.with_watermark.delay_threshold = self._delay_threshold return plan class CachedRemoteRelation(LogicalPlan): """Logical plan object for a DataFrame reference which represents a DataFrame that's been cached on the server with a given id.""" def __init__(self, relation_id: str, spark_session: "SparkSession"): super().__init__(None) self._relation_id = relation_id # Needs to hold the session to make a request itself. self._spark_session = spark_session def plan(self, session: "SparkConnectClient") -> proto.Relation: plan = self._create_proto_relation() plan.cached_remote_relation.relation_id = self._relation_id return plan def __del__(self) -> None: session = self._spark_session # If session is already closed, all cached DataFrame should be released. if session is not None and not session.client.is_closed and self._relation_id is not None: try: command = RemoveRemoteCachedRelation(self).command(session=session.client) req = session.client._execute_plan_request_with_metadata() if session.client._user_id: req.user_context.user_id = session.client._user_id req.plan.command.CopyFrom(command) for attempt in session.client._retrying(): with attempt: # !!HACK ALERT!! # unary_stream does not work on Python's exit for an unknown reasons # Therefore, here we open unary_unary channel instead. # See also :class:`SparkConnectServiceStub`. request_serializer = ( spark_dot_connect_dot_base__pb2.ExecutePlanRequest.SerializeToString ) response_deserializer = ( spark_dot_connect_dot_base__pb2.ExecutePlanResponse.FromString ) channel = session.client._channel.unary_unary( "/spark.connect.SparkConnectService/ExecutePlan", request_serializer=request_serializer, response_deserializer=response_deserializer, ) metadata = session.client._builder.metadata() channel(req, metadata=metadata) # type: ignore[arg-type] except Exception as e: logger.warn(f"RemoveRemoteCachedRelation failed with exception: {e}.") class Hint(LogicalPlan): """Logical plan object for a Hint operation.""" def __init__( self, child: Optional["LogicalPlan"], name: str, parameters: Sequence[Column], ) -> None: assert isinstance(name, str) assert parameters is not None and isinstance(parameters, List) for param in parameters: assert isinstance(param, Column) super().__init__(child, self._collect_references(parameters)) self._name = name self._parameters = parameters def plan(self, session: "SparkConnectClient") -> proto.Relation: assert self._child is not None plan = self._create_proto_relation() plan.hint.input.CopyFrom(self._child.plan(session)) plan.hint.name = self._name plan.hint.parameters.extend([param.to_plan(session) for param in self._parameters]) return self._with_relations(plan, session) class Filter(LogicalPlan): def __init__(self, child: Optional["LogicalPlan"], filter: Column) -> None: super().__init__(child, self._collect_references([filter])) self.filter = filter def plan(self, session: "SparkConnectClient") -> proto.Relation: assert self._child is not None plan = self._create_proto_relation() plan.filter.input.CopyFrom(self._child.plan(session)) plan.filter.condition.CopyFrom(self.filter.to_plan(session)) return self._with_relations(plan, session) class Limit(LogicalPlan): def __init__(self, child: Optional["LogicalPlan"], limit: int) -> None: super().__init__(child) self.limit = limit def plan(self, session: "SparkConnectClient") -> proto.Relation: assert self._child is not None plan = self._create_proto_relation() plan.limit.input.CopyFrom(self._child.plan(session)) plan.limit.limit = self.limit return plan class Tail(LogicalPlan): def __init__(self, child: Optional["LogicalPlan"], limit: int) -> None: super().__init__(child) self.limit = limit def plan(self, session: "SparkConnectClient") -> proto.Relation: assert self._child is not None plan = self._create_proto_relation() plan.tail.input.CopyFrom(self._child.plan(session)) plan.tail.limit = self.limit return plan class Offset(LogicalPlan): def __init__(self, child: Optional["LogicalPlan"], offset: int = 0) -> None: super().__init__(child) self.offset = offset def plan(self, session: "SparkConnectClient") -> proto.Relation: assert self._child is not None plan = self._create_proto_relation() plan.offset.input.CopyFrom(self._child.plan(session)) plan.offset.offset = self.offset return plan class Deduplicate(LogicalPlan): def __init__( self, child: Optional["LogicalPlan"], all_columns_as_keys: bool = False, column_names: Optional[List[str]] = None, within_watermark: bool = False, ) -> None: super().__init__(child) self.all_columns_as_keys = all_columns_as_keys self.column_names = column_names self.within_watermark = within_watermark def plan(self, session: "SparkConnectClient") -> proto.Relation: assert self._child is not None plan = self._create_proto_relation() plan.deduplicate.input.CopyFrom(self._child.plan(session)) plan.deduplicate.all_columns_as_keys = self.all_columns_as_keys plan.deduplicate.within_watermark = self.within_watermark if self.column_names is not None: plan.deduplicate.column_names.extend(self.column_names) return plan class Sort(LogicalPlan): def __init__( self, child: Optional["LogicalPlan"], columns: List[Column], is_global: bool, ) -> None: assert all(isinstance(c, Column) for c in columns) assert isinstance(is_global, bool) super().__init__(child, self._collect_references(columns)) self.columns = columns self.is_global = is_global def plan(self, session: "SparkConnectClient") -> proto.Relation: assert self._child is not None plan = self._create_proto_relation() plan.sort.input.CopyFrom(self._child.plan(session)) plan.sort.order.extend([c.to_plan(session).sort_order for c in self.columns]) plan.sort.is_global = self.is_global return self._with_relations(plan, session) class Drop(LogicalPlan): def __init__( self, child: Optional["LogicalPlan"], columns: List[Union[Column, str]], ) -> None: if len(columns) > 0: assert all(isinstance(c, (Column, str)) for c in columns) super().__init__( child, self._collect_references([c for c in columns if isinstance(c, Column)]) ) self._columns = columns def plan(self, session: "SparkConnectClient") -> proto.Relation: assert self._child is not None plan = self._create_proto_relation() plan.drop.input.CopyFrom(self._child.plan(session)) for c in self._columns: if isinstance(c, Column): plan.drop.columns.append(c.to_plan(session)) else: plan.drop.column_names.append(c) return self._with_relations(plan, session) class Sample(LogicalPlan): def __init__( self, child: Optional["LogicalPlan"], lower_bound: float, upper_bound: float, with_replacement: bool, seed: int, deterministic_order: bool = False, ) -> None: super().__init__(child) self.lower_bound = lower_bound self.upper_bound = upper_bound self.with_replacement = with_replacement self.seed = seed self.deterministic_order = deterministic_order def plan(self, session: "SparkConnectClient") -> proto.Relation: assert self._child is not None plan = self._create_proto_relation() plan.sample.input.CopyFrom(self._child.plan(session)) plan.sample.lower_bound = self.lower_bound plan.sample.upper_bound = self.upper_bound plan.sample.with_replacement = self.with_replacement plan.sample.seed = self.seed plan.sample.deterministic_order = self.deterministic_order return plan class Aggregate(LogicalPlan): def __init__( self, child: Optional["LogicalPlan"], group_type: str, grouping_cols: Sequence[Column], aggregate_cols: Sequence[Column], pivot_col: Optional[Column], pivot_values: Optional[Sequence[Column]], grouping_sets: Optional[Sequence[Sequence[Column]]], ) -> None: assert isinstance(group_type, str) and group_type in [ "groupby", "rollup", "cube", "pivot", "grouping_sets", ] assert isinstance(grouping_cols, list) and all(isinstance(c, Column) for c in grouping_cols) assert isinstance(aggregate_cols, list) and all( isinstance(c, Column) for c in aggregate_cols ) if group_type == "pivot": assert pivot_col is not None and isinstance(pivot_col, Column) assert pivot_values is None or isinstance(pivot_values, list) elif group_type == "grouping_sets": assert grouping_sets is None or isinstance(grouping_sets, list) else: assert pivot_col is None assert pivot_values is None assert grouping_sets is None super().__init__( child, self._collect_references( grouping_cols + aggregate_cols + ([pivot_col] if pivot_col is not None else []) + (pivot_values if pivot_values is not None else []) + ([g for gs in grouping_sets for g in gs] if grouping_sets is not None else []) ), ) self._group_type = group_type self._grouping_cols = grouping_cols self._aggregate_cols = aggregate_cols self._pivot_col = pivot_col self._pivot_values = pivot_values self._grouping_sets = grouping_sets def plan(self, session: "SparkConnectClient") -> proto.Relation: assert self._child is not None plan = self._create_proto_relation() plan.aggregate.input.CopyFrom(self._child.plan(session)) plan.aggregate.grouping_expressions.extend( [c.to_plan(session) for c in self._grouping_cols] ) plan.aggregate.aggregate_expressions.extend( [c.to_plan(session) for c in self._aggregate_cols] ) if self._group_type == "groupby": plan.aggregate.group_type = proto.Aggregate.GroupType.GROUP_TYPE_GROUPBY elif self._group_type == "rollup": plan.aggregate.group_type = proto.Aggregate.GroupType.GROUP_TYPE_ROLLUP elif self._group_type == "cube": plan.aggregate.group_type = proto.Aggregate.GroupType.GROUP_TYPE_CUBE elif self._group_type == "pivot": plan.aggregate.group_type = proto.Aggregate.GroupType.GROUP_TYPE_PIVOT assert self._pivot_col is not None plan.aggregate.pivot.col.CopyFrom(self._pivot_col.to_plan(session)) if self._pivot_values is not None and len(self._pivot_values) > 0: plan.aggregate.pivot.values.extend( [v.to_plan(session).literal for v in self._pivot_values] ) elif self._group_type == "grouping_sets": plan.aggregate.group_type = proto.Aggregate.GroupType.GROUP_TYPE_GROUPING_SETS assert self._grouping_sets is not None for grouping_set in self._grouping_sets: plan.aggregate.grouping_sets.append( proto.Aggregate.GroupingSets( grouping_set=[c.to_plan(session) for c in grouping_set] ) ) return self._with_relations(plan, session) class Join(LogicalPlan): def __init__( self, left: Optional["LogicalPlan"], right: "LogicalPlan", on: Optional[Union[str, List[str], Column, List[Column]]], how: Optional[str], ) -> None: super().__init__( left, self._collect_references( [] if on is None or isinstance(on, str) else [on] if isinstance(on, Column) else [c for c in on if isinstance(c, Column)] ), ) self.left = cast(LogicalPlan, left) self.right = right self.on = on if how is None: join_type = proto.Join.JoinType.JOIN_TYPE_INNER elif how == "inner": join_type = proto.Join.JoinType.JOIN_TYPE_INNER elif how in ["outer", "full", "fullouter"]: join_type = proto.Join.JoinType.JOIN_TYPE_FULL_OUTER elif how in ["leftouter", "left"]: join_type = proto.Join.JoinType.JOIN_TYPE_LEFT_OUTER elif how in ["rightouter", "right"]: join_type = proto.Join.JoinType.JOIN_TYPE_RIGHT_OUTER elif how in ["leftsemi", "semi"]: join_type = proto.Join.JoinType.JOIN_TYPE_LEFT_SEMI elif how in ["leftanti", "anti"]: join_type = proto.Join.JoinType.JOIN_TYPE_LEFT_ANTI elif how == "cross": join_type = proto.Join.JoinType.JOIN_TYPE_CROSS else: raise AnalysisException( errorClass="UNSUPPORTED_JOIN_TYPE", messageParameters={ "typ": how, "supported": ( "'" + "', '".join( [ "inner", "outer", "full", "fullouter", "full_outer", "leftouter", "left", "left_outer", "rightouter", "right", "right_outer", "leftsemi", "left_semi", "semi", "leftanti", "left_anti", "anti", "cross", ] ) + "'" ), }, ) self.how = join_type def plan(self, session: "SparkConnectClient") -> proto.Relation: plan = self._create_proto_relation() plan.join.left.CopyFrom(self.left.plan(session)) plan.join.right.CopyFrom(self.right.plan(session)) if self.on is not None: if not isinstance(self.on, list): if isinstance(self.on, str): plan.join.using_columns.append(self.on) else: plan.join.join_condition.CopyFrom(self.on.to_plan(session)) elif len(self.on) > 0: if isinstance(self.on[0], str): plan.join.using_columns.extend(cast(str, self.on)) else: merge_column = functools.reduce(lambda c1, c2: c1 & c2, self.on) plan.join.join_condition.CopyFrom(cast(Column, merge_column).to_plan(session)) plan.join.join_type = self.how return self._with_relations(plan, session) @property def observations(self) -> Dict[str, "Observation"]: return dict(**super().observations, **self.right.observations) def print(self, indent: int = 0) -> str: i = " " * indent o = " " * (indent + LogicalPlan.INDENT) n = indent + LogicalPlan.INDENT * 2 return ( f"{i}<Join on={self.on} how={self.how}>\n{o}" f"left=\n{self.left.print(n)}\n{o}right=\n{self.right.print(n)}" ) def _repr_html_(self) -> str: return f""" <ul> <li> <b>Join</b><br /> Left: {self.left._repr_html_()} Right: {self.right._repr_html_()} </li> </uL> """ class AsOfJoin(LogicalPlan): def __init__( self, left: LogicalPlan, right: LogicalPlan, left_as_of: Column, right_as_of: Column, on: Optional[Union[str, List[str], Column, List[Column]]], how: str, tolerance: Optional[Column], allow_exact_matches: bool, direction: str, ) -> None: super().__init__( left, self._collect_references( [left_as_of, right_as_of] + ( [] if on is None or isinstance(on, str) else [on] if isinstance(on, Column) else [c for c in on if isinstance(c, Column)] ) + ([tolerance] if tolerance is not None else []) ), ) self.left = left self.right = right self.left_as_of = left_as_of self.right_as_of = right_as_of self.on = on self.how = how self.tolerance = tolerance self.allow_exact_matches = allow_exact_matches self.direction = direction def plan(self, session: "SparkConnectClient") -> proto.Relation: plan = self._create_proto_relation() plan.as_of_join.left.CopyFrom(self.left.plan(session)) plan.as_of_join.right.CopyFrom(self.right.plan(session)) plan.as_of_join.left_as_of.CopyFrom(self.left_as_of.to_plan(session)) plan.as_of_join.right_as_of.CopyFrom(self.right_as_of.to_plan(session)) if self.on is not None: if not isinstance(self.on, list): if isinstance(self.on, str): plan.as_of_join.using_columns.append(self.on) else: plan.as_of_join.join_expr.CopyFrom(self.on.to_plan(session)) elif len(self.on) > 0: if isinstance(self.on[0], str): plan.as_of_join.using_columns.extend(cast(List[str], self.on)) else: merge_column = functools.reduce(lambda c1, c2: c1 & c2, self.on) plan.as_of_join.join_expr.CopyFrom(cast(Column, merge_column).to_plan(session)) plan.as_of_join.join_type = self.how if self.tolerance is not None: plan.as_of_join.tolerance.CopyFrom(self.tolerance.to_plan(session)) plan.as_of_join.allow_exact_matches = self.allow_exact_matches plan.as_of_join.direction = self.direction return self._with_relations(plan, session) @property def observations(self) -> Dict[str, "Observation"]: return dict(**super().observations, **self.right.observations) def print(self, indent: int = 0) -> str: assert self.left is not None assert self.right is not None i = " " * indent o = " " * (indent + LogicalPlan.INDENT) n = indent + LogicalPlan.INDENT * 2 return ( f"{i}<AsOfJoin left_as_of={self.left_as_of}, right_as_of={self.right_as_of}, " f"on={self.on} how={self.how}>\n{o}" f"left=\n{self.left.print(n)}\n{o}right=\n{self.right.print(n)}" ) def _repr_html_(self) -> str: assert self.left is not None assert self.right is not None return f""" <ul> <li> <b>AsOfJoin</b><br /> Left: {self.left._repr_html_()} Right: {self.right._repr_html_()} </li> </uL> """ class LateralJoin(LogicalPlan): def __init__( self, left: Optional[LogicalPlan], right: LogicalPlan, on: Optional[Column], how: Optional[str], ) -> None: super().__init__(left, self._collect_references([on] if on is not None else [])) self.left = cast(LogicalPlan, left) self.right = right self.on = on if how is None: join_type = proto.Join.JoinType.JOIN_TYPE_INNER elif how == "inner": join_type = proto.Join.JoinType.JOIN_TYPE_INNER elif how in ["leftouter", "left"]: join_type = proto.Join.JoinType.JOIN_TYPE_LEFT_OUTER elif how == "cross": join_type = proto.Join.JoinType.JOIN_TYPE_CROSS else: raise AnalysisException( errorClass="UNSUPPORTED_JOIN_TYPE", messageParameters={ "typ": how, "supported": ( "'" + "', '".join(["inner", "leftouter", "left", "left_outer", "cross"]) + "'" ), }, ) self.how = join_type def plan(self, session: "SparkConnectClient") -> proto.Relation: plan = self._create_proto_relation() plan.lateral_join.left.CopyFrom(self.left.plan(session)) plan.lateral_join.right.CopyFrom(self.right.plan(session)) if self.on is not None: plan.lateral_join.join_condition.CopyFrom(self.on.to_plan(session)) plan.lateral_join.join_type = self.how return self._with_relations(plan, session) @property def observations(self) -> Dict[str, "Observation"]: return dict(**super().observations, **self.right.observations) def print(self, indent: int = 0) -> str: i = " " * indent o = " " * (indent + LogicalPlan.INDENT) n = indent + LogicalPlan.INDENT * 2 return ( f"{i}<LateralJoin on={self.on} how={self.how}>\n{o}" f"left=\n{self.left.print(n)}\n{o}right=\n{self.right.print(n)}" ) def _repr_html_(self) -> str: return f""" <ul> <li> <b>LateralJoin</b><br /> Left: {self.left._repr_html_()} Right: {self.right._repr_html_()} </li> </uL> """ class SetOperation(LogicalPlan): def __init__( self, child: Optional["LogicalPlan"], other: Optional["LogicalPlan"], set_op: str, is_all: bool = True, by_name: bool = False, allow_missing_columns: bool = False, ) -> None: super().__init__(child) self.other = other self.by_name = by_name self.is_all = is_all self.set_op = set_op self.allow_missing_columns = allow_missing_columns def plan(self, session: "SparkConnectClient") -> proto.Relation: assert self._child is not None plan = self._create_proto_relation() if self._child is not None: plan.set_op.left_input.CopyFrom(self._child.plan(session)) if self.other is not None: plan.set_op.right_input.CopyFrom(self.other.plan(session)) if self.set_op == "union": plan.set_op.set_op_type = proto.SetOperation.SET_OP_TYPE_UNION elif self.set_op == "intersect": plan.set_op.set_op_type = proto.SetOperation.SET_OP_TYPE_INTERSECT elif self.set_op == "except": plan.set_op.set_op_type = proto.SetOperation.SET_OP_TYPE_EXCEPT else: raise PySparkValueError( errorClass="UNSUPPORTED_OPERATION", messageParameters={"operation": self.set_op}, ) plan.set_op.is_all = self.is_all plan.set_op.by_name = self.by_name plan.set_op.allow_missing_columns = self.allow_missing_columns return plan @property def observations(self) -> Dict[str, "Observation"]: return dict( **super().observations, **(self.other.observations if self.other is not None else {}), ) def print(self, indent: int = 0) -> str: assert self._child is not None assert self.other is not None i = " " * indent o = " " * (indent + LogicalPlan.INDENT) n = indent + LogicalPlan.INDENT * 2 return ( f"{i}SetOperation\n{o}child1=\n{self._child.print(n)}" f"\n{o}child2=\n{self.other.print(n)}" ) def _repr_html_(self) -> str: assert self._child is not None assert self.other is not None return f""" <ul> <li> <b>SetOperation</b><br /> Left: {self._child._repr_html_()} Right: {self.other._repr_html_()} </li> </uL> """ class Repartition(LogicalPlan): """Repartition Relation into a different number of partitions.""" def __init__(self, child: Optional["LogicalPlan"], num_partitions: int, shuffle: bool) -> None: super().__init__(child) self._num_partitions = num_partitions self._shuffle = shuffle def plan(self, session: "SparkConnectClient") -> proto.Relation: plan = self._create_proto_relation() if self._child is not None: plan.repartition.input.CopyFrom(self._child.plan(session)) plan.repartition.shuffle = self._shuffle plan.repartition.num_partitions = self._num_partitions return plan class RepartitionByExpression(LogicalPlan): """Repartition Relation into a different number of partitions using Expression""" def __init__( self, child: Optional["LogicalPlan"], num_partitions: Optional[int], columns: List[Column], ) -> None: assert all(isinstance(c, Column) for c in columns) super().__init__(child, self._collect_references(columns)) self.num_partitions = num_partitions self.columns = columns def plan(self, session: "SparkConnectClient") -> proto.Relation: plan = self._create_proto_relation() plan.repartition_by_expression.partition_exprs.extend( [c.to_plan(session) for c in self.columns] ) if self._child is not None: plan.repartition_by_expression.input.CopyFrom(self._child.plan(session)) if self.num_partitions is not None: plan.repartition_by_expression.num_partitions = self.num_partitions return self._with_relations(plan, session) class SubqueryAlias(LogicalPlan): """Alias for a relation.""" def __init__(self, child: Optional["LogicalPlan"], alias: str) -> None: super().__init__(child) self._alias = alias def plan(self, session: "SparkConnectClient") -> proto.Relation: plan = self._create_proto_relation() if self._child is not None: plan.subquery_alias.input.CopyFrom(self._child.plan(session)) plan.subquery_alias.alias = self._alias return plan class WithRelations(LogicalPlan): def __init__( self, child: Optional["LogicalPlan"], references: Sequence["LogicalPlan"], ) -> None: super().__init__(child) assert references is not None and len(references) > 0 assert all(isinstance(ref, LogicalPlan) for ref in references) self._references = references def plan(self, session: "SparkConnectClient") -> proto.Relation: plan = self._create_proto_relation() if self._child is not None: plan.with_relations.root.CopyFrom(self._child.plan(session)) for ref in self._references: plan.with_relations.references.append(ref.plan(session)) return plan class SQL(LogicalPlan): def __init__( self, query: str, args: Optional[List[Column]] = None, named_args: Optional[Dict[str, Column]] = None, views: Optional[Sequence[SubqueryAlias]] = None, ) -> None: if args is not None: assert isinstance(args, List) assert all(isinstance(arg, Column) for arg in args) if named_args is not None: assert isinstance(named_args, Dict) for k, arg in named_args.items(): assert isinstance(k, str) assert isinstance(arg, Column) if views is not None: assert isinstance(views, List) assert all(isinstance(v, SubqueryAlias) for v in views) super().__init__(None, views) self._query = query self._args = args self._named_args = named_args self._views = views def plan(self, session: "SparkConnectClient") -> proto.Relation: plan = self._create_proto_relation() plan.sql.query = self._query if self._args is not None and len(self._args) > 0: plan.sql.pos_arguments.extend([arg.to_plan(session) for arg in self._args]) if self._named_args is not None and len(self._named_args) > 0: for k, arg in self._named_args.items(): plan.sql.named_arguments[k].CopyFrom(arg.to_plan(session)) return self._with_relations(plan, session) def command(self, session: "SparkConnectClient") -> proto.Command: cmd = proto.Command() cmd.sql_command.input.CopyFrom(self.plan(session)) return cmd class Range(LogicalPlan): def __init__( self, start: int, end: int, step: int, num_partitions: Optional[int] = None, ) -> None: super().__init__(None) self._start = start self._end = end self._step = step self._num_partitions = num_partitions def plan(self, session: "SparkConnectClient") -> proto.Relation: plan = self._create_proto_relation() plan.range.start = self._start plan.range.end = self._end plan.range.step = self._step if self._num_partitions is not None: plan.range.num_partitions = self._num_partitions return plan class ToSchema(LogicalPlan): def __init__(self, child: Optional["LogicalPlan"], schema: DataType) -> None: super().__init__(child) self._schema = schema def plan(self, session: "SparkConnectClient") -> proto.Relation: assert self._child is not None plan = self._create_proto_relation() plan.to_schema.input.CopyFrom(self._child.plan(session)) plan.to_schema.schema.CopyFrom(pyspark_types_to_proto_types(self._schema)) return plan class WithColumnsRenamed(LogicalPlan): def __init__(self, child: Optional["LogicalPlan"], colsMap: Mapping[str, str]) -> None: super().__init__(child) self._colsMap = colsMap def plan(self, session: "SparkConnectClient") -> proto.Relation: assert self._child is not None plan = self._create_proto_relation() plan.with_columns_renamed.input.CopyFrom(self._child.plan(session)) if len(self._colsMap) > 0: for k, v in self._colsMap.items(): rename = proto.WithColumnsRenamed.Rename() rename.col_name = k rename.new_col_name = v plan.with_columns_renamed.renames.append(rename) return plan class Unpivot(LogicalPlan): """Logical plan object for a unpivot operation.""" def __init__( self, child: Optional["LogicalPlan"], ids: List[Column], values: Optional[List[Column]], variable_column_name: str, value_column_name: str, ) -> None: super().__init__(child, self._collect_references(ids + (values or []))) self.ids = ids self.values = values self.variable_column_name = variable_column_name self.value_column_name = value_column_name def plan(self, session: "SparkConnectClient") -> proto.Relation: assert self._child is not None plan = self._create_proto_relation() plan.unpivot.input.CopyFrom(self._child.plan(session)) plan.unpivot.ids.extend([id.to_plan(session) for id in self.ids]) if self.values is not None: plan.unpivot.values.values.extend([v.to_plan(session) for v in self.values]) plan.unpivot.variable_column_name = self.variable_column_name plan.unpivot.value_column_name = self.value_column_name return self._with_relations(plan, session) class Transpose(LogicalPlan): """Logical plan object for a transpose operation.""" def __init__( self, child: Optional["LogicalPlan"], index_columns: Sequence[Column], ) -> None: super().__init__(child, self._collect_references(index_columns)) self.index_columns = index_columns def plan(self, session: "SparkConnectClient") -> proto.Relation: assert self._child is not None plan = self._create_proto_relation() plan.transpose.input.CopyFrom(self._child.plan(session)) if self.index_columns is not None and len(self.index_columns) > 0: for index_column in self.index_columns: plan.transpose.index_columns.append(index_column.to_plan(session)) return self._with_relations(plan, session) class UnresolvedTableValuedFunction(LogicalPlan): def __init__(self, name: str, args: Sequence[Column]): super().__init__(None, self._collect_references(args)) self._name = name self._args = args def plan(self, session: "SparkConnectClient") -> proto.Relation: plan = self._create_proto_relation() plan.unresolved_table_valued_function.function_name = self._name for arg in self._args: plan.unresolved_table_valued_function.arguments.append(arg.to_plan(session)) return self._with_relations(plan, session) class CollectMetrics(LogicalPlan): """Logical plan object for a CollectMetrics operation.""" def __init__( self, child: Optional["LogicalPlan"], observation: Union[str, "Observation"], exprs: List[Column], ) -> None: assert all(isinstance(e, Column) for e in exprs) super().__init__(child, self._collect_references(exprs)) self._observation = observation self._exprs = exprs def plan(self, session: "SparkConnectClient") -> proto.Relation: assert self._child is not None plan = self._create_proto_relation() plan.collect_metrics.input.CopyFrom(self._child.plan(session)) plan.collect_metrics.name = ( self._observation if isinstance(self._observation, str) else str(self._observation._name) ) plan.collect_metrics.metrics.extend([e.to_plan(session) for e in self._exprs]) return self._with_relations(plan, session) @property def observations(self) -> Dict[str, "Observation"]: from pyspark.sql.connect.observation import Observation if isinstance(self._observation, Observation): observations = {str(self._observation._name): self._observation} else: observations = {} return dict(**super().observations, **observations) class NAFill(LogicalPlan): def __init__( self, child: Optional["LogicalPlan"], cols: Optional[List[str]], values: List[Any] ) -> None: super().__init__(child) assert ( isinstance(values, list) and len(values) > 0 and all(isinstance(v, (bool, int, float, str)) for v in values) ) if cols is not None and len(cols) > 0: assert isinstance(cols, list) and all(isinstance(c, str) for c in cols) if len(values) > 1: assert len(cols) == len(values) self.cols = cols self.values = values def _convert_value(self, v: Any) -> proto.Expression.Literal: value = proto.Expression.Literal() if isinstance(v, bool): value.boolean = v elif isinstance(v, int): value.long = v elif isinstance(v, float): value.double = v else: value.string = v return value def plan(self, session: "SparkConnectClient") -> proto.Relation: assert self._child is not None plan = self._create_proto_relation() plan.fill_na.input.CopyFrom(self._child.plan(session)) if self.cols is not None and len(self.cols) > 0: plan.fill_na.cols.extend(self.cols) plan.fill_na.values.extend([self._convert_value(v) for v in self.values]) return plan class NADrop(LogicalPlan): def __init__( self, child: Optional["LogicalPlan"], cols: Optional[List[str]], min_non_nulls: Optional[int], ) -> None: super().__init__(child) self.cols = cols self.min_non_nulls = min_non_nulls def plan(self, session: "SparkConnectClient") -> proto.Relation: assert self._child is not None plan = self._create_proto_relation() plan.drop_na.input.CopyFrom(self._child.plan(session)) if self.cols is not None and len(self.cols) > 0: plan.drop_na.cols.extend(self.cols) if self.min_non_nulls is not None: plan.drop_na.min_non_nulls = self.min_non_nulls return plan class NAReplace(LogicalPlan): def __init__( self, child: Optional["LogicalPlan"], cols: Optional[List[str]], replacements: Sequence[Tuple[Column, Column]], ) -> None: assert replacements is not None and isinstance(replacements, List) for k, v in replacements: assert k is not None and isinstance(k, Column) assert v is not None and isinstance(v, Column) super().__init__(child, self._collect_references([e for t in replacements for e in t])) self.cols = cols self.replacements = replacements def plan(self, session: "SparkConnectClient") -> proto.Relation: assert self._child is not None plan = self._create_proto_relation() plan.replace.input.CopyFrom(self._child.plan(session)) if self.cols is not None and len(self.cols) > 0: plan.replace.cols.extend(self.cols) if len(self.replacements) > 0: for old_value, new_value in self.replacements: replacement = proto.NAReplace.Replacement() replacement.old_value.CopyFrom(old_value.to_plan(session).literal) replacement.new_value.CopyFrom(new_value.to_plan(session).literal) plan.replace.replacements.append(replacement) return self._with_relations(plan, session) class StatSummary(LogicalPlan): def __init__(self, child: Optional["LogicalPlan"], statistics: List[str]) -> None: super().__init__(child) self.statistics = statistics def plan(self, session: "SparkConnectClient") -> proto.Relation: assert self._child is not None plan = self._create_proto_relation() plan.summary.input.CopyFrom(self._child.plan(session)) plan.summary.statistics.extend(self.statistics) return plan class StatDescribe(LogicalPlan): def __init__(self, child: Optional["LogicalPlan"], cols: List[str]) -> None: super().__init__(child) self.cols = cols def plan(self, session: "SparkConnectClient") -> proto.Relation: assert self._child is not None plan = self._create_proto_relation() plan.describe.input.CopyFrom(self._child.plan(session)) plan.describe.cols.extend(self.cols) return plan class StatCov(LogicalPlan): def __init__(self, child: Optional["LogicalPlan"], col1: str, col2: str) -> None: super().__init__(child) self._col1 = col1 self._col2 = col2 def plan(self, session: "SparkConnectClient") -> proto.Relation: assert self._child is not None plan = self._create_proto_relation() plan.cov.input.CopyFrom(self._child.plan(session)) plan.cov.col1 = self._col1 plan.cov.col2 = self._col2 return plan class StatApproxQuantile(LogicalPlan): def __init__( self, child: Optional["LogicalPlan"], cols: List[str], probabilities: List[float], relativeError: float, ) -> None: super().__init__(child) self._cols = cols self._probabilities = probabilities self._relativeError = relativeError def plan(self, session: "SparkConnectClient") -> proto.Relation: assert self._child is not None plan = self._create_proto_relation() plan.approx_quantile.input.CopyFrom(self._child.plan(session)) plan.approx_quantile.cols.extend(self._cols) plan.approx_quantile.probabilities.extend(self._probabilities) plan.approx_quantile.relative_error = self._relativeError return plan class StatCrosstab(LogicalPlan): def __init__(self, child: Optional["LogicalPlan"], col1: str, col2: str) -> None: super().__init__(child) self.col1 = col1 self.col2 = col2 def plan(self, session: "SparkConnectClient") -> proto.Relation: assert self._child is not None plan = self._create_proto_relation() plan.crosstab.input.CopyFrom(self._child.plan(session)) plan.crosstab.col1 = self.col1 plan.crosstab.col2 = self.col2 return plan class StatFreqItems(LogicalPlan): def __init__( self, child: Optional["LogicalPlan"], cols: List[str], support: float, ) -> None: super().__init__(child) self._cols = cols self._support = support def plan(self, session: "SparkConnectClient") -> proto.Relation: assert self._child is not None plan = self._create_proto_relation() plan.freq_items.input.CopyFrom(self._child.plan(session)) plan.freq_items.cols.extend(self._cols) plan.freq_items.support = self._support return plan class StatSampleBy(LogicalPlan): def __init__( self, child: Optional["LogicalPlan"], col: Column, fractions: Sequence[Tuple[Column, float]], seed: int, ) -> None: assert col is not None and isinstance(col, (Column, str)) assert fractions is not None and isinstance(fractions, List) for k, v in fractions: assert k is not None and isinstance(k, Column) assert v is not None and isinstance(v, float) assert seed is None or isinstance(seed, int) super().__init__( child, self._collect_references( [col] if isinstance(col, Column) else [] + [c for c, _ in fractions] ), ) self._col = col self._fractions = fractions self._seed = seed def plan(self, session: "SparkConnectClient") -> proto.Relation: assert self._child is not None plan = self._create_proto_relation() plan.sample_by.input.CopyFrom(self._child.plan(session)) plan.sample_by.col.CopyFrom(self._col._expr.to_plan(session)) if len(self._fractions) > 0: for k, v in self._fractions: fraction = proto.StatSampleBy.Fraction() fraction.stratum.CopyFrom(k.to_plan(session).literal) fraction.fraction = float(v) plan.sample_by.fractions.append(fraction) plan.sample_by.seed = self._seed return self._with_relations(plan, session) class StatCorr(LogicalPlan): def __init__(self, child: Optional["LogicalPlan"], col1: str, col2: str, method: str) -> None: super().__init__(child) self._col1 = col1 self._col2 = col2 self._method = method def plan(self, session: "SparkConnectClient") -> proto.Relation: assert self._child is not None plan = self._create_proto_relation() plan.corr.input.CopyFrom(self._child.plan(session)) plan.corr.col1 = self._col1 plan.corr.col2 = self._col2 plan.corr.method = self._method return plan class ToDF(LogicalPlan): def __init__(self, child: Optional["LogicalPlan"], cols: Sequence[str]) -> None: super().__init__(child) self._cols = cols def plan(self, session: "SparkConnectClient") -> proto.Relation: assert self._child is not None plan = self._create_proto_relation() plan.to_df.input.CopyFrom(self._child.plan(session)) plan.to_df.column_names.extend(self._cols) return plan class CreateView(LogicalPlan): def __init__( self, child: Optional["LogicalPlan"], name: str, is_global: bool, replace: bool ) -> None: super().__init__(child) self._name = name self._is_global = is_global self._replace = replace def command(self, session: "SparkConnectClient") -> proto.Command: assert self._child is not None plan = proto.Command() plan.create_dataframe_view.replace = self._replace plan.create_dataframe_view.is_global = self._is_global plan.create_dataframe_view.name = self._name plan.create_dataframe_view.input.CopyFrom(self._child.plan(session)) return plan class WriteOperation(LogicalPlan): def __init__(self, child: "LogicalPlan") -> None: super(WriteOperation, self).__init__(child) self.source: Optional[str] = None self.path: Optional[str] = None self.table_name: Optional[str] = None self.table_save_method: Optional[str] = None self.mode: Optional[str] = None self.sort_cols: List[str] = [] self.partitioning_cols: List[str] = [] self.clustering_cols: List[str] = [] self.options: Dict[str, Optional[str]] = {} self.num_buckets: int = -1 self.bucket_cols: List[str] = [] def command(self, session: "SparkConnectClient") -> proto.Command: assert self._child is not None plan = proto.Command() plan.write_operation.input.CopyFrom(self._child.plan(session)) if self.source is not None: plan.write_operation.source = self.source plan.write_operation.sort_column_names.extend(self.sort_cols) plan.write_operation.partitioning_columns.extend(self.partitioning_cols) plan.write_operation.clustering_columns.extend(self.clustering_cols) if self.num_buckets > 0: plan.write_operation.bucket_by.bucket_column_names.extend(self.bucket_cols) plan.write_operation.bucket_by.num_buckets = self.num_buckets for k in self.options: if self.options[k] is None: plan.write_operation.options.pop(k, None) else: plan.write_operation.options[k] = cast(str, self.options[k]) if self.table_name is not None: plan.write_operation.table.table_name = self.table_name if self.table_save_method is not None: tsm = self.table_save_method.lower() if tsm == "save_as_table": plan.write_operation.table.save_method = ( proto.WriteOperation.SaveTable.TableSaveMethod.TABLE_SAVE_METHOD_SAVE_AS_TABLE # noqa: E501 ) elif tsm == "insert_into": plan.write_operation.table.save_method = ( proto.WriteOperation.SaveTable.TableSaveMethod.TABLE_SAVE_METHOD_INSERT_INTO ) else: raise PySparkValueError( errorClass="UNSUPPORTED_OPERATION", messageParameters={"operation": tsm}, ) elif self.path is not None: plan.write_operation.path = self.path if self.mode is not None: wm = self.mode.lower() if wm == "append": plan.write_operation.mode = proto.WriteOperation.SaveMode.SAVE_MODE_APPEND elif wm == "overwrite": plan.write_operation.mode = proto.WriteOperation.SaveMode.SAVE_MODE_OVERWRITE elif wm == "error": plan.write_operation.mode = proto.WriteOperation.SaveMode.SAVE_MODE_ERROR_IF_EXISTS elif wm == "ignore": plan.write_operation.mode = proto.WriteOperation.SaveMode.SAVE_MODE_IGNORE else: raise PySparkValueError( errorClass="UNSUPPORTED_OPERATION", messageParameters={"operation": self.mode}, ) return plan def print(self, indent: int = 0) -> str: i = " " * indent return ( f"{i}" f"<WriteOperation source='{self.source}' " f"path='{self.path} " f"table_name='{self.table_name}' " f"table_save_method='{self.table_save_method}' " f"mode='{self.mode}' " f"sort_cols='{self.sort_cols}' " f"partitioning_cols='{self.partitioning_cols}' " f"clustering_cols='{self.clustering_cols}' " f"num_buckets='{self.num_buckets}' " f"bucket_cols='{self.bucket_cols}' " f"options='{self.options}'>" ) def _repr_html_(self) -> str: return ( f"<uL><li>WriteOperation <br />source='{self.source}'<br />" f"path: '{self.path}<br />" f"table_name: '{self.table_name}' <br />" f"table_save_method: '{self.table_save_method}' <br />" f"mode: '{self.mode}' <br />" f"sort_cols: '{self.sort_cols}' <br />" f"partitioning_cols: '{self.partitioning_cols}' <br />" f"clustering_cols: '{self.clustering_cols}' <br />" f"num_buckets: '{self.num_buckets}' <br />" f"bucket_cols: '{self.bucket_cols}' <br />" f"options: '{self.options}'<br />" f"</li></ul>" ) class WriteOperationV2(LogicalPlan): def __init__(self, child: "LogicalPlan", table_name: str) -> None: super(WriteOperationV2, self).__init__(child) self.table_name: Optional[str] = table_name self.provider: Optional[str] = None self.partitioning_columns: List[Column] = [] self.clustering_columns: List[str] = [] self.options: dict[str, Optional[str]] = {} self.table_properties: dict[str, Optional[str]] = {} self.mode: Optional[str] = None self.overwrite_condition: Optional[Column] = None def command(self, session: "SparkConnectClient") -> proto.Command: assert self._child is not None plan = proto.Command() plan.write_operation_v2.input.CopyFrom(self._child.plan(session)) if self.table_name is not None: plan.write_operation_v2.table_name = self.table_name if self.provider is not None: plan.write_operation_v2.provider = self.provider plan.write_operation_v2.partitioning_columns.extend( [c.to_plan(session) for c in self.partitioning_columns] ) plan.write_operation_v2.clustering_columns.extend(self.clustering_columns) for k in self.options: if self.options[k] is None: plan.write_operation_v2.options.pop(k, None) else: plan.write_operation_v2.options[k] = cast(str, self.options[k]) for k in self.table_properties: if self.table_properties[k] is None: plan.write_operation_v2.table_properties.pop(k, None) else: plan.write_operation_v2.table_properties[k] = cast(str, self.table_properties[k]) if self.mode is not None: wm = self.mode.lower() if wm == "create": plan.write_operation_v2.mode = proto.WriteOperationV2.Mode.MODE_CREATE elif wm == "overwrite": plan.write_operation_v2.mode = proto.WriteOperationV2.Mode.MODE_OVERWRITE if self.overwrite_condition is not None: plan.write_operation_v2.overwrite_condition.CopyFrom( self.overwrite_condition.to_plan(session) ) elif wm == "overwrite_partitions": plan.write_operation_v2.mode = proto.WriteOperationV2.Mode.MODE_OVERWRITE_PARTITIONS elif wm == "append": plan.write_operation_v2.mode = proto.WriteOperationV2.Mode.MODE_APPEND elif wm == "replace": plan.write_operation_v2.mode = proto.WriteOperationV2.Mode.MODE_REPLACE elif wm == "create_or_replace": plan.write_operation_v2.mode = proto.WriteOperationV2.Mode.MODE_CREATE_OR_REPLACE else: raise PySparkValueError( errorClass="UNSUPPORTED_OPERATION", messageParameters={"operation": self.mode}, ) return plan class WriteStreamOperation(LogicalPlan): def __init__(self, child: "LogicalPlan") -> None: super(WriteStreamOperation, self).__init__(child) self.write_op = proto.WriteStreamOperationStart() def command(self, session: "SparkConnectClient") -> proto.Command: assert self._child is not None self.write_op.input.CopyFrom(self._child.plan(session)) cmd = proto.Command() cmd.write_stream_operation_start.CopyFrom(self.write_op) return cmd class RemoveRemoteCachedRelation(LogicalPlan): def __init__(self, relation: CachedRemoteRelation) -> None: super().__init__(None) self._relation = relation def command(self, session: "SparkConnectClient") -> proto.Command: plan = self._create_proto_relation() plan.cached_remote_relation.relation_id = self._relation._relation_id cmd = proto.Command() cmd.remove_cached_remote_relation_command.relation.CopyFrom(plan.cached_remote_relation) return cmd class Checkpoint(LogicalPlan): def __init__( self, child: Optional["LogicalPlan"], local: bool, eager: bool, storage_level: Optional[StorageLevel] = None, ) -> None: super().__init__(child) self._local = local self._eager = eager self._storage_level = storage_level def command(self, session: "SparkConnectClient") -> proto.Command: cmd = proto.Command() assert self._child is not None checkpoint_command = proto.CheckpointCommand( relation=self._child.plan(session), local=self._local, eager=self._eager, ) if self._storage_level is not None: checkpoint_command.storage_level.CopyFrom(storage_level_to_proto(self._storage_level)) cmd.checkpoint_command.CopyFrom(checkpoint_command) return cmd # Catalog API (internal-only) class CurrentDatabase(LogicalPlan): def __init__(self) -> None: super().__init__(None) def plan(self, session: "SparkConnectClient") -> proto.Relation: plan = self._create_proto_relation() plan.catalog.current_database.SetInParent() return plan class SetCurrentDatabase(LogicalPlan): def __init__(self, db_name: str) -> None: super().__init__(None) self._db_name = db_name def plan(self, session: "SparkConnectClient") -> proto.Relation: plan = self._create_proto_relation() plan.catalog.set_current_database.db_name = self._db_name return plan class ListDatabases(LogicalPlan): def __init__(self, pattern: Optional[str] = None) -> None: super().__init__(None) self._pattern = pattern def plan(self, session: "SparkConnectClient") -> proto.Relation: plan = self._create_proto_relation() plan.catalog.list_databases.SetInParent() if self._pattern is not None: plan.catalog.list_databases.pattern = self._pattern return plan class ListTables(LogicalPlan): def __init__(self, db_name: Optional[str] = None, pattern: Optional[str] = None) -> None: super().__init__(None) self._db_name = db_name self._pattern = pattern def plan(self, session: "SparkConnectClient") -> proto.Relation: plan = self._create_proto_relation() plan.catalog.list_tables.SetInParent() if self._db_name is not None: plan.catalog.list_tables.db_name = self._db_name if self._pattern is not None: plan.catalog.list_tables.pattern = self._pattern return plan class ListFunctions(LogicalPlan): def __init__(self, db_name: Optional[str] = None, pattern: Optional[str] = None) -> None: super().__init__(None) self._db_name = db_name self._pattern = pattern def plan(self, session: "SparkConnectClient") -> proto.Relation: plan = self._create_proto_relation() plan.catalog.list_functions.SetInParent() if self._db_name is not None: plan.catalog.list_functions.db_name = self._db_name if self._pattern is not None: plan.catalog.list_functions.pattern = self._pattern return plan class ListColumns(LogicalPlan): def __init__(self, table_name: str, db_name: Optional[str] = None) -> None: super().__init__(None) self._table_name = table_name self._db_name = db_name def plan(self, session: "SparkConnectClient") -> proto.Relation: plan = self._create_proto_relation() plan.catalog.list_columns.table_name = self._table_name if self._db_name is not None: plan.catalog.list_columns.db_name = self._db_name return plan class GetDatabase(LogicalPlan): def __init__(self, db_name: str) -> None: super().__init__(None) self._db_name = db_name def plan(self, session: "SparkConnectClient") -> proto.Relation: plan = self._create_proto_relation() plan.catalog.get_database.db_name = self._db_name return plan class GetTable(LogicalPlan): def __init__(self, table_name: str, db_name: Optional[str] = None) -> None: super().__init__(None) self._table_name = table_name self._db_name = db_name def plan(self, session: "SparkConnectClient") -> proto.Relation: plan = self._create_proto_relation() plan.catalog.get_table.table_name = self._table_name if self._db_name is not None: plan.catalog.get_table.db_name = self._db_name return plan class GetFunction(LogicalPlan): def __init__(self, function_name: str, db_name: Optional[str] = None) -> None: super().__init__(None) self._function_name = function_name self._db_name = db_name def plan(self, session: "SparkConnectClient") -> proto.Relation: plan = self._create_proto_relation() plan.catalog.get_function.function_name = self._function_name if self._db_name is not None: plan.catalog.get_function.db_name = self._db_name return plan class DatabaseExists(LogicalPlan): def __init__(self, db_name: str) -> None: super().__init__(None) self._db_name = db_name def plan(self, session: "SparkConnectClient") -> proto.Relation: plan = self._create_proto_relation() plan.catalog.database_exists.db_name = self._db_name return plan class TableExists(LogicalPlan): def __init__(self, table_name: str, db_name: Optional[str] = None) -> None: super().__init__(None) self._table_name = table_name self._db_name = db_name def plan(self, session: "SparkConnectClient") -> proto.Relation: plan = self._create_proto_relation() plan.catalog.table_exists.table_name = self._table_name if self._db_name is not None: plan.catalog.table_exists.db_name = self._db_name return plan class FunctionExists(LogicalPlan): def __init__(self, function_name: str, db_name: Optional[str] = None) -> None: super().__init__(None) self._function_name = function_name self._db_name = db_name def plan(self, session: "SparkConnectClient") -> proto.Relation: plan = self._create_proto_relation() plan.catalog.function_exists.function_name = self._function_name if self._db_name is not None: plan.catalog.function_exists.db_name = self._db_name return plan class CreateTable(LogicalPlan): def __init__( self, table_name: str, path: str, source: Optional[str] = None, description: Optional[str] = None, schema: Optional[DataType] = None, options: Mapping[str, str] = {}, ) -> None: super().__init__(None) self._table_name = table_name self._path = path self._source = source self._description = description self._schema = schema self._options = options def plan(self, session: "SparkConnectClient") -> proto.Relation: plan = self._create_proto_relation() plan.catalog.create_table.table_name = self._table_name if self._path is not None: plan.catalog.create_table.path = self._path if self._source is not None: plan.catalog.create_table.source = self._source if self._description is not None: plan.catalog.create_table.description = self._description if self._schema is not None: plan.catalog.create_table.schema.CopyFrom(pyspark_types_to_proto_types(self._schema)) for k in self._options.keys(): v = self._options.get(k) if v is not None: plan.catalog.create_table.options[k] = v return plan class DropTempView(LogicalPlan): def __init__(self, view_name: str) -> None: super().__init__(None) self._view_name = view_name def plan(self, session: "SparkConnectClient") -> proto.Relation: plan = self._create_proto_relation() plan.catalog.drop_temp_view.view_name = self._view_name return plan class DropGlobalTempView(LogicalPlan): def __init__(self, view_name: str) -> None: super().__init__(None) self._view_name = view_name def plan(self, session: "SparkConnectClient") -> proto.Relation: plan = self._create_proto_relation() plan.catalog.drop_global_temp_view.view_name = self._view_name return plan class RecoverPartitions(LogicalPlan): def __init__(self, table_name: str) -> None: super().__init__(None) self._table_name = table_name def plan(self, session: "SparkConnectClient") -> proto.Relation: plan = self._create_proto_relation() plan.catalog.recover_partitions.table_name = self._table_name return plan class IsCached(LogicalPlan): def __init__(self, table_name: str) -> None: super().__init__(None) self._table_name = table_name def plan(self, session: "SparkConnectClient") -> proto.Relation: plan = self._create_proto_relation() plan.catalog.is_cached.table_name = self._table_name return plan class CacheTable(LogicalPlan): def __init__(self, table_name: str, storage_level: Optional[StorageLevel] = None) -> None: super().__init__(None) self._table_name = table_name self._storage_level = storage_level def plan(self, session: "SparkConnectClient") -> proto.Relation: plan = self._create_proto_relation() _cache_table = proto.CacheTable(table_name=self._table_name) if self._storage_level: _cache_table.storage_level.CopyFrom(storage_level_to_proto(self._storage_level)) plan.catalog.cache_table.CopyFrom(_cache_table) return plan class UncacheTable(LogicalPlan): def __init__(self, table_name: str) -> None: super().__init__(None) self._table_name = table_name def plan(self, session: "SparkConnectClient") -> proto.Relation: plan = self._create_proto_relation() plan.catalog.uncache_table.table_name = self._table_name return plan class ClearCache(LogicalPlan): def __init__(self) -> None: super().__init__(None) def plan(self, session: "SparkConnectClient") -> proto.Relation: plan = self._create_proto_relation() plan.catalog.clear_cache.SetInParent() return plan class RefreshTable(LogicalPlan): def __init__(self, table_name: str) -> None: super().__init__(None) self._table_name = table_name def plan(self, session: "SparkConnectClient") -> proto.Relation: plan = self._create_proto_relation() plan.catalog.refresh_table.table_name = self._table_name return plan class RefreshByPath(LogicalPlan): def __init__(self, path: str) -> None: super().__init__(None) self._path = path def plan(self, session: "SparkConnectClient") -> proto.Relation: plan = self._create_proto_relation() plan.catalog.refresh_by_path.path = self._path return plan class CurrentCatalog(LogicalPlan): def __init__(self) -> None: super().__init__(None) def plan(self, session: "SparkConnectClient") -> proto.Relation: plan = self._create_proto_relation() plan.catalog.current_catalog.SetInParent() return plan class SetCurrentCatalog(LogicalPlan): def __init__(self, catalog_name: str) -> None: super().__init__(None) self._catalog_name = catalog_name def plan(self, session: "SparkConnectClient") -> proto.Relation: plan = self._create_proto_relation() plan.catalog.set_current_catalog.catalog_name = self._catalog_name return plan class ListCatalogs(LogicalPlan): def __init__(self, pattern: Optional[str] = None) -> None: super().__init__(None) self._pattern = pattern def plan(self, session: "SparkConnectClient") -> proto.Relation: plan = self._create_proto_relation() plan.catalog.list_catalogs.SetInParent() if self._pattern is not None: plan.catalog.list_catalogs.pattern = self._pattern return plan class MapPartitions(LogicalPlan): """Logical plan object for a mapPartitions-equivalent API: mapInPandas, mapInArrow.""" def __init__( self, child: Optional["LogicalPlan"], function: "UserDefinedFunction", cols: List[str], is_barrier: bool, profile: Optional[ResourceProfile], ) -> None: super().__init__(child) self._function = function._build_common_inline_user_defined_function(*cols) self._is_barrier = is_barrier self._profile = profile def plan(self, session: "SparkConnectClient") -> proto.Relation: assert self._child is not None plan = self._create_proto_relation() plan.map_partitions.input.CopyFrom(self._child.plan(session)) plan.map_partitions.func.CopyFrom(self._function.to_plan_udf(session)) plan.map_partitions.is_barrier = self._is_barrier if self._profile is not None: plan.map_partitions.profile_id = self._profile.id return plan class GroupMap(LogicalPlan): """Logical plan object for a Group Map API: apply, applyInPandas.""" def __init__( self, child: Optional["LogicalPlan"], grouping_cols: Sequence[Column], function: "UserDefinedFunction", cols: List[str], ): assert isinstance(grouping_cols, list) and all(isinstance(c, Column) for c in grouping_cols) super().__init__(child, self._collect_references(grouping_cols)) self._grouping_cols = grouping_cols self._function = function._build_common_inline_user_defined_function(*cols) def plan(self, session: "SparkConnectClient") -> proto.Relation: assert self._child is not None plan = self._create_proto_relation() plan.group_map.input.CopyFrom(self._child.plan(session)) plan.group_map.grouping_expressions.extend( [c.to_plan(session) for c in self._grouping_cols] ) plan.group_map.func.CopyFrom(self._function.to_plan_udf(session)) return self._with_relations(plan, session) class CoGroupMap(LogicalPlan): """Logical plan object for a CoGroup Map API: applyInPandas.""" def __init__( self, input: Optional["LogicalPlan"], input_grouping_cols: Sequence[Column], other: Optional["LogicalPlan"], other_grouping_cols: Sequence[Column], function: "UserDefinedFunction", ): assert isinstance(input_grouping_cols, list) and all( isinstance(c, Column) for c in input_grouping_cols ) assert isinstance(other_grouping_cols, list) and all( isinstance(c, Column) for c in other_grouping_cols ) super().__init__(input, self._collect_references(input_grouping_cols + other_grouping_cols)) self._input_grouping_cols = input_grouping_cols self._other_grouping_cols = other_grouping_cols self._other = cast(LogicalPlan, other) # The function takes entire DataFrame as inputs, no need to do # column binding (no input columns). self._function = function._build_common_inline_user_defined_function() def plan(self, session: "SparkConnectClient") -> proto.Relation: assert self._child is not None plan = self._create_proto_relation() plan.co_group_map.input.CopyFrom(self._child.plan(session)) plan.co_group_map.input_grouping_expressions.extend( [c.to_plan(session) for c in self._input_grouping_cols] ) plan.co_group_map.other.CopyFrom(self._other.plan(session)) plan.co_group_map.other_grouping_expressions.extend( [c.to_plan(session) for c in self._other_grouping_cols] ) plan.co_group_map.func.CopyFrom(self._function.to_plan_udf(session)) return self._with_relations(plan, session) class ApplyInPandasWithState(LogicalPlan): """Logical plan object for a applyInPandasWithState.""" def __init__( self, child: Optional["LogicalPlan"], grouping_cols: Sequence[Column], function: "UserDefinedFunction", output_schema: str, state_schema: str, output_mode: str, timeout_conf: str, cols: List[str], ): assert isinstance(grouping_cols, list) and all(isinstance(c, Column) for c in grouping_cols) super().__init__(child, self._collect_references(grouping_cols)) self._grouping_cols = grouping_cols self._function = function._build_common_inline_user_defined_function(*cols) self._output_schema = output_schema self._state_schema = state_schema self._output_mode = output_mode self._timeout_conf = timeout_conf def plan(self, session: "SparkConnectClient") -> proto.Relation: assert self._child is not None plan = self._create_proto_relation() plan.apply_in_pandas_with_state.input.CopyFrom(self._child.plan(session)) plan.apply_in_pandas_with_state.grouping_expressions.extend( [c.to_plan(session) for c in self._grouping_cols] ) plan.apply_in_pandas_with_state.func.CopyFrom(self._function.to_plan_udf(session)) plan.apply_in_pandas_with_state.output_schema = self._output_schema plan.apply_in_pandas_with_state.state_schema = self._state_schema plan.apply_in_pandas_with_state.output_mode = self._output_mode plan.apply_in_pandas_with_state.timeout_conf = self._timeout_conf return self._with_relations(plan, session) class BaseTransformWithStateInPySpark(LogicalPlan): """Base implementation of logical plan object for a TransformWithStateIn(PySpark/Pandas).""" def __init__( self, child: Optional["LogicalPlan"], grouping_cols: Sequence[Column], function: "UserDefinedFunction", output_schema: Union[DataType, str], output_mode: str, time_mode: str, event_time_col_name: str, cols: List[str], initial_state_plan: Optional["LogicalPlan"], initial_state_grouping_cols: Optional[Sequence[Column]], ): assert isinstance(grouping_cols, list) and all(isinstance(c, Column) for c in grouping_cols) if initial_state_plan is not None: assert isinstance(initial_state_grouping_cols, list) and all( isinstance(c, Column) for c in initial_state_grouping_cols ) super().__init__( child, self._collect_references(grouping_cols + initial_state_grouping_cols) ) else: super().__init__(child, self._collect_references(grouping_cols)) self._grouping_cols = grouping_cols self._output_schema: DataType = ( UnparsedDataType(output_schema) if isinstance(output_schema, str) else output_schema ) self._output_mode = output_mode self._time_mode = time_mode self._event_time_col_name = event_time_col_name self._function = function._build_common_inline_user_defined_function(*cols) self._initial_state_plan = initial_state_plan self._initial_state_grouping_cols = initial_state_grouping_cols def plan(self, session: "SparkConnectClient") -> proto.Relation: assert self._child is not None plan = self._create_proto_relation() plan.group_map.input.CopyFrom(self._child.plan(session)) plan.group_map.grouping_expressions.extend( [c.to_plan(session) for c in self._grouping_cols] ) plan.group_map.output_mode = self._output_mode # fill in initial state related fields if self._initial_state_plan is not None: plan.group_map.initial_input.CopyFrom(self._initial_state_plan.plan(session)) assert self._initial_state_grouping_cols is not None plan.group_map.initial_grouping_expressions.extend( [c.to_plan(session) for c in self._initial_state_grouping_cols] ) # fill in transformWithStateInPySpark/Pandas related fields tws_info = proto.TransformWithStateInfo() tws_info.time_mode = self._time_mode tws_info.event_time_column_name = self._event_time_col_name tws_info.output_schema.CopyFrom(pyspark_types_to_proto_types(self._output_schema)) plan.group_map.transform_with_state_info.CopyFrom(tws_info) # wrap transformWithStateInPySparkUdf in a function plan.group_map.func.CopyFrom(self._function.to_plan_udf(session)) return self._with_relations(plan, session) class TransformWithStateInPySpark(BaseTransformWithStateInPySpark): """Logical plan object for a TransformWithStateInPySpark.""" pass # Retaining this to avoid breaking backward compatibility. class TransformWithStateInPandas(BaseTransformWithStateInPySpark): """Logical plan object for a TransformWithStateInPandas.""" pass class PythonUDTF: """Represents a Python user-defined table function.""" def __init__( self, func: Type, return_type: Optional[Union[DataType, str]], eval_type: int, python_ver: str, ) -> None: self._func = func self._name = func.__name__ self._return_type: Optional[DataType] = ( None if return_type is None else UnparsedDataType(return_type) if isinstance(return_type, str) else return_type ) self._eval_type = eval_type self._python_ver = python_ver def to_plan(self, session: "SparkConnectClient") -> proto.PythonUDTF: udtf = proto.PythonUDTF() if self._return_type is not None: udtf.return_type.CopyFrom(pyspark_types_to_proto_types(self._return_type)) udtf.eval_type = self._eval_type try: udtf.command = CloudPickleSerializer().dumps(self._func) except pickle.PicklingError: raise PySparkPicklingError( errorClass="UDTF_SERIALIZATION_ERROR", messageParameters={ "name": self._name, "message": "Please check the stack trace and " "make sure the function is serializable.", }, ) udtf.python_ver = self._python_ver return udtf def __repr__(self) -> str: return ( f"PythonUDTF({self._name}, {self._return_type}, " f"{self._eval_type}, {self._python_ver})" ) class CommonInlineUserDefinedTableFunction(LogicalPlan): """ Logical plan object for a user-defined table function with an inlined defined function body. """ def __init__( self, function_name: str, function: PythonUDTF, deterministic: bool, arguments: Sequence[Expression], ) -> None: super().__init__(None, self._collect_references(arguments)) self._function_name = function_name self._deterministic = deterministic self._arguments = arguments self._function = function def plan(self, session: "SparkConnectClient") -> proto.Relation: plan = self._create_proto_relation() plan.common_inline_user_defined_table_function.function_name = self._function_name plan.common_inline_user_defined_table_function.deterministic = self._deterministic if len(self._arguments) > 0: plan.common_inline_user_defined_table_function.arguments.extend( [arg.to_plan(session) for arg in self._arguments] ) plan.common_inline_user_defined_table_function.python_udtf.CopyFrom( self._function.to_plan(session) ) return self._with_relations(plan, session) def udtf_plan( self, session: "SparkConnectClient" ) -> "proto.CommonInlineUserDefinedTableFunction": """ Compared to `plan`, it returns a `proto.CommonInlineUserDefinedTableFunction` instead of a `proto.Relation`. """ plan = proto.CommonInlineUserDefinedTableFunction() plan.function_name = self._function_name plan.deterministic = self._deterministic if len(self._arguments) > 0: plan.arguments.extend([arg.to_plan(session) for arg in self._arguments]) plan.python_udtf.CopyFrom( cast(proto.PythonUDF, self._function.to_plan(session)) # type: ignore[arg-type] ) return plan def __repr__(self) -> str: return f"{self._function_name}({', '.join([str(arg) for arg in self._arguments])})" class PythonDataSource: """Represents a user-defined Python data source.""" def __init__(self, data_source: Type, python_ver: str): self._data_source = data_source self._python_ver = python_ver def to_plan(self, session: "SparkConnectClient") -> proto.PythonDataSource: ds = proto.PythonDataSource() ds.command = CloudPickleSerializer().dumps(self._data_source) ds.python_ver = self._python_ver return ds class CommonInlineUserDefinedDataSource(LogicalPlan): """Logical plan object for a user-defined data source""" def __init__(self, name: str, data_source: PythonDataSource) -> None: super().__init__(None) self._name = name self._data_source = data_source def plan(self, session: "SparkConnectClient") -> proto.Relation: plan = self._create_proto_relation() plan.common_inline_user_defined_data_source.name = self._name plan.common_inline_user_defined_data_source.python_data_source.CopyFrom( self._data_source.to_plan(session) ) return plan def to_data_source_proto( self, session: "SparkConnectClient" ) -> "proto.CommonInlineUserDefinedDataSource": plan = proto.CommonInlineUserDefinedDataSource() plan.name = self._name plan.python_data_source.CopyFrom(self._data_source.to_plan(session)) return plan class CachedRelation(LogicalPlan): def __init__(self, plan: proto.Relation) -> None: super(CachedRelation, self).__init__(None) self._plan = plan # Update the plan ID based on the incremented counter. self._plan.common.plan_id = self._plan_id def plan(self, session: "SparkConnectClient") -> proto.Relation: return self._plan