python/src/datasketches_spark/common.py (57 lines of code) (raw):
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from pyspark import SparkContext
from pyspark.sql.column import Column, _to_java_column, _to_seq, _create_column_from_literal
from pyspark.sql.utils import try_remote_functions
from py4j.java_gateway import JavaClass
from typing import Any, TypeVar, Union, Callable
from functools import lru_cache
from ._version import __version__
import os
from importlib.resources import files, as_file
ColumnOrName = Union[Column, str]
ColumnOrName_ = TypeVar("ColumnOrName_", bound=ColumnOrName)
def get_dependency_path(filename: str) -> str:
"""
Returns a list of absolute paths to the specified file,\n
it is included in the package's /deps subdir.
:param filename: Name of file to retrieve
:return: Absolute paths to filename
:exception FileNotFoundError: If a file is not found in the package
"""
try:
with as_file(files("datasketches_spark.deps") / filename) as file_path:
return str(file_path)
except FileNotFoundError:
raise FileNotFoundError(f"File {filename} not found in datasketches_spark.deps")
def get_dependency_classpath() -> list[str]:
"""
Returns a list of absolute paths to the jar files included in the package.\n
Assumes that the jar files are located in the package's /deps subdir.
"""
# we need whatever is listed in dependencies.txt as well as
# datasketches-spark_<scala_veersion>-<ds-spark_version>.jar
jar_files = []
with (files("datasketches_spark.deps") / "dependencies.txt").open('r') as dependencies:
for dep in dependencies:
jar_files.append(dep.strip())
ds_spark_version = __version__
jar_files.append(f"datasketches-spark_{os.environ.get('SCALA_VERSION', '2.12')}-{ds_spark_version}.jar")
return ":".join([get_dependency_path(jar) for jar in jar_files])
# Since we have functions from different packages, rather than the
# single 16k+ line functions class in core Spark, we'll have each
# sketch family grab its own functions class from the JVM and cache it
def _get_jvm_class(name: str) -> JavaClass:
"""
Retrieves JVM class identified by name from
Java gateway associated with the current active Spark context.
"""
assert SparkContext._active_spark_context is not None
return getattr(SparkContext._active_spark_context._jvm, name)
@lru_cache
def _get_jvm_function(cls: JavaClass, name: str) -> Callable:
"""
Retrieves JVM function identified by name from
Java gateway associated with sc.
"""
assert cls is not None
return getattr(cls, name)
def _invoke_function(cls: JavaClass, name: str, *args: Any) -> Column:
"""
Invokes JVM function identified by name with args
and wraps the result with :class:`~pyspark.sql.Column`.
"""
#assert SparkContext._active_spark_context is not None
assert cls is not None
jf = _get_jvm_function(cls, name)
return Column(jf(*args))
def _invoke_function_over_columns(cls: JavaClass, name: str, *cols: "ColumnOrName") -> Column:
"""
Invokes n-ary JVM function identified by name
and wraps the result with :class:`~pyspark.sql.Column`.
"""
return _invoke_function(cls, name, *(_to_java_column(col) for col in cols))
# lazy init so we know the SparkContext exists first
_spark_functions_class: JavaClass = None
def _get_spark_functions_class() -> JavaClass:
global _spark_functions_class
if _spark_functions_class is None:
_spark_functions_class = _get_jvm_class("org.apache.spark.sql.functions")
return _spark_functions_class
# borrowed from PySpark
def _array_as_java_column(data: Union[list, tuple]) -> Column:
"""
Converts a Python list or tuple to a Spark DataFrame column.
"""
sc = SparkContext._active_spark_context
col = _to_seq(sc, [_create_column_from_literal(x) for x in data])
return _invoke_function(_get_spark_functions_class(), "array", col)._jc
#return _invoke_function(_get_spark_functions_class(), "array", _to_seq(sc, [_create_column_from_literal(x) for x in data]))._jc
_common_functions_class: JavaClass = None
def _get_common_functions_class() -> JavaClass:
global _common_functions_class
if _common_functions_class is None:
_common_functions_class = _get_jvm_class("org.apache.spark.sql.datasketches.common.functions")
return _common_functions_class
@try_remote_functions
def cast_as_binary(col: "ColumnOrName") -> Column:
return _invoke_function_over_columns(_get_common_functions_class(), "cast_as_binary", col)