dev/gen.py (98 lines of code) (raw):

# Licensed to the Apache Software Foundation (ASF) under one or more # contributor license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright ownership. # The ASF licenses this file to You under the Apache License, Version 2.0 # (the "License"); you may not use this file except in compliance with # the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # This is a basic script to generate the builtin functions based on the # currently available PySpark installation. # Simply call the script as follows: # # python gen.py > spark/client/functions/generated.go import pyspark.sql.connect.functions as F import inspect import typing import types def normalize(input: str) -> str: vals = [x[0].upper() + x[1:] for x in input.split("_")] return "".join(vals) print(""" // Licensed to the Apache Software Foundation (ASF) under one or more // contributor license agreements. See the NOTICE file distributed with // this work for additional information regarding copyright ownership. // The ASF licenses this file to You under the Apache License, Version 2.0 // (the "License"); you may not use this file except in compliance with // the License. You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package functions import "github.com/apache/spark-connect-go/v35/spark/sql/column" """) for fun in F.__dict__: if fun.startswith("_"): continue if not callable(F.__dict__[fun]): continue if "pyspark.sql.connect.functions" not in F.__dict__[fun].__module__: continue if fun == "expr" or fun == "col" or fun == "column" or fun == "lit": continue # Ignore the aliases of the old distinct. if "Distinct" in fun: continue sig = inspect.signature(F.__dict__[fun]) # Ignore all functions that take callables as parameters has_callable = False for p in sig.parameters: param = sig.parameters[p] if "Callable" in str(param.annotation): has_callable = True break if has_callable: print(f"// TODO: {fun}: {sig}") print() continue if "udf" in fun.lower(): print(f"// Ignore UDF: {fun}: {sig}") print() continue if "udt" in fun.lower(): print(f"// Ignore UDT: {fun}: {sig}") print() continue # Convert parameters into Golang res_params = [] conversions = [] args = [] valid = True for p in sig.parameters: param = sig.parameters[p] if param.annotation == inspect.Parameter.empty: res_params.append(f"{p} interface{{}}") args.append(p) elif param.kind == inspect.Parameter.VAR_POSITIONAL and param.annotation == "ColumnOrName": res_params.append(f"{p} ...column.Column") conversions.append("vals := make([]column.Column, 0)") for x in args: conversions.append(f"vals = append(vals, {x})") conversions.append(f"vals = append(vals, {p}...)") args = ["vals..."] elif type(param.annotation) == str and str(param.annotation) == "ColumnOrName" and param.kind != inspect.Parameter.VAR_POSITIONAL and param.kind != inspect.Parameter.VAR_KEYWORD: res_params.append(f"{p} column.Column") args.append(p) elif len(typing.get_args(param.annotation)) > 1 and typing.ForwardRef("ColumnOrName") in typing.get_args(param.annotation): # Find the parameter with ColumnOrName tmp = [x for x in typing.get_args(param.annotation) if typing.ForwardRef("ColumnOrName") == x] assert len(tmp) == 1 res_params.append(f"{p} column.Column") args.append(p) elif param.annotation == str or typing.get_args(param.annotation) == (str, types.NoneType): res_params.append(f"{p} string") conversions.append(f"lit_{p} := StringLit({p})") args.append(f"lit_{p}") elif param.annotation == int or typing.get_args(param.annotation) == (int, types.NoneType): res_params.append(f"{p} int64") conversions.append(f"lit_{p} := Int64Lit({p})") args.append(f"lit_{p}") elif param.annotation == float or typing.get_args(param.annotation) == (float, types.NoneType): res_params.append(f"{p} float64") conversions.append(f"lit_{p} := Float64Lit({p})") args.append(f"lit_{p}") else: valid = False break if not valid: print(f"// TODO: {fun}: {sig}") print() else: name = normalize(fun) # Generate the doc string if F.__dict__[fun].__doc__ is not None: lines = list(map(str.lstrip, F.__dict__[fun].__doc__.split("\n"))) pos = list(map(lambda x: x.startswith("..") or x.startswith("Parameters"), lines)).index(True) lines = "\n".join(lines[:pos]).strip().split("\n") lines[0] = name + " - " + lines[0] lines = ["// " + l for l in lines] doc = "\n".join(lines) + "\n//" print(doc) print(f"// {name} is the Golang equivalent of {fun}: {sig}") print(f"func {name}({', '.join(res_params)}) column.Column {{") for c in conversions: print(f" {c}") print(f" return column.NewColumn(column.NewUnresolvedFunctionWithColumns(\"{fun}\", {', '.join(args)}))") print(f"}}") print()