tools/scripts/gen-function-support-docs.py (896 lines of code) (raw):

# # Licensed to the Apache Software Foundation (ASF) under one or more # contributor license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright ownership. # The ASF licenses this file to You under the Apache License, Version 2.0 # (the "License"); you may not use this file except in compliance with # the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import os import findspark import argparse import logging import re import subprocess import tabulate # Fetched from org.apache.spark.sql.catalyst.analysis.FunctionRegistry. SPARK35_EXPRESSION_MAPPINGS = ''' // misc non-aggregate functions expression[Abs]("abs"), expression[Coalesce]("coalesce"), expressionBuilder("explode", ExplodeExpressionBuilder), expressionGeneratorBuilderOuter("explode_outer", ExplodeExpressionBuilder), expression[Greatest]("greatest"), expression[If]("if"), expression[Inline]("inline"), expressionGeneratorOuter[Inline]("inline_outer"), expression[IsNaN]("isnan"), expression[Nvl]("ifnull", setAlias = true), expression[IsNull]("isnull"), expression[IsNotNull]("isnotnull"), expression[Least]("least"), expression[NaNvl]("nanvl"), expression[NullIf]("nullif"), expression[Nvl]("nvl"), expression[Nvl2]("nvl2"), expression[PosExplode]("posexplode"), expressionGeneratorOuter[PosExplode]("posexplode_outer"), expression[Rand]("rand"), expression[Rand]("random", true), expression[Randn]("randn"), expression[Stack]("stack"), expression[CaseWhen]("when"), // math functions expression[Acos]("acos"), expression[Acosh]("acosh"), expression[Asin]("asin"), expression[Asinh]("asinh"), expression[Atan]("atan"), expression[Atan2]("atan2"), expression[Atanh]("atanh"), expression[Bin]("bin"), expression[BRound]("bround"), expression[Cbrt]("cbrt"), expressionBuilder("ceil", CeilExpressionBuilder), expressionBuilder("ceiling", CeilExpressionBuilder, true), expression[Cos]("cos"), expression[Sec]("sec"), expression[Cosh]("cosh"), expression[Conv]("conv"), expression[ToDegrees]("degrees"), expression[EulerNumber]("e"), expression[Exp]("exp"), expression[Expm1]("expm1"), expressionBuilder("floor", FloorExpressionBuilder), expression[Factorial]("factorial"), expression[Hex]("hex"), expression[Hypot]("hypot"), expression[Logarithm]("log"), expression[Log10]("log10"), expression[Log1p]("log1p"), expression[Log2]("log2"), expression[Log]("ln"), expression[Remainder]("mod", true), expression[UnaryMinus]("negative", true), expression[Pi]("pi"), expression[Pmod]("pmod"), expression[UnaryPositive]("positive"), expression[Pow]("pow", true), expression[Pow]("power"), expression[ToRadians]("radians"), expression[Rint]("rint"), expression[Round]("round"), expression[ShiftLeft]("shiftleft"), expression[ShiftRight]("shiftright"), expression[ShiftRightUnsigned]("shiftrightunsigned"), expression[Signum]("sign", true), expression[Signum]("signum"), expression[Sin]("sin"), expression[Csc]("csc"), expression[Sinh]("sinh"), expression[StringToMap]("str_to_map"), expression[Sqrt]("sqrt"), expression[Tan]("tan"), expression[Cot]("cot"), expression[Tanh]("tanh"), expression[WidthBucket]("width_bucket"), expression[Add]("+"), expression[Subtract]("-"), expression[Multiply]("*"), expression[Divide]("/"), expression[IntegralDivide]("div"), expression[Remainder]("%"), // "try_*" function which always return Null instead of runtime error. expression[TryAdd]("try_add"), expression[TryDivide]("try_divide"), expression[TrySubtract]("try_subtract"), expression[TryMultiply]("try_multiply"), expression[TryElementAt]("try_element_at"), expressionBuilder("try_avg", TryAverageExpressionBuilder, setAlias = true), expressionBuilder("try_sum", TrySumExpressionBuilder, setAlias = true), expression[TryToBinary]("try_to_binary"), expressionBuilder("try_to_timestamp", TryToTimestampExpressionBuilder, setAlias = true), expression[TryAesDecrypt]("try_aes_decrypt"), // aggregate functions expression[HyperLogLogPlusPlus]("approx_count_distinct"), expression[Average]("avg"), expression[Corr]("corr"), expression[Count]("count"), expression[CountIf]("count_if"), expression[CovPopulation]("covar_pop"), expression[CovSample]("covar_samp"), expression[First]("first"), expression[First]("first_value", true), expression[AnyValue]("any_value"), expression[Kurtosis]("kurtosis"), expression[Last]("last"), expression[Last]("last_value", true), expression[Max]("max"), expression[MaxBy]("max_by"), expression[Average]("mean", true), expression[Min]("min"), expression[MinBy]("min_by"), expression[Percentile]("percentile"), expression[Median]("median"), expression[Skewness]("skewness"), expression[ApproximatePercentile]("percentile_approx"), expression[ApproximatePercentile]("approx_percentile", true), expression[HistogramNumeric]("histogram_numeric"), expression[StddevSamp]("std", true), expression[StddevSamp]("stddev", true), expression[StddevPop]("stddev_pop"), expression[StddevSamp]("stddev_samp"), expression[Sum]("sum"), expression[VarianceSamp]("variance", true), expression[VariancePop]("var_pop"), expression[VarianceSamp]("var_samp"), expression[CollectList]("collect_list"), expression[CollectList]("array_agg", true, Some("3.3.0")), expression[CollectSet]("collect_set"), expressionBuilder("count_min_sketch", CountMinSketchAggExpressionBuilder), expression[BoolAnd]("every", true), expression[BoolAnd]("bool_and"), expression[BoolOr]("any", true), expression[BoolOr]("some", true), expression[BoolOr]("bool_or"), expression[RegrCount]("regr_count"), expression[RegrAvgX]("regr_avgx"), expression[RegrAvgY]("regr_avgy"), expression[RegrR2]("regr_r2"), expression[RegrSXX]("regr_sxx"), expression[RegrSXY]("regr_sxy"), expression[RegrSYY]("regr_syy"), expression[RegrSlope]("regr_slope"), expression[RegrIntercept]("regr_intercept"), expression[Mode]("mode"), expression[HllSketchAgg]("hll_sketch_agg"), expression[HllUnionAgg]("hll_union_agg"), // string functions expression[Ascii]("ascii"), expression[Chr]("char", true), expression[Chr]("chr"), expressionBuilder("contains", ContainsExpressionBuilder), expressionBuilder("startswith", StartsWithExpressionBuilder), expressionBuilder("endswith", EndsWithExpressionBuilder), expression[Base64]("base64"), expression[BitLength]("bit_length"), expression[Length]("char_length", true), expression[Length]("character_length", true), expression[ConcatWs]("concat_ws"), expression[Decode]("decode"), expression[Elt]("elt"), expression[Encode]("encode"), expression[FindInSet]("find_in_set"), expression[FormatNumber]("format_number"), expression[FormatString]("format_string"), expression[ToNumber]("to_number"), expression[TryToNumber]("try_to_number"), expression[ToCharacter]("to_char"), expression[ToCharacter]("to_varchar", setAlias = true, Some("3.5.0")), expression[GetJsonObject]("get_json_object"), expression[InitCap]("initcap"), expression[StringInstr]("instr"), expression[Lower]("lcase", true), expression[Length]("length"), expression[Length]("len", setAlias = true, Some("3.4.0")), expression[Levenshtein]("levenshtein"), expression[Luhncheck]("luhn_check"), expression[Like]("like"), expression[ILike]("ilike"), expression[Lower]("lower"), expression[OctetLength]("octet_length"), expression[StringLocate]("locate"), expressionBuilder("lpad", LPadExpressionBuilder), expression[StringTrimLeft]("ltrim"), expression[JsonTuple]("json_tuple"), expression[StringLocate]("position", true), expression[FormatString]("printf", true), expression[RegExpExtract]("regexp_extract"), expression[RegExpExtractAll]("regexp_extract_all"), expression[RegExpReplace]("regexp_replace"), expression[StringRepeat]("repeat"), expression[StringReplace]("replace"), expression[Overlay]("overlay"), expression[RLike]("rlike"), expression[RLike]("regexp_like", true, Some("3.2.0")), expression[RLike]("regexp", true, Some("3.2.0")), expressionBuilder("rpad", RPadExpressionBuilder), expression[StringTrimRight]("rtrim"), expression[Sentences]("sentences"), expression[SoundEx]("soundex"), expression[StringSpace]("space"), expression[StringSplit]("split"), expression[SplitPart]("split_part"), expression[Substring]("substr", true), expression[Substring]("substring"), expression[Left]("left"), expression[Right]("right"), expression[SubstringIndex]("substring_index"), expression[StringTranslate]("translate"), expression[StringTrim]("trim"), expression[StringTrimBoth]("btrim"), expression[Upper]("ucase", true), expression[UnBase64]("unbase64"), expression[Unhex]("unhex"), expression[Upper]("upper"), expression[XPathList]("xpath"), expression[XPathBoolean]("xpath_boolean"), expression[XPathDouble]("xpath_double"), expression[XPathDouble]("xpath_number", true), expression[XPathFloat]("xpath_float"), expression[XPathInt]("xpath_int"), expression[XPathLong]("xpath_long"), expression[XPathShort]("xpath_short"), expression[XPathString]("xpath_string"), expression[RegExpCount]("regexp_count"), expression[RegExpSubStr]("regexp_substr"), expression[RegExpInStr]("regexp_instr"), // url functions expression[UrlEncode]("url_encode"), expression[UrlDecode]("url_decode"), expression[ParseUrl]("parse_url"), // datetime functions expression[AddMonths]("add_months"), expression[CurrentDate]("current_date"), expressionBuilder("curdate", CurDateExpressionBuilder, setAlias = true), expression[CurrentTimestamp]("current_timestamp"), expression[CurrentTimeZone]("current_timezone"), expression[LocalTimestamp]("localtimestamp"), expression[DateDiff]("datediff"), expression[DateDiff]("date_diff", setAlias = true, Some("3.4.0")), expression[DateAdd]("date_add"), expression[DateAdd]("dateadd", setAlias = true, Some("3.4.0")), expression[DateFormatClass]("date_format"), expression[DateSub]("date_sub"), expression[DayOfMonth]("day", true), expression[DayOfYear]("dayofyear"), expression[DayOfMonth]("dayofmonth"), expression[FromUnixTime]("from_unixtime"), expression[FromUTCTimestamp]("from_utc_timestamp"), expression[Hour]("hour"), expression[LastDay]("last_day"), expression[Minute]("minute"), expression[Month]("month"), expression[MonthsBetween]("months_between"), expression[NextDay]("next_day"), expression[Now]("now"), expression[Quarter]("quarter"), expression[Second]("second"), expression[ParseToTimestamp]("to_timestamp"), expression[ParseToDate]("to_date"), expression[ToBinary]("to_binary"), expression[ToUnixTimestamp]("to_unix_timestamp"), expression[ToUTCTimestamp]("to_utc_timestamp"), // We keep the 2 expression builders below to have different function docs. expressionBuilder("to_timestamp_ntz", ParseToTimestampNTZExpressionBuilder, setAlias = true), expressionBuilder("to_timestamp_ltz", ParseToTimestampLTZExpressionBuilder, setAlias = true), expression[TruncDate]("trunc"), expression[TruncTimestamp]("date_trunc"), expression[UnixTimestamp]("unix_timestamp"), expression[DayOfWeek]("dayofweek"), expression[WeekDay]("weekday"), expression[WeekOfYear]("weekofyear"), expression[Year]("year"), expression[TimeWindow]("window"), expression[SessionWindow]("session_window"), expression[WindowTime]("window_time"), expression[MakeDate]("make_date"), expression[MakeTimestamp]("make_timestamp"), // We keep the 2 expression builders below to have different function docs. expressionBuilder("make_timestamp_ntz", MakeTimestampNTZExpressionBuilder, setAlias = true), expressionBuilder("make_timestamp_ltz", MakeTimestampLTZExpressionBuilder, setAlias = true), expression[MakeInterval]("make_interval"), expression[MakeDTInterval]("make_dt_interval"), expression[MakeYMInterval]("make_ym_interval"), expression[Extract]("extract"), // We keep the `DatePartExpressionBuilder` to have different function docs. expressionBuilder("date_part", DatePartExpressionBuilder, setAlias = true), expressionBuilder("datepart", DatePartExpressionBuilder, setAlias = true, Some("3.4.0")), expression[DateFromUnixDate]("date_from_unix_date"), expression[UnixDate]("unix_date"), expression[SecondsToTimestamp]("timestamp_seconds"), expression[MillisToTimestamp]("timestamp_millis"), expression[MicrosToTimestamp]("timestamp_micros"), expression[UnixSeconds]("unix_seconds"), expression[UnixMillis]("unix_millis"), expression[UnixMicros]("unix_micros"), expression[ConvertTimezone]("convert_timezone"), // collection functions expression[CreateArray]("array"), expression[ArrayContains]("array_contains"), expression[ArraysOverlap]("arrays_overlap"), expression[ArrayInsert]("array_insert"), expression[ArrayIntersect]("array_intersect"), expression[ArrayJoin]("array_join"), expression[ArrayPosition]("array_position"), expression[ArraySize]("array_size"), expression[ArraySort]("array_sort"), expression[ArrayExcept]("array_except"), expression[ArrayUnion]("array_union"), expression[ArrayCompact]("array_compact"), expression[CreateMap]("map"), expression[CreateNamedStruct]("named_struct"), expression[ElementAt]("element_at"), expression[MapContainsKey]("map_contains_key"), expression[MapFromArrays]("map_from_arrays"), expression[MapKeys]("map_keys"), expression[MapValues]("map_values"), expression[MapEntries]("map_entries"), expression[MapFromEntries]("map_from_entries"), expression[MapConcat]("map_concat"), expression[Size]("size"), expression[Slice]("slice"), expression[Size]("cardinality", true), expression[ArraysZip]("arrays_zip"), expression[SortArray]("sort_array"), expression[Shuffle]("shuffle"), expression[ArrayMin]("array_min"), expression[ArrayMax]("array_max"), expression[ArrayAppend]("array_append"), expression[Reverse]("reverse"), expression[Concat]("concat"), expression[Flatten]("flatten"), expression[Sequence]("sequence"), expression[ArrayRepeat]("array_repeat"), expression[ArrayRemove]("array_remove"), expression[ArrayPrepend]("array_prepend"), expression[ArrayDistinct]("array_distinct"), expression[ArrayTransform]("transform"), expression[MapFilter]("map_filter"), expression[ArrayFilter]("filter"), expression[ArrayExists]("exists"), expression[ArrayForAll]("forall"), expression[ArrayAggregate]("aggregate"), expression[ArrayAggregate]("reduce", setAlias = true, Some("3.4.0")), expression[TransformValues]("transform_values"), expression[TransformKeys]("transform_keys"), expression[MapZipWith]("map_zip_with"), expression[ZipWith]("zip_with"), expression[Get]("get"), CreateStruct.registryEntry, // misc functions expression[AssertTrue]("assert_true"), expression[RaiseError]("raise_error"), expression[Crc32]("crc32"), expression[Md5]("md5"), expression[Uuid]("uuid"), expression[Murmur3Hash]("hash"), expression[XxHash64]("xxhash64"), expression[Sha1]("sha", true), expression[Sha1]("sha1"), expression[Sha2]("sha2"), expression[AesEncrypt]("aes_encrypt"), expression[AesDecrypt]("aes_decrypt"), expression[SparkPartitionID]("spark_partition_id"), expression[InputFileName]("input_file_name"), expression[InputFileBlockStart]("input_file_block_start"), expression[InputFileBlockLength]("input_file_block_length"), expression[MonotonicallyIncreasingID]("monotonically_increasing_id"), expression[CurrentDatabase]("current_database"), expression[CurrentDatabase]("current_schema", true), expression[CurrentCatalog]("current_catalog"), expression[CurrentUser]("current_user"), expression[CurrentUser]("user", setAlias = true), expression[CallMethodViaReflection]("reflect"), expression[CallMethodViaReflection]("java_method", true), expression[SparkVersion]("version"), expression[TypeOf]("typeof"), expression[EqualNull]("equal_null"), expression[HllSketchEstimate]("hll_sketch_estimate"), expression[HllUnion]("hll_union"), // grouping sets expression[Grouping]("grouping"), expression[GroupingID]("grouping_id"), // window functions expression[Lead]("lead"), expression[Lag]("lag"), expression[RowNumber]("row_number"), expression[CumeDist]("cume_dist"), expression[NthValue]("nth_value"), expression[NTile]("ntile"), expression[Rank]("rank"), expression[DenseRank]("dense_rank"), expression[PercentRank]("percent_rank"), // predicates expression[And]("and"), expression[In]("in"), expression[Not]("not"), expression[Or]("or"), // comparison operators expression[EqualNullSafe]("<=>"), expression[EqualTo]("="), expression[EqualTo]("=="), expression[GreaterThan](">"), expression[GreaterThanOrEqual](">="), expression[LessThan]("<"), expression[LessThanOrEqual]("<="), expression[Not]("!"), // bitwise expression[BitwiseAnd]("&"), expression[BitwiseNot]("~"), expression[BitwiseOr]("|"), expression[BitwiseXor]("^"), expression[BitwiseCount]("bit_count"), expression[BitAndAgg]("bit_and"), expression[BitOrAgg]("bit_or"), expression[BitXorAgg]("bit_xor"), expression[BitwiseGet]("bit_get"), expression[BitwiseGet]("getbit", true), // bitmap functions and aggregates expression[BitmapBucketNumber]("bitmap_bucket_number"), expression[BitmapBitPosition]("bitmap_bit_position"), expression[BitmapConstructAgg]("bitmap_construct_agg"), expression[BitmapCount]("bitmap_count"), expression[BitmapOrAgg]("bitmap_or_agg"), // json expression[StructsToJson]("to_json"), expression[JsonToStructs]("from_json"), expression[SchemaOfJson]("schema_of_json"), expression[LengthOfJsonArray]("json_array_length"), expression[JsonObjectKeys]("json_object_keys"), // cast expression[Cast]("cast"), // Cast aliases (SPARK-16730) castAlias("boolean", BooleanType), castAlias("tinyint", ByteType), castAlias("smallint", ShortType), castAlias("int", IntegerType), castAlias("bigint", LongType), castAlias("float", FloatType), castAlias("double", DoubleType), castAlias("decimal", DecimalType.USER_DEFAULT), castAlias("date", DateType), castAlias("timestamp", TimestampType), castAlias("binary", BinaryType), castAlias("string", StringType), // mask functions expressionBuilder("mask", MaskExpressionBuilder), // csv expression[CsvToStructs]("from_csv"), expression[SchemaOfCsv]("schema_of_csv"), expression[StructsToCsv]("to_csv") ''' FUNCTION_CATEGORIES = ['scalar', 'aggregate', 'window', 'generator'] STATIC_INVOKES = { "luhn_check": ("org.apache.spark.sql.catalyst.expressions.ExpressionImplUtils", "isLuhnNumber"), "base64": ("org.apache.spark.sql.catalyst.expressions.Base64", "encode"), "contains": ("org.apache.spark.unsafe.array.ByteArrayMethods", "contains"), "startsWith": ("org.apache.spark.unsafe.array.ByteArrayMethods", "startsWith"), "endsWith": ("org.apache.spark.unsafe.array.ByteArrayMethods", "endsWith"), "lpad": ("org.apache.spark.unsafe.array.ByteArrayMethods", "lpad"), "rpad": ("org.apache.spark.unsafe.array.ByteArrayMethods", "rpad"), } # Known Restrictions in Gluten. LOOKAROUND_UNSUPPORTED = 'Lookaround unsupported' BINARY_TYPE_UNSUPPORTED = 'BinaryType unsupported' GLUTEN_RESTRICTIONS = { 'scalar': { 'regexp': LOOKAROUND_UNSUPPORTED, 'regexp_like': LOOKAROUND_UNSUPPORTED, 'rlike': LOOKAROUND_UNSUPPORTED, 'regexp_extract': LOOKAROUND_UNSUPPORTED, 'regexp_extract_all': LOOKAROUND_UNSUPPORTED, 'regexp_replace': LOOKAROUND_UNSUPPORTED, 'contains': BINARY_TYPE_UNSUPPORTED, 'startswith': BINARY_TYPE_UNSUPPORTED, 'endswith': BINARY_TYPE_UNSUPPORTED, 'lpad': BINARY_TYPE_UNSUPPORTED, 'rpad': BINARY_TYPE_UNSUPPORTED }, 'aggregate': {}, 'window': {}, 'generator': {} } SPARK_FUNCTION_GROUPS = { "agg_funcs", "array_funcs", "datetime_funcs", "json_funcs", "map_funcs", "window_funcs", "math_funcs", "conditional_funcs", "generator_funcs", "predicate_funcs", "string_funcs", "misc_funcs", "bitwise_funcs", "conversion_funcs", "csv_funcs", } # Function groups that are not listed in the spark doc. spark_function_missing_groups = { 'collection_funcs', 'hash_funcs', 'lambda_funcs', 'struct_funcs', 'url_funcs', 'xml_funcs' } SPARK_FUNCTION_GROUPS = SPARK_FUNCTION_GROUPS.union(spark_function_missing_groups) SCALAR_FUNCTION_GROUPS = {'array_funcs': "Array Functions", 'map_funcs': "Map Functions", 'datetime_funcs': "Date and Timestamp Functions", 'json_funcs': "JSON Functions", 'math_funcs': "Mathematical Functions", 'string_funcs': "String Functions", 'bitwise_funcs': "Bitwise Functions", 'conversion_funcs': "Conversion Functions", 'conditional_funcs': "Conditional Functions", 'predicate_funcs': "Predicate Functions", 'csv_funcs': "Csv Functions", 'misc_funcs': "Misc Functions", 'collection_funcs': "Collection Functions", 'hash_funcs': "Hash Functions", 'lambda_funcs': "Lambda Functions", 'struct_funcs': "Struct Functions", 'url_funcs': "URL Functions", 'xml_funcs': "XML Functions"} FUNCTION_GROUPS = {'scalar': SCALAR_FUNCTION_GROUPS, 'aggregate': {'agg_funcs': 'Aggregate Functions'}, 'window': {'window_funcs': 'Window Functions'}, 'generator': {'generator_funcs': "Generator Functions"}} FUNCTION_SUITE_PACKAGE = 'org.apache.spark.sql.' FUNCTION_SUITES = { 'scalar': {'GlutenSQLQueryTestSuite', 'GlutenDataFrameSessionWindowingSuite', 'GlutenDataFrameTimeWindowingSuite', 'GlutenMiscFunctionsSuite', 'GlutenDateFunctionsSuite', 'GlutenDataFrameFunctionsSuite', 'GlutenBitmapExpressionsQuerySuite', 'GlutenMathFunctionsSuite', 'GlutenColumnExpressionSuite', 'GlutenStringFunctionsSuite', 'GlutenXPathFunctionsSuite', 'GlutenSQLQuerySuite'}, 'aggregate': {'GlutenSQLQueryTestSuite', 'GlutenApproxCountDistinctForIntervalsQuerySuite', 'GlutenBitmapExpressionsQuerySuite', 'GlutenDataFrameAggregateSuite'}, # All window functions are supported. 'window': {}, 'generator': {'GlutenGeneratorFunctionSuite'} } def create_spark_function_map(): exprs = list(map(lambda x: x if x[-1] != ',' else x[:-1], map(lambda x: x.strip(), filter(lambda x: 'expression' in x, SPARK35_EXPRESSION_MAPPINGS.split('\n'))))) func_map = {} expression_pattern = 'expression[GeneratorOuter]*\[([\w0-9]+)\]\("([^\s]+)".*' expression_builder_pattern = 'expression[Generator]*Builder[Outer]*\("([^\s]+)", ([\w0-9]+).*' for r in exprs: match = re.search(expression_pattern, r) if match: class_name = match.group(1) function_name = match.group(2) func_map[function_name] = class_name else: match = re.search(expression_builder_pattern, r) if match: class_name = match.group(2) function_name = match.group(1) func_map[function_name] = class_name else: logging.log(logging.WARNING, f'Could not parse expression: {r}') return func_map def generate_function_list(): jinfos = jvm.org.apache.spark.sql.api.python.PythonSQLUtils.listBuiltinFunctionInfos() infos = [["!=", '', 'predicate_funcs'], ["<>", "", "predicate_funcs"], ['between', '', 'predicate_funcs'], ['case', '', 'predicate_funcs'], ["||", '', 'misc_funcs']] for jinfo in filter(lambda x: x.getGroup() in SPARK_FUNCTION_GROUPS, jinfos): infos.append([jinfo.getName(), jinfo.getClassName().split('.')[-1], jinfo.getGroup()]) for info in infos: name, classname, groupname = info if (name == "raise_error"): continue all_function_names.append(name) classname_to_function[classname] = name function_to_classname[name] = classname if groupname not in group_functions: group_functions[groupname] = [] group_functions[groupname].append(name) if groupname in SCALAR_FUNCTION_GROUPS: functions['scalar'].add(name) elif groupname == 'agg_funcs': functions['aggregate'].add(name) elif groupname == 'window_funcs': functions['window'].add(name) elif groupname == 'generator_funcs': functions['generator'].add(name) else: logging.log(logging.WARNING, f"No matching group name for function {name}: " + groupname) def parse_logs(log_file): # "<>", "!=", "between", "case", and "||" are hard coded in spark and there's no corresponding functions. builtin_functions = ['<>', '!=', 'between', 'case', '||'] function_names = all_function_names.copy() for f in builtin_functions: function_names.remove(f) print(function_names) generator_functions = ['explode', 'explode_outer', 'inline', 'inline_outer', 'posexplode', 'posexplode_outer', 'stack'] # unknown functions are not in the all_function_names list. Perhaps spark implemented this function but did not # expose it to the user for current version. support_list = {'scalar': {'partial': set(), 'unsupported': set(), 'unsupported_expr': set(), 'unknown': set()}, 'aggregate': {'partial': set(), 'unsupported': set(), 'unsupported_expr': set(), 'unknown': set()}, 'generator': {'partial': set(), 'unsupported': set(), 'unsupported_expr': set(), 'unknown': set()}, 'window': {'partial': set(), 'unsupported': set(), 'unsupported_expr': set(), 'unknown': set()}} try_to_binary_funcs = {'unhex', 'encode', 'unbase64'} unresolved = [] def filter_fallback_reasons(): with open(log_file, 'r') as f: lines = f.readlines() validation_logs = [] # Filter validation logs. for l in lines: if l.startswith(' - ') and 'Native validation failed:' not in l or l.startswith(' |- '): validation_logs.append(l) # Extract fallback reasons. fallback_reasons = set() for l in validation_logs: if 'due to:' in l: fallback_reasons.add(l.split('due to:')[-1].strip()) elif 'reason:' in l: fallback_reasons.add(l.split('reason:')[-1].strip()) else: fallback_reasons.add(l) fallback_reasons = sorted(fallback_reasons) # Remove udf. return list(filter(lambda x: 'Not supported python udf' not in x and 'Not supported scala udf' not in x, fallback_reasons)) def function_name_tuple(function_name): return ( function_name, None if function_name not in function_to_classname else function_to_classname[function_name]) def function_not_found(r): logging.log(logging.WARNING, f"No function name or class name found in: {r}") unresolved.append(r) java_import(jvm, "org.apache.gluten.expression.ExpressionMappings") jexpression_mappings = jvm.org.apache.gluten.expression.ExpressionMappings.listExpressionMappings() gluten_expressions = {} for item in jexpression_mappings: gluten_expressions[item._1()] = item._2() for category in FUNCTION_CATEGORIES: if category == 'scalar': for f in functions[category]: # TODO: Remove this filter as it may exclude supported expressions, such as Builder. if f not in builtin_functions and f not in gluten_expressions.values() and function_to_classname[ f] not in gluten_expressions.keys(): logging.log(logging.WARNING, f"Function not found in gluten expressions: {f}") support_list[category]['unsupported'].add(function_name_tuple(f)) for f in GLUTEN_RESTRICTIONS[category].keys(): support_list[category]['partial'].add(function_name_tuple(f)) for r in filter_fallback_reasons(): ############## Scalar functions ############## # Not supported: Expression not in ExpressionMappings. if 'Not supported to map spark function name to substrait function name' in r: pattern = r"class name: ([\w0-9]+)." # Extract class name match = re.search(pattern, r) if match: class_name = match.group(1) if class_name in classname_to_function: function_name = classname_to_function[class_name] if function_name in function_names: support_list['scalar']['unsupported'].add((function_name, class_name)) else: support_list['scalar']['unknown'].add((function_name, class_name)) else: logging.log(logging.INFO, f"No function name for class: {class_name}. Adding class name") support_list['scalar']['unsupported_expr'].add(class_name) else: function_not_found(r) # Not supported: Function not registered in Velox. elif 'Scalar function name not registered:' in r: pattern = r"Scalar function name not registered:\s+([\w0-9]+)" # Extract the function name match = re.search(pattern, r) if match: function_name = match.group(1) if function_name in function_names: support_list['scalar']['unsupported'].add(function_name_tuple(function_name)) else: support_list['scalar']['unknown'].add(function_name_tuple(function_name)) else: function_not_found(r) # Partially supported: Function registered in Velox but not registered with specific arguments. elif 'not registered with arguments:' in r: pattern = r"Scalar function ([\w0-9]+) not registered with arguments:" # Extract the function name match = re.search(pattern, r) if match: function_name = match.group(1) if function_name in function_names: support_list['scalar']['partial'].add(function_name_tuple(function_name)) else: support_list['scalar']['unknown'].add(function_name_tuple(function_name)) else: function_not_found(r) # Not supported: Special case for unsupported expressions. elif 'Not support expression' in r: pattern = r"Not support expression ([\w0-9]+)" # Extract class name match = re.search(pattern, r) if match: class_name = match.group(1) if class_name in classname_to_function: function_name = classname_to_function[class_name] if function_name in function_names: support_list['scalar']['unsupported'].add((function_name, class_name)) else: support_list['scalar']['unknown'].add((function_name, class_name)) else: logging.log(logging.INFO, f"No function name for class: {class_name}. Adding class name") support_list['scalar']['unsupported_expr'].add(class_name) else: function_not_found(r) # Not supported: Special case for unsupported functions. elif 'Function is not supported:' in r: pattern = r"Function is not supported:\s+([\w0-9]+)" # Extract the function name match = re.search(pattern, r) if match: function_name = match.group(1) if function_name in function_names: support_list['scalar']['unsupported'].add(function_name_tuple(function_name)) else: support_list['scalar']['unknown'].add(function_name_tuple(function_name)) else: function_not_found(r) ############## Aggregate functions ############## elif 'Could not find a valid substrait mapping' in r: pattern = r"Could not find a valid substrait mapping name for ([\w0-9]+)\(" # Extract the function name match = re.search(pattern, r) if match: function_name = match.group(1) if function_name in function_names: support_list['aggregate']['unsupported'].add(function_name_tuple(function_name)) else: support_list['aggregate']['unknown'].add(function_name_tuple(function_name)) else: function_not_found(r) elif 'Unsupported aggregate mode' in r: pattern = r"Unsupported aggregate mode: [\w]+ for ([\w0-9]+)" # Extract the function name match = re.search(pattern, r) if match: function_name = match.group(1) if function_name in function_names: support_list['aggregate']['partial'].add(function_name_tuple(function_name)) else: support_list['aggregate']['unknown'].add(function_name_tuple(function_name)) else: function_not_found(r) ############## Generator functions ############## elif 'Velox backend does not support this generator:' in r: pattern = r"Velox backend does not support this generator:\s+([\w0-9]+)" # Extract the function name match = re.search(pattern, r) if match: class_name = match.group(1) function_name = class_name.lower() if function_name not in generator_functions: support_list['generator']['unknown'].add((None, class_name)) elif 'outer: true' in r: support_list['generator']['unsupported'].add((function_name + '_outer', None)) else: support_list['generator']['unsupported'].add(function_name_tuple(function_name)) else: function_not_found(r) ############## Special judgements ############## elif 'try_eval' in r and ' is not supported' in r: pattern = r"try_eval\((\w+)\) is not supported" match = re.search(pattern, r) if match: function_name = match.group(1) if function_name in try_to_binary_funcs: try_to_binary_funcs.remove(function_name) function_name = 'try_to_binary' p = function_name_tuple(function_name) if len(try_to_binary_funcs) == 0: if p in support_list['scalar']['partial']: support_list['scalar']['partial'].remove(p) support_list['scalar']['unsupported'].add(p) elif 'add' in function_name: function_name = 'try_add' support_list['scalar']['partial'].add(function_name_tuple(function_name)) else: function_not_found(r) elif 'Pattern is not string literal for regexp_extract' == r: function_name = 'regexp_extract' support_list['scalar']['partial'].add(function_name_tuple(function_name)) elif 'Pattern is not string literal for regexp_extract_all' == r: function_name = 'regexp_extract_all' support_list['scalar']['partial'].add(function_name_tuple(function_name)) else: unresolved.append(r) return support_list, unresolved def generate_function_doc(category, output): def support_str(num_functions): return f"{num_functions} functions" if num_functions > 1 else f"{num_functions} function" num_unsupported = len(list(filter(lambda x: x[0] is not None, support_list[category]['unsupported']))) num_unsupported_expression = len(support_list[category]['unsupported_expr']) num_unknown_function = len(support_list[category]['unknown']) num_partially_supported = len(list(filter(lambda x: x[0] is not None, support_list[category]['partial']))) num_supported = len(functions[category]) - num_unsupported - num_partially_supported logging.log(logging.WARNING, f'Number of {category} functions: {len(functions[category])}') logging.log(logging.WARNING, f'Number of fully supported {category} function: {num_supported}') logging.log(logging.WARNING, f'Number of unsupported {category} functions: {num_unsupported}') logging.log(logging.WARNING, f'Number of partially supported {category} function: {num_partially_supported}') logging.log(logging.WARNING, f'Number of unsupported {category} expressions: {num_unsupported_expression}') logging.log(logging.WARNING, f'Number of unknown {category} function: {num_unknown_function}. List: {support_list[category]["unknown"]}') headers = ['Spark Functions', 'Spark Expressions', 'Status', 'Restrictions'] partially_supports = '.' if not num_partially_supported else f' and partially supports {support_str(num_partially_supported)}.' lines = f'''# {category.capitalize()} Functions Support Status **Out of {len(functions[category])} {category} functions in Spark 3.5, Gluten currently fully supports {support_str(num_supported)}{partially_supports}** ''' for g in sorted(SPARK_FUNCTION_GROUPS): if g in FUNCTION_GROUPS[category]: lines += '## ' + FUNCTION_GROUPS[category][g] + '\n\n' data = [] for f in sorted(group_functions[g]): classname = '' if f not in spark_function_map else spark_function_map[f] support = None for item in support_list[category]['partial']: if item[0] and item[0] == f or item[1] and item[1] == classname: support = 'PS' break if support is None: for item in support_list[category]['unsupported']: if item[0] and item[0] == f or item[1] and item[1] == classname: support = '' break if support is None: support = 'S' if f == '|': f = '&#124;' elif f == '||': f = '&#124;&#124;' data.append([f, classname, support, '' if f not in GLUTEN_RESTRICTIONS[category] else GLUTEN_RESTRICTIONS[category][f]]) table = tabulate.tabulate(data, headers, tablefmt="github") lines += table + '\n\n' with open(output, 'w') as fd: fd.write(lines) def run_test_suites(categories): log4j_properties_file = os.path.abspath( os.path.join(os.path.dirname(os.path.abspath(__file__)), 'log4j2.properties')) suite_list = [] for category in categories: if FUNCTION_SUITES[category]: suite_list.append(','.join([FUNCTION_SUITE_PACKAGE + name for name in FUNCTION_SUITES[category]])) suites = ','.join(suite_list) if not suites: logging.log(logging.WARNING, "No test suites to run.") return command = [ "mvn", "test", "-Pspark-3.5", "-Pspark-ut", "-Pbackends-velox", f"-DargLine=-Dspark.test.home={spark_home} -Dlog4j2.configurationFile=file:{log4j_properties_file}", f"-DwildcardSuites={suites}", "-Dtest=none", "-Dsurefire.failIfNoSpecifiedTests=false" ] subprocess.Popen(command, cwd=gluten_home).wait() def get_maven_project_version(): result = subprocess.run( ['mvn', 'help:evaluate', '-Dexpression=project.version', '-q', '-DforceStdout'], capture_output=True, text=True, cwd=gluten_home ) if result.returncode == 0: version = result.stdout.strip() return version else: raise RuntimeError(f"Error running Maven command: {result.stderr}") if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument("--spark_home", type=str, required=True, help="Directory to spark source code for the newest supported spark version in Gluten. " "It's required the spark project has been built from source.") parser.add_argument("--skip_test_suite", action='store_true', help="Whether to run test suite. Set to False to skip running the test suite.") parser.add_argument("--categories", type=str, default=','.join(FUNCTION_CATEGORIES), help="Use comma-separated string to specify the function categories to generate the docs. " "Default is all categories.") args = parser.parse_args() spark_home = args.spark_home findspark.init(spark_home) gluten_home = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), '../../')) if not args.skip_test_suite: run_test_suites(args.categories.split(',')) gluten_version = get_maven_project_version() gluten_jar = os.path.join(gluten_home, 'package', 'target', f'gluten-package-{gluten_version}.jar') if not os.path.exists(gluten_jar): raise Exception(f"Gluten jar not found at {gluten_jar}") # Importing the required modules after findspark. from py4j.java_gateway import java_import from pyspark.java_gateway import launch_gateway from pyspark.conf import SparkConf conf = SparkConf().set("spark.jars", gluten_jar) jvm = launch_gateway(conf=conf).jvm # Generate the function list to the global variables. all_function_names = [] functions = {'scalar': set(), 'aggregate': set(), 'window': set(), 'generator': set()} classname_to_function = {} function_to_classname = {} group_functions = {} generate_function_list() spark_function_map = create_spark_function_map() support_list, unresolved = parse_logs( os.path.join(gluten_home, 'gluten-ut', 'spark35', 'target', 'gen-function-support-docs-tests.log')) for category in args.categories.split(','): generate_function_doc(category, os.path.join(gluten_home, 'docs', f'velox-backend-{category}-function-support.md'))