in python/pyspark/worker.py [0:0]
def read_udtf(pickleSer, infile, eval_type):
prefers_large_var_types = False
legacy_pandas_conversion = False
if eval_type == PythonEvalType.SQL_ARROW_TABLE_UDF:
runner_conf = {}
# Load conf used for arrow evaluation.
num_conf = read_int(infile)
for i in range(num_conf):
k = utf8_deserializer.loads(infile)
v = utf8_deserializer.loads(infile)
runner_conf[k] = v
prefers_large_var_types = use_large_var_types(runner_conf)
legacy_pandas_conversion = (
runner_conf.get(
"spark.sql.legacy.execution.pythonUDTF.pandas.conversion.enabled", "false"
).lower()
== "true"
)
if legacy_pandas_conversion:
# NOTE: if timezone is set here, that implies respectSessionTimeZone is True
safecheck = (
runner_conf.get(
"spark.sql.execution.pandas.convertToArrowArraySafely", "false"
).lower()
== "true"
)
timezone = runner_conf.get("spark.sql.session.timeZone", None)
ser = ArrowStreamPandasUDTFSerializer(timezone, safecheck)
else:
ser = ArrowStreamUDTFSerializer()
else:
# Each row is a group so do not batch but send one by one.
ser = BatchedSerializer(CPickleSerializer(), 1)
# See 'PythonUDTFRunner.PythonUDFWriterThread.writeCommand'
num_arg = read_int(infile)
args_offsets = []
kwargs_offsets = {}
for _ in range(num_arg):
offset = read_int(infile)
if read_bool(infile):
name = utf8_deserializer.loads(infile)
kwargs_offsets[name] = offset
else:
args_offsets.append(offset)
num_partition_child_indexes = read_int(infile)
partition_child_indexes = [read_int(infile) for i in range(num_partition_child_indexes)]
has_pickled_analyze_result = read_bool(infile)
if has_pickled_analyze_result:
pickled_analyze_result = pickleSer._read_with_length(infile)
else:
pickled_analyze_result = None
# Initially we assume that the UDTF __init__ method accepts the pickled AnalyzeResult,
# although we may set this to false later if we find otherwise.
handler = read_command(pickleSer, infile)
if not isinstance(handler, type):
raise PySparkRuntimeError(
f"Invalid UDTF handler type. Expected a class (type 'type'), but "
f"got an instance of {type(handler).__name__}."
)
return_type = _parse_datatype_json_string(utf8_deserializer.loads(infile))
if not isinstance(return_type, StructType):
raise PySparkRuntimeError(
f"The return type of a UDTF must be a struct type, but got {type(return_type)}."
)
udtf_name = utf8_deserializer.loads(infile)
# Update the handler that creates a new UDTF instance to first try calling the UDTF constructor
# with one argument containing the previous AnalyzeResult. If that fails, then try a constructor
# with no arguments. In this way each UDTF class instance can decide if it wants to inspect the
# AnalyzeResult.
udtf_init_args = inspect.getfullargspec(handler)
if has_pickled_analyze_result:
if len(udtf_init_args.args) > 2:
raise PySparkRuntimeError(
errorClass="UDTF_CONSTRUCTOR_INVALID_IMPLEMENTS_ANALYZE_METHOD",
messageParameters={"name": udtf_name},
)
elif len(udtf_init_args.args) == 2:
prev_handler = handler
def construct_udtf():
# Here we pass the AnalyzeResult to the UDTF's __init__ method.
return prev_handler(dataclasses.replace(pickled_analyze_result))
handler = construct_udtf
elif len(udtf_init_args.args) > 1:
raise PySparkRuntimeError(
errorClass="UDTF_CONSTRUCTOR_INVALID_NO_ANALYZE_METHOD",
messageParameters={"name": udtf_name},
)
class UDTFWithPartitions:
"""
This implements the logic of a UDTF that accepts an input TABLE argument with one or more
PARTITION BY expressions.
For example, let's assume we have a table like:
CREATE TABLE t (c1 INT, c2 INT) USING delta;
Then for the following queries:
SELECT * FROM my_udtf(TABLE (t) PARTITION BY c1, c2);
The partition_child_indexes will be: 0, 1.
SELECT * FROM my_udtf(TABLE (t) PARTITION BY c1, c2 + 4);
The partition_child_indexes will be: 0, 2 (where we add a projection for "c2 + 4").
"""
def __init__(self, create_udtf: Callable, partition_child_indexes: list):
"""
Creates a new instance of this class to wrap the provided UDTF with another one that
checks the values of projected partitioning expressions on consecutive rows to figure
out when the partition boundaries change.
Parameters
----------
create_udtf: function
Function to create a new instance of the UDTF to be invoked.
partition_child_indexes: list
List of integers identifying zero-based indexes of the columns of the input table
that contain projected partitioning expressions. This class will inspect these
values for each pair of consecutive input rows. When they change, this indicates
the boundary between two partitions, and we will invoke the 'terminate' method on
the UDTF class instance and then destroy it and create a new one to implement the
desired partitioning semantics.
"""
self._create_udtf: Callable = create_udtf
self._udtf = create_udtf()
self._prev_arguments: list = list()
self._partition_child_indexes: list = partition_child_indexes
self._eval_raised_skip_rest_of_input_table: bool = False
def eval(self, *args, **kwargs) -> Iterator:
changed_partitions = self._check_partition_boundaries(
list(args) + list(kwargs.values())
)
if changed_partitions:
if hasattr(self._udtf, "terminate"):
result = self._udtf.terminate()
if result is not None:
for row in result:
yield row
self._udtf = self._create_udtf()
self._eval_raised_skip_rest_of_input_table = False
if self._udtf.eval is not None and not self._eval_raised_skip_rest_of_input_table:
# Filter the arguments to exclude projected PARTITION BY values added by Catalyst.
filtered_args = [self._remove_partition_by_exprs(arg) for arg in args]
filtered_kwargs = {
key: self._remove_partition_by_exprs(value) for (key, value) in kwargs.items()
}
try:
result = self._udtf.eval(*filtered_args, **filtered_kwargs)
if result is not None:
for row in result:
yield row
except SkipRestOfInputTableException:
# If the 'eval' method raised this exception, then we should skip the rest of
# the rows in the current partition. Set this field to True here and then for
# each subsequent row in the partition, we will skip calling the 'eval' method
# until we see a change in the partition boundaries.
self._eval_raised_skip_rest_of_input_table = True
def terminate(self) -> Iterator:
if hasattr(self._udtf, "terminate"):
return self._udtf.terminate()
return iter(())
def cleanup(self) -> None:
if hasattr(self._udtf, "cleanup"):
self._udtf.cleanup()
def _check_partition_boundaries(self, arguments: list) -> bool:
result = False
if len(self._prev_arguments) > 0:
cur_table_arg = self._get_table_arg(arguments)
prev_table_arg = self._get_table_arg(self._prev_arguments)
cur_partitions_args = []
prev_partitions_args = []
for i in self._partition_child_indexes:
cur_partitions_args.append(cur_table_arg[i])
prev_partitions_args.append(prev_table_arg[i])
result = any(k != v for k, v in zip(cur_partitions_args, prev_partitions_args))
self._prev_arguments = arguments
return result
def _get_table_arg(self, inputs: list) -> Row:
return [x for x in inputs if type(x) is Row][0]
def _remove_partition_by_exprs(self, arg: Any) -> Any:
if isinstance(arg, Row):
new_row_keys = []
new_row_values = []
for i, (key, value) in enumerate(zip(arg.__fields__, arg)):
if i not in self._partition_child_indexes:
new_row_keys.append(key)
new_row_values.append(value)
return _create_row(new_row_keys, new_row_values)
else:
return arg
# Instantiate the UDTF class.
try:
if len(partition_child_indexes) > 0:
udtf = UDTFWithPartitions(handler, partition_child_indexes)
else:
udtf = handler()
except Exception as e:
raise PySparkRuntimeError(
errorClass="UDTF_EXEC_ERROR",
messageParameters={"method_name": "__init__", "error": str(e)},
)
# Validate the UDTF
if not hasattr(udtf, "eval"):
raise PySparkRuntimeError(
"Failed to execute the user defined table function because it has not "
"implemented the 'eval' method. Please add the 'eval' method and try "
"the query again."
)
# Check that the arguments provided to the UDTF call match the expected parameters defined
# in the 'eval' method signature.
try:
inspect.signature(udtf.eval).bind(*args_offsets, **kwargs_offsets)
except TypeError as e:
raise PySparkRuntimeError(
errorClass="UDTF_EVAL_METHOD_ARGUMENTS_DO_NOT_MATCH_SIGNATURE",
messageParameters={"name": udtf_name, "reason": str(e)},
) from None
def build_null_checker(return_type: StructType) -> Optional[Callable[[Any], None]]:
def raise_(result_column_index):
raise PySparkRuntimeError(
errorClass="UDTF_EXEC_ERROR",
messageParameters={
"method_name": "eval' or 'terminate",
"error": f"Column {result_column_index} within a returned row had a "
+ "value of None, either directly or within array/struct/map "
+ "subfields, but the corresponding column type was declared as "
+ "non-nullable; please update the UDTF to return a non-None value at "
+ "this location or otherwise declare the column type as nullable.",
},
)
def checker(data_type: DataType, result_column_index: int):
if isinstance(data_type, ArrayType):
element_checker = checker(data_type.elementType, result_column_index)
contains_null = data_type.containsNull
if element_checker is None and contains_null:
return None
def check_array(arr):
if isinstance(arr, list):
for e in arr:
if e is None:
if not contains_null:
raise_(result_column_index)
elif element_checker is not None:
element_checker(e)
return check_array
elif isinstance(data_type, MapType):
key_checker = checker(data_type.keyType, result_column_index)
value_checker = checker(data_type.valueType, result_column_index)
value_contains_null = data_type.valueContainsNull
if value_checker is None and value_contains_null:
def check_map(map):
if isinstance(map, dict):
for k, v in map.items():
if k is None:
raise_(result_column_index)
elif key_checker is not None:
key_checker(k)
else:
def check_map(map):
if isinstance(map, dict):
for k, v in map.items():
if k is None:
raise_(result_column_index)
elif key_checker is not None:
key_checker(k)
if v is None:
if not value_contains_null:
raise_(result_column_index)
elif value_checker is not None:
value_checker(v)
return check_map
elif isinstance(data_type, StructType):
field_checkers = [checker(f.dataType, result_column_index) for f in data_type]
nullables = [f.nullable for f in data_type]
if all(c is None for c in field_checkers) and all(nullables):
return None
def check_struct(struct):
if isinstance(struct, tuple):
for value, checker, nullable in zip(struct, field_checkers, nullables):
if value is None:
if not nullable:
raise_(result_column_index)
elif checker is not None:
checker(value)
return check_struct
else:
return None
field_checkers = [
checker(f.dataType, result_column_index=i) for i, f in enumerate(return_type)
]
nullables = [f.nullable for f in return_type]
if all(c is None for c in field_checkers) and all(nullables):
return None
def check(row):
if isinstance(row, tuple):
for i, (value, checker, nullable) in enumerate(zip(row, field_checkers, nullables)):
if value is None:
if not nullable:
raise_(i)
elif checker is not None:
checker(value)
return check
check_output_row_against_schema = build_null_checker(return_type)
if eval_type == PythonEvalType.SQL_ARROW_TABLE_UDF and legacy_pandas_conversion:
def wrap_arrow_udtf(f, return_type):
import pandas as pd
arrow_return_type = to_arrow_type(
return_type, prefers_large_types=use_large_var_types(runner_conf)
)
return_type_size = len(return_type)
def verify_result(result):
if not isinstance(result, pd.DataFrame):
raise PySparkTypeError(
errorClass="INVALID_ARROW_UDTF_RETURN_TYPE",
messageParameters={
"return_type": type(result).__name__,
"value": str(result),
"func": f.__name__,
},
)
# Validate the output schema when the result dataframe has either output
# rows or columns. Note that we avoid using `df.empty` here because the
# result dataframe may contain an empty row. For example, when a UDTF is
# defined as follows: def eval(self): yield tuple().
if len(result) > 0 or len(result.columns) > 0:
if len(result.columns) != return_type_size:
raise PySparkRuntimeError(
errorClass="UDTF_RETURN_SCHEMA_MISMATCH",
messageParameters={
"expected": str(return_type_size),
"actual": str(len(result.columns)),
"func": f.__name__,
},
)
# Verify the type and the schema of the result.
verify_pandas_result(
result, return_type, assign_cols_by_name=False, truncate_return_schema=False
)
return result
# Wrap the exception thrown from the UDTF in a PySparkRuntimeError.
def func(*a: Any) -> Any:
try:
return f(*a)
except SkipRestOfInputTableException:
raise
except Exception as e:
raise PySparkRuntimeError(
errorClass="UDTF_EXEC_ERROR",
messageParameters={"method_name": f.__name__, "error": str(e)},
)
def check_return_value(res):
# Check whether the result of an arrow UDTF is iterable before
# using it to construct a pandas DataFrame.
if res is not None:
if not isinstance(res, Iterable):
raise PySparkRuntimeError(
errorClass="UDTF_RETURN_NOT_ITERABLE",
messageParameters={
"type": type(res).__name__,
"func": f.__name__,
},
)
if check_output_row_against_schema is not None:
for row in res:
if row is not None:
check_output_row_against_schema(row)
yield row
else:
yield from res
def evaluate(*args: pd.Series):
if len(args) == 0:
res = func()
yield verify_result(pd.DataFrame(check_return_value(res))), arrow_return_type
else:
# Create tuples from the input pandas Series, each tuple
# represents a row across all Series.
row_tuples = zip(*args)
for row in row_tuples:
res = func(*row)
yield verify_result(
pd.DataFrame(check_return_value(res))
), arrow_return_type
return evaluate
eval_func_kwargs_support, args_kwargs_offsets = wrap_kwargs_support(
getattr(udtf, "eval"), args_offsets, kwargs_offsets
)
eval = wrap_arrow_udtf(eval_func_kwargs_support, return_type)
if hasattr(udtf, "terminate"):
terminate = wrap_arrow_udtf(getattr(udtf, "terminate"), return_type)
else:
terminate = None
cleanup = getattr(udtf, "cleanup") if hasattr(udtf, "cleanup") else None
def mapper(_, it):
try:
for a in it:
# The eval function yields an iterator. Each element produced by this
# iterator is a tuple in the form of (pandas.DataFrame, arrow_return_type).
yield from eval(*[a[o] for o in args_kwargs_offsets])
if terminate is not None:
yield from terminate()
except SkipRestOfInputTableException:
if terminate is not None:
yield from terminate()
finally:
if cleanup is not None:
cleanup()
return mapper, None, ser, ser
elif eval_type == PythonEvalType.SQL_ARROW_TABLE_UDF and not legacy_pandas_conversion:
def wrap_arrow_udtf(f, return_type):
import pyarrow as pa
arrow_return_type = to_arrow_type(
return_type, prefers_large_types=use_large_var_types(runner_conf)
)
return_type_size = len(return_type)
def verify_result(result):
if not isinstance(result, pa.RecordBatch):
raise PySparkTypeError(
errorClass="INVALID_ARROW_UDTF_RETURN_TYPE",
messageParameters={
"return_type": type(result).__name__,
"value": str(result),
"func": f.__name__,
},
)
# Validate the output schema when the result dataframe has either output
# rows or columns. Note that we avoid using `df.empty` here because the
# result dataframe may contain an empty row. For example, when a UDTF is
# defined as follows: def eval(self): yield tuple().
if len(result) > 0 or len(result.columns) > 0:
if len(result.columns) != return_type_size:
raise PySparkRuntimeError(
errorClass="UDTF_RETURN_SCHEMA_MISMATCH",
messageParameters={
"expected": str(return_type_size),
"actual": str(len(result.columns)),
"func": f.__name__,
},
)
# Verify the type and the schema of the result.
verify_arrow_result(
pa.Table.from_batches([result], schema=pa.schema(list(arrow_return_type))),
assign_cols_by_name=False,
expected_cols_and_types=[
(col.name, to_arrow_type(col.dataType)) for col in return_type.fields
],
)
return result
# Wrap the exception thrown from the UDTF in a PySparkRuntimeError.
def func(*a: Any) -> Any:
try:
return f(*a)
except SkipRestOfInputTableException:
raise
except Exception as e:
raise PySparkRuntimeError(
errorClass="UDTF_EXEC_ERROR",
messageParameters={"method_name": f.__name__, "error": str(e)},
)
def check_return_value(res):
# Check whether the result of an arrow UDTF is iterable before
# using it to construct a pandas DataFrame.
if res is not None:
if not isinstance(res, Iterable):
raise PySparkRuntimeError(
errorClass="UDTF_RETURN_NOT_ITERABLE",
messageParameters={
"type": type(res).__name__,
"func": f.__name__,
},
)
if check_output_row_against_schema is not None:
for row in res:
if row is not None:
check_output_row_against_schema(row)
yield row
else:
yield from res
def convert_to_arrow(data: Iterable):
data = list(check_return_value(data))
if len(data) == 0:
return [
pa.RecordBatch.from_pylist(data, schema=pa.schema(list(arrow_return_type)))
]
try:
ret = LocalDataToArrowConversion.convert(
data, return_type, prefers_large_var_types
).to_batches()
if len(return_type.fields) == 0:
return [pa.RecordBatch.from_struct_array(pa.array([{}] * len(data)))]
return ret
except Exception as e:
raise PySparkRuntimeError(
errorClass="UDTF_ARROW_TYPE_CONVERSION_ERROR",
messageParameters={
"data": str(data),
"schema": return_type.simpleString(),
"arrow_schema": str(arrow_return_type),
},
) from e
def evaluate(*args: pa.ChunkedArray):
if len(args) == 0:
for batch in convert_to_arrow(func()):
yield verify_result(batch), arrow_return_type
else:
list_args = list(args)
names = [f"_{n}" for n in range(len(list_args))]
t = pa.Table.from_arrays(list_args, names=names)
schema = from_arrow_schema(t.schema, prefers_large_var_types)
rows = ArrowTableToRowsConversion.convert(t, schema=schema)
for row in rows:
row = tuple(row) # type: ignore[assignment]
for batch in convert_to_arrow(func(*row)):
yield verify_result(batch), arrow_return_type
return evaluate
eval_func_kwargs_support, args_kwargs_offsets = wrap_kwargs_support(
getattr(udtf, "eval"), args_offsets, kwargs_offsets
)
eval = wrap_arrow_udtf(eval_func_kwargs_support, return_type)
if hasattr(udtf, "terminate"):
terminate = wrap_arrow_udtf(getattr(udtf, "terminate"), return_type)
else:
terminate = None
cleanup = getattr(udtf, "cleanup") if hasattr(udtf, "cleanup") else None
def mapper(_, it):
try:
for a in it:
# The eval function yields an iterator. Each element produced by this
# iterator is a tuple in the form of (pyarrow.RecordBatch, arrow_return_type).
yield from eval(*[a[o] for o in args_kwargs_offsets])
if terminate is not None:
yield from terminate()
except SkipRestOfInputTableException:
if terminate is not None:
yield from terminate()
finally:
if cleanup is not None:
cleanup()
return mapper, None, ser, ser
else:
def wrap_udtf(f, return_type):
assert return_type.needConversion()
toInternal = return_type.toInternal
return_type_size = len(return_type)
def verify_and_convert_result(result):
if result is not None:
if hasattr(result, "__len__") and len(result) != return_type_size:
raise PySparkRuntimeError(
errorClass="UDTF_RETURN_SCHEMA_MISMATCH",
messageParameters={
"expected": str(return_type_size),
"actual": str(len(result)),
"func": f.__name__,
},
)
if not (isinstance(result, (list, dict, tuple)) or hasattr(result, "__dict__")):
raise PySparkRuntimeError(
errorClass="UDTF_INVALID_OUTPUT_ROW_TYPE",
messageParameters={
"type": type(result).__name__,
"func": f.__name__,
},
)
if check_output_row_against_schema is not None:
check_output_row_against_schema(result)
return toInternal(result)
# Evaluate the function and return a tuple back to the executor.
def evaluate(*a) -> tuple:
try:
res = f(*a)
except SkipRestOfInputTableException:
raise
except Exception as e:
raise PySparkRuntimeError(
errorClass="UDTF_EXEC_ERROR",
messageParameters={"method_name": f.__name__, "error": str(e)},
)
if res is None:
# If the function returns None or does not have an explicit return statement,
# an empty tuple is returned to the executor.
# This is because directly constructing tuple(None) results in an exception.
return tuple()
if not isinstance(res, Iterable):
raise PySparkRuntimeError(
errorClass="UDTF_RETURN_NOT_ITERABLE",
messageParameters={
"type": type(res).__name__,
"func": f.__name__,
},
)
# If the function returns a result, we map it to the internal representation and
# returns the results as a tuple.
return tuple(map(verify_and_convert_result, res))
return evaluate
eval_func_kwargs_support, args_kwargs_offsets = wrap_kwargs_support(
getattr(udtf, "eval"), args_offsets, kwargs_offsets
)
eval = wrap_udtf(eval_func_kwargs_support, return_type)
if hasattr(udtf, "terminate"):
terminate = wrap_udtf(getattr(udtf, "terminate"), return_type)
else:
terminate = None
cleanup = getattr(udtf, "cleanup") if hasattr(udtf, "cleanup") else None
# Return an iterator of iterators.
def mapper(_, it):
try:
for a in it:
yield eval(*[a[o] for o in args_kwargs_offsets])
if terminate is not None:
yield terminate()
except SkipRestOfInputTableException:
if terminate is not None:
yield terminate()
finally:
if cleanup is not None:
cleanup()
return mapper, None, ser, ser