def read_udtf()

in python/pyspark/worker.py [0:0]


def read_udtf(pickleSer, infile, eval_type):
    prefers_large_var_types = False
    legacy_pandas_conversion = False

    if eval_type == PythonEvalType.SQL_ARROW_TABLE_UDF:
        runner_conf = {}
        # Load conf used for arrow evaluation.
        num_conf = read_int(infile)
        for i in range(num_conf):
            k = utf8_deserializer.loads(infile)
            v = utf8_deserializer.loads(infile)
            runner_conf[k] = v
        prefers_large_var_types = use_large_var_types(runner_conf)
        legacy_pandas_conversion = (
            runner_conf.get(
                "spark.sql.legacy.execution.pythonUDTF.pandas.conversion.enabled", "false"
            ).lower()
            == "true"
        )
        if legacy_pandas_conversion:
            # NOTE: if timezone is set here, that implies respectSessionTimeZone is True
            safecheck = (
                runner_conf.get(
                    "spark.sql.execution.pandas.convertToArrowArraySafely", "false"
                ).lower()
                == "true"
            )
            timezone = runner_conf.get("spark.sql.session.timeZone", None)
            ser = ArrowStreamPandasUDTFSerializer(timezone, safecheck)
        else:
            ser = ArrowStreamUDTFSerializer()

    else:
        # Each row is a group so do not batch but send one by one.
        ser = BatchedSerializer(CPickleSerializer(), 1)

    # See 'PythonUDTFRunner.PythonUDFWriterThread.writeCommand'
    num_arg = read_int(infile)
    args_offsets = []
    kwargs_offsets = {}
    for _ in range(num_arg):
        offset = read_int(infile)
        if read_bool(infile):
            name = utf8_deserializer.loads(infile)
            kwargs_offsets[name] = offset
        else:
            args_offsets.append(offset)
    num_partition_child_indexes = read_int(infile)
    partition_child_indexes = [read_int(infile) for i in range(num_partition_child_indexes)]
    has_pickled_analyze_result = read_bool(infile)
    if has_pickled_analyze_result:
        pickled_analyze_result = pickleSer._read_with_length(infile)
    else:
        pickled_analyze_result = None
    # Initially we assume that the UDTF __init__ method accepts the pickled AnalyzeResult,
    # although we may set this to false later if we find otherwise.
    handler = read_command(pickleSer, infile)
    if not isinstance(handler, type):
        raise PySparkRuntimeError(
            f"Invalid UDTF handler type. Expected a class (type 'type'), but "
            f"got an instance of {type(handler).__name__}."
        )

    return_type = _parse_datatype_json_string(utf8_deserializer.loads(infile))
    if not isinstance(return_type, StructType):
        raise PySparkRuntimeError(
            f"The return type of a UDTF must be a struct type, but got {type(return_type)}."
        )
    udtf_name = utf8_deserializer.loads(infile)

    # Update the handler that creates a new UDTF instance to first try calling the UDTF constructor
    # with one argument containing the previous AnalyzeResult. If that fails, then try a constructor
    # with no arguments. In this way each UDTF class instance can decide if it wants to inspect the
    # AnalyzeResult.
    udtf_init_args = inspect.getfullargspec(handler)
    if has_pickled_analyze_result:
        if len(udtf_init_args.args) > 2:
            raise PySparkRuntimeError(
                errorClass="UDTF_CONSTRUCTOR_INVALID_IMPLEMENTS_ANALYZE_METHOD",
                messageParameters={"name": udtf_name},
            )
        elif len(udtf_init_args.args) == 2:
            prev_handler = handler

            def construct_udtf():
                # Here we pass the AnalyzeResult to the UDTF's __init__ method.
                return prev_handler(dataclasses.replace(pickled_analyze_result))

            handler = construct_udtf
    elif len(udtf_init_args.args) > 1:
        raise PySparkRuntimeError(
            errorClass="UDTF_CONSTRUCTOR_INVALID_NO_ANALYZE_METHOD",
            messageParameters={"name": udtf_name},
        )

    class UDTFWithPartitions:
        """
        This implements the logic of a UDTF that accepts an input TABLE argument with one or more
        PARTITION BY expressions.

        For example, let's assume we have a table like:
            CREATE TABLE t (c1 INT, c2 INT) USING delta;
        Then for the following queries:
            SELECT * FROM my_udtf(TABLE (t) PARTITION BY c1, c2);
            The partition_child_indexes will be: 0, 1.
            SELECT * FROM my_udtf(TABLE (t) PARTITION BY c1, c2 + 4);
            The partition_child_indexes will be: 0, 2 (where we add a projection for "c2 + 4").
        """

        def __init__(self, create_udtf: Callable, partition_child_indexes: list):
            """
            Creates a new instance of this class to wrap the provided UDTF with another one that
            checks the values of projected partitioning expressions on consecutive rows to figure
            out when the partition boundaries change.

            Parameters
            ----------
            create_udtf: function
                Function to create a new instance of the UDTF to be invoked.
            partition_child_indexes: list
                List of integers identifying zero-based indexes of the columns of the input table
                that contain projected partitioning expressions. This class will inspect these
                values for each pair of consecutive input rows. When they change, this indicates
                the boundary between two partitions, and we will invoke the 'terminate' method on
                the UDTF class instance and then destroy it and create a new one to implement the
                desired partitioning semantics.
            """
            self._create_udtf: Callable = create_udtf
            self._udtf = create_udtf()
            self._prev_arguments: list = list()
            self._partition_child_indexes: list = partition_child_indexes
            self._eval_raised_skip_rest_of_input_table: bool = False

        def eval(self, *args, **kwargs) -> Iterator:
            changed_partitions = self._check_partition_boundaries(
                list(args) + list(kwargs.values())
            )
            if changed_partitions:
                if hasattr(self._udtf, "terminate"):
                    result = self._udtf.terminate()
                    if result is not None:
                        for row in result:
                            yield row
                self._udtf = self._create_udtf()
                self._eval_raised_skip_rest_of_input_table = False
            if self._udtf.eval is not None and not self._eval_raised_skip_rest_of_input_table:
                # Filter the arguments to exclude projected PARTITION BY values added by Catalyst.
                filtered_args = [self._remove_partition_by_exprs(arg) for arg in args]
                filtered_kwargs = {
                    key: self._remove_partition_by_exprs(value) for (key, value) in kwargs.items()
                }
                try:
                    result = self._udtf.eval(*filtered_args, **filtered_kwargs)
                    if result is not None:
                        for row in result:
                            yield row
                except SkipRestOfInputTableException:
                    # If the 'eval' method raised this exception, then we should skip the rest of
                    # the rows in the current partition. Set this field to True here and then for
                    # each subsequent row in the partition, we will skip calling the 'eval' method
                    # until we see a change in the partition boundaries.
                    self._eval_raised_skip_rest_of_input_table = True

        def terminate(self) -> Iterator:
            if hasattr(self._udtf, "terminate"):
                return self._udtf.terminate()
            return iter(())

        def cleanup(self) -> None:
            if hasattr(self._udtf, "cleanup"):
                self._udtf.cleanup()

        def _check_partition_boundaries(self, arguments: list) -> bool:
            result = False
            if len(self._prev_arguments) > 0:
                cur_table_arg = self._get_table_arg(arguments)
                prev_table_arg = self._get_table_arg(self._prev_arguments)
                cur_partitions_args = []
                prev_partitions_args = []
                for i in self._partition_child_indexes:
                    cur_partitions_args.append(cur_table_arg[i])
                    prev_partitions_args.append(prev_table_arg[i])
                result = any(k != v for k, v in zip(cur_partitions_args, prev_partitions_args))
            self._prev_arguments = arguments
            return result

        def _get_table_arg(self, inputs: list) -> Row:
            return [x for x in inputs if type(x) is Row][0]

        def _remove_partition_by_exprs(self, arg: Any) -> Any:
            if isinstance(arg, Row):
                new_row_keys = []
                new_row_values = []
                for i, (key, value) in enumerate(zip(arg.__fields__, arg)):
                    if i not in self._partition_child_indexes:
                        new_row_keys.append(key)
                        new_row_values.append(value)
                return _create_row(new_row_keys, new_row_values)
            else:
                return arg

    # Instantiate the UDTF class.
    try:
        if len(partition_child_indexes) > 0:
            udtf = UDTFWithPartitions(handler, partition_child_indexes)
        else:
            udtf = handler()
    except Exception as e:
        raise PySparkRuntimeError(
            errorClass="UDTF_EXEC_ERROR",
            messageParameters={"method_name": "__init__", "error": str(e)},
        )

    # Validate the UDTF
    if not hasattr(udtf, "eval"):
        raise PySparkRuntimeError(
            "Failed to execute the user defined table function because it has not "
            "implemented the 'eval' method. Please add the 'eval' method and try "
            "the query again."
        )

    # Check that the arguments provided to the UDTF call match the expected parameters defined
    # in the 'eval' method signature.
    try:
        inspect.signature(udtf.eval).bind(*args_offsets, **kwargs_offsets)
    except TypeError as e:
        raise PySparkRuntimeError(
            errorClass="UDTF_EVAL_METHOD_ARGUMENTS_DO_NOT_MATCH_SIGNATURE",
            messageParameters={"name": udtf_name, "reason": str(e)},
        ) from None

    def build_null_checker(return_type: StructType) -> Optional[Callable[[Any], None]]:
        def raise_(result_column_index):
            raise PySparkRuntimeError(
                errorClass="UDTF_EXEC_ERROR",
                messageParameters={
                    "method_name": "eval' or 'terminate",
                    "error": f"Column {result_column_index} within a returned row had a "
                    + "value of None, either directly or within array/struct/map "
                    + "subfields, but the corresponding column type was declared as "
                    + "non-nullable; please update the UDTF to return a non-None value at "
                    + "this location or otherwise declare the column type as nullable.",
                },
            )

        def checker(data_type: DataType, result_column_index: int):
            if isinstance(data_type, ArrayType):
                element_checker = checker(data_type.elementType, result_column_index)
                contains_null = data_type.containsNull

                if element_checker is None and contains_null:
                    return None

                def check_array(arr):
                    if isinstance(arr, list):
                        for e in arr:
                            if e is None:
                                if not contains_null:
                                    raise_(result_column_index)
                            elif element_checker is not None:
                                element_checker(e)

                return check_array

            elif isinstance(data_type, MapType):
                key_checker = checker(data_type.keyType, result_column_index)
                value_checker = checker(data_type.valueType, result_column_index)
                value_contains_null = data_type.valueContainsNull

                if value_checker is None and value_contains_null:

                    def check_map(map):
                        if isinstance(map, dict):
                            for k, v in map.items():
                                if k is None:
                                    raise_(result_column_index)
                                elif key_checker is not None:
                                    key_checker(k)

                else:

                    def check_map(map):
                        if isinstance(map, dict):
                            for k, v in map.items():
                                if k is None:
                                    raise_(result_column_index)
                                elif key_checker is not None:
                                    key_checker(k)
                                if v is None:
                                    if not value_contains_null:
                                        raise_(result_column_index)
                                elif value_checker is not None:
                                    value_checker(v)

                return check_map

            elif isinstance(data_type, StructType):
                field_checkers = [checker(f.dataType, result_column_index) for f in data_type]
                nullables = [f.nullable for f in data_type]

                if all(c is None for c in field_checkers) and all(nullables):
                    return None

                def check_struct(struct):
                    if isinstance(struct, tuple):
                        for value, checker, nullable in zip(struct, field_checkers, nullables):
                            if value is None:
                                if not nullable:
                                    raise_(result_column_index)
                            elif checker is not None:
                                checker(value)

                return check_struct

            else:
                return None

        field_checkers = [
            checker(f.dataType, result_column_index=i) for i, f in enumerate(return_type)
        ]
        nullables = [f.nullable for f in return_type]

        if all(c is None for c in field_checkers) and all(nullables):
            return None

        def check(row):
            if isinstance(row, tuple):
                for i, (value, checker, nullable) in enumerate(zip(row, field_checkers, nullables)):
                    if value is None:
                        if not nullable:
                            raise_(i)
                    elif checker is not None:
                        checker(value)

        return check

    check_output_row_against_schema = build_null_checker(return_type)

    if eval_type == PythonEvalType.SQL_ARROW_TABLE_UDF and legacy_pandas_conversion:

        def wrap_arrow_udtf(f, return_type):
            import pandas as pd

            arrow_return_type = to_arrow_type(
                return_type, prefers_large_types=use_large_var_types(runner_conf)
            )
            return_type_size = len(return_type)

            def verify_result(result):
                if not isinstance(result, pd.DataFrame):
                    raise PySparkTypeError(
                        errorClass="INVALID_ARROW_UDTF_RETURN_TYPE",
                        messageParameters={
                            "return_type": type(result).__name__,
                            "value": str(result),
                            "func": f.__name__,
                        },
                    )

                # Validate the output schema when the result dataframe has either output
                # rows or columns. Note that we avoid using `df.empty` here because the
                # result dataframe may contain an empty row. For example, when a UDTF is
                # defined as follows: def eval(self): yield tuple().
                if len(result) > 0 or len(result.columns) > 0:
                    if len(result.columns) != return_type_size:
                        raise PySparkRuntimeError(
                            errorClass="UDTF_RETURN_SCHEMA_MISMATCH",
                            messageParameters={
                                "expected": str(return_type_size),
                                "actual": str(len(result.columns)),
                                "func": f.__name__,
                            },
                        )

                # Verify the type and the schema of the result.
                verify_pandas_result(
                    result, return_type, assign_cols_by_name=False, truncate_return_schema=False
                )
                return result

            # Wrap the exception thrown from the UDTF in a PySparkRuntimeError.
            def func(*a: Any) -> Any:
                try:
                    return f(*a)
                except SkipRestOfInputTableException:
                    raise
                except Exception as e:
                    raise PySparkRuntimeError(
                        errorClass="UDTF_EXEC_ERROR",
                        messageParameters={"method_name": f.__name__, "error": str(e)},
                    )

            def check_return_value(res):
                # Check whether the result of an arrow UDTF is iterable before
                # using it to construct a pandas DataFrame.
                if res is not None:
                    if not isinstance(res, Iterable):
                        raise PySparkRuntimeError(
                            errorClass="UDTF_RETURN_NOT_ITERABLE",
                            messageParameters={
                                "type": type(res).__name__,
                                "func": f.__name__,
                            },
                        )
                    if check_output_row_against_schema is not None:
                        for row in res:
                            if row is not None:
                                check_output_row_against_schema(row)
                            yield row
                    else:
                        yield from res

            def evaluate(*args: pd.Series):
                if len(args) == 0:
                    res = func()
                    yield verify_result(pd.DataFrame(check_return_value(res))), arrow_return_type
                else:
                    # Create tuples from the input pandas Series, each tuple
                    # represents a row across all Series.
                    row_tuples = zip(*args)
                    for row in row_tuples:
                        res = func(*row)
                        yield verify_result(
                            pd.DataFrame(check_return_value(res))
                        ), arrow_return_type

            return evaluate

        eval_func_kwargs_support, args_kwargs_offsets = wrap_kwargs_support(
            getattr(udtf, "eval"), args_offsets, kwargs_offsets
        )
        eval = wrap_arrow_udtf(eval_func_kwargs_support, return_type)

        if hasattr(udtf, "terminate"):
            terminate = wrap_arrow_udtf(getattr(udtf, "terminate"), return_type)
        else:
            terminate = None

        cleanup = getattr(udtf, "cleanup") if hasattr(udtf, "cleanup") else None

        def mapper(_, it):
            try:
                for a in it:
                    # The eval function yields an iterator. Each element produced by this
                    # iterator is a tuple in the form of (pandas.DataFrame, arrow_return_type).
                    yield from eval(*[a[o] for o in args_kwargs_offsets])
                if terminate is not None:
                    yield from terminate()
            except SkipRestOfInputTableException:
                if terminate is not None:
                    yield from terminate()
            finally:
                if cleanup is not None:
                    cleanup()

        return mapper, None, ser, ser

    elif eval_type == PythonEvalType.SQL_ARROW_TABLE_UDF and not legacy_pandas_conversion:

        def wrap_arrow_udtf(f, return_type):
            import pyarrow as pa

            arrow_return_type = to_arrow_type(
                return_type, prefers_large_types=use_large_var_types(runner_conf)
            )
            return_type_size = len(return_type)

            def verify_result(result):
                if not isinstance(result, pa.RecordBatch):
                    raise PySparkTypeError(
                        errorClass="INVALID_ARROW_UDTF_RETURN_TYPE",
                        messageParameters={
                            "return_type": type(result).__name__,
                            "value": str(result),
                            "func": f.__name__,
                        },
                    )

                # Validate the output schema when the result dataframe has either output
                # rows or columns. Note that we avoid using `df.empty` here because the
                # result dataframe may contain an empty row. For example, when a UDTF is
                # defined as follows: def eval(self): yield tuple().
                if len(result) > 0 or len(result.columns) > 0:
                    if len(result.columns) != return_type_size:
                        raise PySparkRuntimeError(
                            errorClass="UDTF_RETURN_SCHEMA_MISMATCH",
                            messageParameters={
                                "expected": str(return_type_size),
                                "actual": str(len(result.columns)),
                                "func": f.__name__,
                            },
                        )

                # Verify the type and the schema of the result.
                verify_arrow_result(
                    pa.Table.from_batches([result], schema=pa.schema(list(arrow_return_type))),
                    assign_cols_by_name=False,
                    expected_cols_and_types=[
                        (col.name, to_arrow_type(col.dataType)) for col in return_type.fields
                    ],
                )
                return result

            # Wrap the exception thrown from the UDTF in a PySparkRuntimeError.
            def func(*a: Any) -> Any:
                try:
                    return f(*a)
                except SkipRestOfInputTableException:
                    raise
                except Exception as e:
                    raise PySparkRuntimeError(
                        errorClass="UDTF_EXEC_ERROR",
                        messageParameters={"method_name": f.__name__, "error": str(e)},
                    )

            def check_return_value(res):
                # Check whether the result of an arrow UDTF is iterable before
                # using it to construct a pandas DataFrame.
                if res is not None:
                    if not isinstance(res, Iterable):
                        raise PySparkRuntimeError(
                            errorClass="UDTF_RETURN_NOT_ITERABLE",
                            messageParameters={
                                "type": type(res).__name__,
                                "func": f.__name__,
                            },
                        )
                    if check_output_row_against_schema is not None:
                        for row in res:
                            if row is not None:
                                check_output_row_against_schema(row)
                            yield row
                    else:
                        yield from res

            def convert_to_arrow(data: Iterable):
                data = list(check_return_value(data))
                if len(data) == 0:
                    return [
                        pa.RecordBatch.from_pylist(data, schema=pa.schema(list(arrow_return_type)))
                    ]
                try:
                    ret = LocalDataToArrowConversion.convert(
                        data, return_type, prefers_large_var_types
                    ).to_batches()
                    if len(return_type.fields) == 0:
                        return [pa.RecordBatch.from_struct_array(pa.array([{}] * len(data)))]
                    return ret
                except Exception as e:
                    raise PySparkRuntimeError(
                        errorClass="UDTF_ARROW_TYPE_CONVERSION_ERROR",
                        messageParameters={
                            "data": str(data),
                            "schema": return_type.simpleString(),
                            "arrow_schema": str(arrow_return_type),
                        },
                    ) from e

            def evaluate(*args: pa.ChunkedArray):
                if len(args) == 0:
                    for batch in convert_to_arrow(func()):
                        yield verify_result(batch), arrow_return_type

                else:
                    list_args = list(args)
                    names = [f"_{n}" for n in range(len(list_args))]
                    t = pa.Table.from_arrays(list_args, names=names)
                    schema = from_arrow_schema(t.schema, prefers_large_var_types)
                    rows = ArrowTableToRowsConversion.convert(t, schema=schema)
                    for row in rows:
                        row = tuple(row)  # type: ignore[assignment]
                        for batch in convert_to_arrow(func(*row)):
                            yield verify_result(batch), arrow_return_type

            return evaluate

        eval_func_kwargs_support, args_kwargs_offsets = wrap_kwargs_support(
            getattr(udtf, "eval"), args_offsets, kwargs_offsets
        )
        eval = wrap_arrow_udtf(eval_func_kwargs_support, return_type)

        if hasattr(udtf, "terminate"):
            terminate = wrap_arrow_udtf(getattr(udtf, "terminate"), return_type)
        else:
            terminate = None

        cleanup = getattr(udtf, "cleanup") if hasattr(udtf, "cleanup") else None

        def mapper(_, it):
            try:
                for a in it:
                    # The eval function yields an iterator. Each element produced by this
                    # iterator is a tuple in the form of (pyarrow.RecordBatch, arrow_return_type).
                    yield from eval(*[a[o] for o in args_kwargs_offsets])
                if terminate is not None:
                    yield from terminate()
            except SkipRestOfInputTableException:
                if terminate is not None:
                    yield from terminate()
            finally:
                if cleanup is not None:
                    cleanup()

        return mapper, None, ser, ser
    else:

        def wrap_udtf(f, return_type):
            assert return_type.needConversion()
            toInternal = return_type.toInternal
            return_type_size = len(return_type)

            def verify_and_convert_result(result):
                if result is not None:
                    if hasattr(result, "__len__") and len(result) != return_type_size:
                        raise PySparkRuntimeError(
                            errorClass="UDTF_RETURN_SCHEMA_MISMATCH",
                            messageParameters={
                                "expected": str(return_type_size),
                                "actual": str(len(result)),
                                "func": f.__name__,
                            },
                        )

                    if not (isinstance(result, (list, dict, tuple)) or hasattr(result, "__dict__")):
                        raise PySparkRuntimeError(
                            errorClass="UDTF_INVALID_OUTPUT_ROW_TYPE",
                            messageParameters={
                                "type": type(result).__name__,
                                "func": f.__name__,
                            },
                        )
                    if check_output_row_against_schema is not None:
                        check_output_row_against_schema(result)
                return toInternal(result)

            # Evaluate the function and return a tuple back to the executor.
            def evaluate(*a) -> tuple:
                try:
                    res = f(*a)
                except SkipRestOfInputTableException:
                    raise
                except Exception as e:
                    raise PySparkRuntimeError(
                        errorClass="UDTF_EXEC_ERROR",
                        messageParameters={"method_name": f.__name__, "error": str(e)},
                    )

                if res is None:
                    # If the function returns None or does not have an explicit return statement,
                    # an empty tuple is returned to the executor.
                    # This is because directly constructing tuple(None) results in an exception.
                    return tuple()

                if not isinstance(res, Iterable):
                    raise PySparkRuntimeError(
                        errorClass="UDTF_RETURN_NOT_ITERABLE",
                        messageParameters={
                            "type": type(res).__name__,
                            "func": f.__name__,
                        },
                    )

                # If the function returns a result, we map it to the internal representation and
                # returns the results as a tuple.
                return tuple(map(verify_and_convert_result, res))

            return evaluate

        eval_func_kwargs_support, args_kwargs_offsets = wrap_kwargs_support(
            getattr(udtf, "eval"), args_offsets, kwargs_offsets
        )
        eval = wrap_udtf(eval_func_kwargs_support, return_type)

        if hasattr(udtf, "terminate"):
            terminate = wrap_udtf(getattr(udtf, "terminate"), return_type)
        else:
            terminate = None

        cleanup = getattr(udtf, "cleanup") if hasattr(udtf, "cleanup") else None

        # Return an iterator of iterators.
        def mapper(_, it):
            try:
                for a in it:
                    yield eval(*[a[o] for o in args_kwargs_offsets])
                if terminate is not None:
                    yield terminate()
            except SkipRestOfInputTableException:
                if terminate is not None:
                    yield terminate()
            finally:
                if cleanup is not None:
                    cleanup()

        return mapper, None, ser, ser