python/pyspark/testing/connectutils.py (177 lines of code) (raw):
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import shutil
import tempfile
import typing
import os
import functools
import unittest
import uuid
import contextlib
from pyspark.testing import (
grpc_requirement_message,
have_grpc,
grpc_status_requirement_message,
have_grpc_status,
googleapis_common_protos_requirement_message,
have_googleapis_common_protos,
graphviz_requirement_message,
have_graphviz,
connect_requirement_message,
should_test_connect,
)
from pyspark import Row, SparkConf
from pyspark.util import is_remote_only
from pyspark.testing.utils import PySparkErrorTestUtils
from pyspark.testing.sqlutils import (
have_pandas,
pandas_requirement_message,
pyarrow_requirement_message,
SQLTestUtils,
)
from pyspark.sql.session import SparkSession as PySparkSession
if should_test_connect:
from pyspark.sql.connect.dataframe import DataFrame
from pyspark.sql.connect.plan import Read, Range, SQL, LogicalPlan
from pyspark.sql.connect.session import SparkSession
class MockRemoteSession:
def __init__(self):
self.hooks = {}
self.session_id = str(uuid.uuid4())
self.is_mock_session = True
def set_hook(self, name, hook):
self.hooks[name] = hook
def drop_hook(self, name):
self.hooks.pop(name)
def __getattr__(self, item):
if item not in self.hooks:
raise LookupError(f"{item} is not defined as a method hook in MockRemoteSession")
return functools.partial(self.hooks[item])
@unittest.skipIf(not should_test_connect, connect_requirement_message)
class PlanOnlyTestFixture(unittest.TestCase, PySparkErrorTestUtils):
if should_test_connect:
class MockDF(DataFrame):
"""Helper class that must only be used for the mock plan tests."""
def __init__(self, plan: LogicalPlan, session: SparkSession):
super().__init__(plan, session)
def __getattr__(self, name):
"""All attributes are resolved to columns, because none really exist in the
mocked DataFrame."""
return self[name]
@classmethod
def _read_table(cls, table_name):
return cls._df_mock(Read(table_name))
@classmethod
def _udf_mock(cls, *args, **kwargs):
return "internal_name"
@classmethod
def _df_mock(cls, plan: LogicalPlan) -> MockDF:
return PlanOnlyTestFixture.MockDF(plan, cls.connect)
@classmethod
def _session_range(
cls,
start,
end,
step=1,
num_partitions=None,
):
return cls._df_mock(Range(start, end, step, num_partitions))
@classmethod
def _session_sql(cls, query):
return cls._df_mock(SQL(query))
if have_pandas:
@classmethod
def _with_plan(cls, plan):
return cls._df_mock(plan)
@classmethod
def setUpClass(cls):
cls.connect = MockRemoteSession()
cls.session = SparkSession.builder.remote().getOrCreate()
cls.tbl_name = "test_connect_plan_only_table_1"
cls.connect.set_hook("readTable", cls._read_table)
cls.connect.set_hook("range", cls._session_range)
cls.connect.set_hook("sql", cls._session_sql)
cls.connect.set_hook("with_plan", cls._with_plan)
@classmethod
def tearDownClass(cls):
cls.connect.drop_hook("readTable")
cls.connect.drop_hook("range")
cls.connect.drop_hook("sql")
cls.connect.drop_hook("with_plan")
@unittest.skipIf(not should_test_connect, connect_requirement_message)
class ReusedConnectTestCase(unittest.TestCase, SQLTestUtils, PySparkErrorTestUtils):
"""
Spark Connect version of :class:`pyspark.testing.sqlutils.ReusedSQLTestCase`.
"""
@classmethod
def conf(cls):
"""
Override this in subclasses to supply a more specific conf
"""
conf = SparkConf(loadDefaults=False)
# Make the server terminate reattachable streams every 1 second and 123 bytes,
# to make the tests exercise reattach.
if conf._jconf is not None:
conf._jconf.remove("spark.master")
conf.set("spark.connect.execute.reattachable.senderMaxStreamDuration", "1s")
conf.set("spark.connect.execute.reattachable.senderMaxStreamSize", "123")
# Set a static token for all tests so the parallelism doesn't overwrite each
# tests' environment variables
conf.set("spark.connect.authenticate.token", "deadbeef")
# Disable ml cache offloading,
# offloading hasn't supported APIs like model.summary / model.evaluate
conf.set("spark.connect.session.connectML.mlCache.memoryControl.enabled", "false")
return conf
@classmethod
def master(cls):
return os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[4]")
@classmethod
def setUpClass(cls):
cls.spark = (
PySparkSession.builder.config(conf=cls.conf())
.appName(cls.__name__)
.remote(cls.master())
.getOrCreate()
)
cls._legacy_sc = None
if not is_remote_only():
cls._legacy_sc = PySparkSession._instantiatedSession._sc
cls.tempdir = tempfile.NamedTemporaryFile(delete=False)
os.unlink(cls.tempdir.name)
cls.testData = [Row(key=i, value=str(i)) for i in range(100)]
cls.df = cls.spark.createDataFrame(cls.testData)
@classmethod
def tearDownClass(cls):
shutil.rmtree(cls.tempdir.name, ignore_errors=True)
cls.spark.stop()
def setUp(self) -> None:
# force to clean up the ML cache before each test
self.spark.client._cleanup_ml_cache()
def test_assert_remote_mode(self):
from pyspark.sql import is_remote
self.assertTrue(is_remote())
def quiet(self):
from pyspark.testing.utils import QuietTest
if self._legacy_sc is not None:
return QuietTest(self._legacy_sc)
else:
return contextlib.nullcontext()
@unittest.skipIf(
not should_test_connect or is_remote_only(),
connect_requirement_message or "Requires JVM access",
)
class ReusedMixedTestCase(ReusedConnectTestCase, SQLTestUtils):
@classmethod
def setUpClass(cls):
super(ReusedMixedTestCase, cls).setUpClass()
# Disable the shared namespace so pyspark.sql.functions, etc point the regular
# PySpark libraries.
os.environ["PYSPARK_NO_NAMESPACE_SHARE"] = "1"
cls.connect = cls.spark # Switch Spark Connect session and regular PySpark session.
cls.spark = PySparkSession._instantiatedSession
assert cls.spark is not None
@classmethod
def tearDownClass(cls):
try:
# Stopping Spark Connect closes the session in JVM at the server.
cls.spark = cls.connect
del os.environ["PYSPARK_NO_NAMESPACE_SHARE"]
finally:
super(ReusedMixedTestCase, cls).tearDownClass()
def setUp(self) -> None:
# force to clean up the ML cache before each test
self.connect.client._cleanup_ml_cache()
def compare_by_show(self, df1, df2, n: int = 20, truncate: int = 20):
from pyspark.sql.classic.dataframe import DataFrame as SDF
from pyspark.sql.connect.dataframe import DataFrame as CDF
assert isinstance(df1, (SDF, CDF))
if isinstance(df1, SDF):
str1 = df1._jdf.showString(n, truncate, False)
else:
str1 = df1._show_string(n, truncate, False)
assert isinstance(df2, (SDF, CDF))
if isinstance(df2, SDF):
str2 = df2._jdf.showString(n, truncate, False)
else:
str2 = df2._show_string(n, truncate, False)
self.assertEqual(str1, str2)
def test_assert_remote_mode(self):
# no need to test this in mixed mode
pass