python/sedona/spark/core/jvm/config.py (143 lines of code) (raw):

# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. import functools import inspect import logging import os import warnings from re import findall from typing import Optional, Tuple from py4j.protocol import Py4JJavaError from pyspark.sql import SparkSession from sedona.spark.utils.decorators import classproperty string_types = (bytes, str) def is_greater_or_equal_version(version_a: str, version_b: str) -> bool: if all([version_b, version_a]): version_numbers = version_a.split("."), version_b.split(".") if any([version[0] == "" for version in version_numbers]): return False for ver_a, ver_b in zip(*version_numbers): if int(ver_a) > int(ver_b): return True elif int(ver_a) < int(ver_b): return False else: return False return True def since(version: str): def wrapper(function): def applier(*args, **kwargs): sedona_version = SedonaMeta.version if not is_greater_or_equal_version(sedona_version, version): logging.warning( f"This function is not available for {sedona_version}, " f"please use version higher than {version}" ) raise AttributeError(f"Not available before {version} sedona version") result = function(*args, **kwargs) return result return applier return wrapper def deprecated(reason): """ This is a decorator which can be used to mark functions as deprecated. It will result in a warning being emitted when the function is used. """ if isinstance(reason, string_types): # The @deprecated is used with a 'reason'. # # .. code-block:: python # # @deprecated("please, use another function") # def old_function(x, y): # pass def decorator(func1): if inspect.isclass(func1): fmt1 = "Call to deprecated class {name} ({reason})." else: fmt1 = "Call to deprecated function {name} ({reason})." @functools.wraps(func1) def new_func1(*args, **kwargs): warnings.warn( fmt1.format(name=func1.__name__, reason=reason), category=DeprecationWarning, stacklevel=2, ) return func1(*args, **kwargs) return new_func1 return decorator elif inspect.isclass(reason) or inspect.isfunction(reason): # The @deprecated is used without any 'reason'. # # .. code-block:: python # # @deprecated # def old_function(x, y): # pass func2 = reason if inspect.isclass(func2): fmt2 = "Call to deprecated class {name}." else: fmt2 = "Call to deprecated function {name}." @functools.wraps(func2) def new_func2(*args, **kwargs): warnings.warn( fmt2.format(name=func2.__name__), category=DeprecationWarning, stacklevel=2, ) return func2(*args, **kwargs) return new_func2 else: raise TypeError(repr(type(reason))) class SparkJars: @staticmethod def get_used_jars(): spark = SparkSession._instantiatedSession # When deployed normally, Sedona appears in `spark.jars``, when it's submitted # via YARN it's in `spark.yarn.dist.jars` used_jar_files_lookup_with_errors = [ SparkJars.get_spark_java_config(spark, config_id) for config_id in ("spark.jars", "spark.yarn.dist.jars") ] used_jar_files_lookup = [ lookup_result for lookup_result, _ in used_jar_files_lookup_with_errors ] used_jar_files = ( ",".join(jars for jars in used_jar_files_lookup if jars) if any(used_jar_files_lookup) else None ) if not used_jar_files: for _, error in used_jar_files_lookup_with_errors: logging.warning(error) logging.info("Trying to get filenames from the $SPARK_HOME/jars directory") used_jar_files = ",".join( os.listdir(os.path.join(os.environ["SPARK_HOME"], "jars")) ) return used_jar_files @property def jars(self): if not hasattr(self, "__spark_jars"): setattr(self, "__spark_jars", self.get_used_jars()) return getattr(self, "__spark_jars") @staticmethod def get_spark_java_config( spark: SparkSession, value: str ) -> Tuple[Optional[str], Optional[str]]: if spark is not None: spark_conf = spark.conf else: raise TypeError("SparkSession is not initiated") java_spark_conf = spark_conf._jconf used_jar_files = None error_message = None try: used_jar_files = java_spark_conf.get(value) except Py4JJavaError: error_message = f"Didn't find the value of {value} from SparkConf" logging.info(error_message) return used_jar_files, error_message class SedonaMeta: @classmethod def get_version(cls, spark_jars: str) -> Optional[str]: # Find Spark version, Scala version and Sedona version. versions = findall( r"sedona-(?:python-adapter|spark-shaded|spark)-([^,\n]{3})_([^,\n]{4})-([^,\n]{5})", spark_jars, ) print(versions) try: sedona_version = versions[0][2] except IndexError: sedona_version = None return sedona_version @classproperty def version(cls): spark_jars = SparkJars.get_used_jars() if not hasattr(cls, "__version"): setattr(cls, "__version", cls.get_version(spark_jars)) return getattr(cls, "__version") if __name__ == "__main__": assert not is_greater_or_equal_version("1.1.5", "1.2.0") assert is_greater_or_equal_version("1.2.0", "1.1.5") assert is_greater_or_equal_version("1.3.5", "1.2.0") assert not is_greater_or_equal_version("", "1.2.0") assert not is_greater_or_equal_version("1.3.5", "")