def read_udfs()

in python/pyspark/worker.py [0:0]


def read_udfs(pickleSer, infile, eval_type):
    runner_conf = {}

    state_server_port = None
    key_schema = None
    if eval_type in (
        PythonEvalType.SQL_ARROW_BATCHED_UDF,
        PythonEvalType.SQL_SCALAR_PANDAS_UDF,
        PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF,
        PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF,
        PythonEvalType.SQL_MAP_PANDAS_ITER_UDF,
        PythonEvalType.SQL_MAP_ARROW_ITER_UDF,
        PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
        PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF,
        PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF,
        PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE,
        PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF,
        PythonEvalType.SQL_COGROUPED_MAP_ARROW_UDF,
        PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_UDF,
        PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_INIT_STATE_UDF,
        PythonEvalType.SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_UDF,
        PythonEvalType.SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_INIT_STATE_UDF,
    ):
        # Load conf used for pandas_udf evaluation
        num_conf = read_int(infile)
        for i in range(num_conf):
            k = utf8_deserializer.loads(infile)
            v = utf8_deserializer.loads(infile)
            runner_conf[k] = v

        state_object_schema = None
        if eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE:
            state_object_schema = StructType.fromJson(json.loads(utf8_deserializer.loads(infile)))
        elif (
            eval_type == PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_UDF
            or eval_type == PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_INIT_STATE_UDF
            or eval_type == PythonEvalType.SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_UDF
            or eval_type == PythonEvalType.SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_INIT_STATE_UDF
        ):
            state_server_port = read_int(infile)
            if state_server_port == -1:
                state_server_port = utf8_deserializer.loads(infile)
            key_schema = StructType.fromJson(json.loads(utf8_deserializer.loads(infile)))

        # NOTE: if timezone is set here, that implies respectSessionTimeZone is True
        timezone = runner_conf.get("spark.sql.session.timeZone", None)
        prefers_large_var_types = use_large_var_types(runner_conf)
        safecheck = (
            runner_conf.get("spark.sql.execution.pandas.convertToArrowArraySafely", "false").lower()
            == "true"
        )
        _assign_cols_by_name = assign_cols_by_name(runner_conf)

        if eval_type == PythonEvalType.SQL_COGROUPED_MAP_ARROW_UDF:
            ser = CogroupArrowUDFSerializer(_assign_cols_by_name)
        elif eval_type == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF:
            ser = CogroupPandasUDFSerializer(timezone, safecheck, _assign_cols_by_name)
        elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE:
            arrow_max_records_per_batch = runner_conf.get(
                "spark.sql.execution.arrow.maxRecordsPerBatch", 10000
            )
            arrow_max_records_per_batch = int(arrow_max_records_per_batch)

            ser = ApplyInPandasWithStateSerializer(
                timezone,
                safecheck,
                _assign_cols_by_name,
                state_object_schema,
                arrow_max_records_per_batch,
                prefers_large_var_types,
            )
        elif eval_type == PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_UDF:
            arrow_max_records_per_batch = runner_conf.get(
                "spark.sql.execution.arrow.maxRecordsPerBatch", 10000
            )
            arrow_max_records_per_batch = int(arrow_max_records_per_batch)

            ser = TransformWithStateInPandasSerializer(
                timezone, safecheck, _assign_cols_by_name, arrow_max_records_per_batch
            )
        elif eval_type == PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_INIT_STATE_UDF:
            arrow_max_records_per_batch = runner_conf.get(
                "spark.sql.execution.arrow.maxRecordsPerBatch", 10000
            )
            arrow_max_records_per_batch = int(arrow_max_records_per_batch)

            ser = TransformWithStateInPandasInitStateSerializer(
                timezone, safecheck, _assign_cols_by_name, arrow_max_records_per_batch
            )
        elif eval_type == PythonEvalType.SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_UDF:
            arrow_max_records_per_batch = runner_conf.get(
                "spark.sql.execution.arrow.maxRecordsPerBatch", 10000
            )
            arrow_max_records_per_batch = int(arrow_max_records_per_batch)

            ser = TransformWithStateInPySparkRowSerializer(arrow_max_records_per_batch)
        elif eval_type == PythonEvalType.SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_INIT_STATE_UDF:
            arrow_max_records_per_batch = runner_conf.get(
                "spark.sql.execution.arrow.maxRecordsPerBatch", 10000
            )
            arrow_max_records_per_batch = int(arrow_max_records_per_batch)

            ser = TransformWithStateInPySparkRowInitStateSerializer(arrow_max_records_per_batch)
        elif eval_type == PythonEvalType.SQL_MAP_ARROW_ITER_UDF:
            ser = ArrowStreamUDFSerializer()
        elif eval_type == PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF:
            ser = ArrowStreamGroupUDFSerializer(_assign_cols_by_name)
        else:
            # Scalar Pandas UDF handles struct type arguments as pandas DataFrames instead of
            # pandas Series. See SPARK-27240.
            df_for_struct = (
                eval_type == PythonEvalType.SQL_SCALAR_PANDAS_UDF
                or eval_type == PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF
                or eval_type == PythonEvalType.SQL_MAP_PANDAS_ITER_UDF
            )
            # Arrow-optimized Python UDF takes a struct type argument as a Row
            struct_in_pandas = (
                "row" if eval_type == PythonEvalType.SQL_ARROW_BATCHED_UDF else "dict"
            )
            ndarray_as_list = eval_type == PythonEvalType.SQL_ARROW_BATCHED_UDF
            # Arrow-optimized Python UDF uses explicit Arrow cast for type coercion
            arrow_cast = eval_type == PythonEvalType.SQL_ARROW_BATCHED_UDF
            # Arrow-optimized Python UDF takes input types
            input_types = (
                [f.dataType for f in _parse_datatype_json_string(utf8_deserializer.loads(infile))]
                if eval_type == PythonEvalType.SQL_ARROW_BATCHED_UDF
                else None
            )

            ser = ArrowStreamPandasUDFSerializer(
                timezone,
                safecheck,
                _assign_cols_by_name,
                df_for_struct,
                struct_in_pandas,
                ndarray_as_list,
                arrow_cast,
                input_types,
            )
    else:
        batch_size = int(os.environ.get("PYTHON_UDF_BATCH_SIZE", "100"))
        ser = BatchedSerializer(CPickleSerializer(), batch_size)

    is_profiling = read_bool(infile)
    if is_profiling:
        profiler = utf8_deserializer.loads(infile)
    else:
        profiler = None

    num_udfs = read_int(infile)

    is_scalar_iter = eval_type == PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF
    is_map_pandas_iter = eval_type == PythonEvalType.SQL_MAP_PANDAS_ITER_UDF
    is_map_arrow_iter = eval_type == PythonEvalType.SQL_MAP_ARROW_ITER_UDF

    if is_scalar_iter or is_map_pandas_iter or is_map_arrow_iter:
        if is_scalar_iter:
            assert num_udfs == 1, "One SCALAR_ITER UDF expected here."
        if is_map_pandas_iter:
            assert num_udfs == 1, "One MAP_PANDAS_ITER UDF expected here."
        if is_map_arrow_iter:
            assert num_udfs == 1, "One MAP_ARROW_ITER UDF expected here."

        arg_offsets, udf = read_single_udf(
            pickleSer, infile, eval_type, runner_conf, udf_index=0, profiler=profiler
        )

        def func(_, iterator):
            num_input_rows = 0

            def map_batch(batch):
                nonlocal num_input_rows

                udf_args = [batch[offset] for offset in arg_offsets]
                num_input_rows += len(udf_args[0])
                if len(udf_args) == 1:
                    return udf_args[0]
                else:
                    return tuple(udf_args)

            iterator = map(map_batch, iterator)
            result_iter = udf(iterator)

            num_output_rows = 0
            for result_batch, result_type in result_iter:
                num_output_rows += len(result_batch)
                # This check is for Scalar Iterator UDF to fail fast.
                # The length of the entire input can only be explicitly known
                # by consuming the input iterator in user side. Therefore,
                # it's very unlikely the output length is higher than
                # input length.
                if is_scalar_iter and num_output_rows > num_input_rows:
                    raise PySparkRuntimeError(
                        errorClass="PANDAS_UDF_OUTPUT_EXCEEDS_INPUT_ROWS", messageParameters={}
                    )
                yield (result_batch, result_type)

            if is_scalar_iter:
                try:
                    next(iterator)
                except StopIteration:
                    pass
                else:
                    raise PySparkRuntimeError(
                        errorClass="STOP_ITERATION_OCCURRED_FROM_SCALAR_ITER_PANDAS_UDF",
                        messageParameters={},
                    )

                if num_output_rows != num_input_rows:
                    raise PySparkRuntimeError(
                        errorClass="RESULT_LENGTH_MISMATCH_FOR_SCALAR_ITER_PANDAS_UDF",
                        messageParameters={
                            "output_length": str(num_output_rows),
                            "input_length": str(num_input_rows),
                        },
                    )

        # profiling is not supported for UDF
        return func, None, ser, ser

    def extract_key_value_indexes(grouped_arg_offsets):
        """
        Helper function to extract the key and value indexes from arg_offsets for the grouped and
        cogrouped pandas udfs. See BasePandasGroupExec.resolveArgOffsets for equivalent scala code.

        Parameters
        ----------
        grouped_arg_offsets:  list
            List containing the key and value indexes of columns of the
            DataFrames to be passed to the udf. It consists of n repeating groups where n is the
            number of DataFrames.  Each group has the following format:
                group[0]: length of group
                group[1]: length of key indexes
                group[2.. group[1] +2]: key attributes
                group[group[1] +3 group[0]]: value attributes
        """
        parsed = []
        idx = 0
        while idx < len(grouped_arg_offsets):
            offsets_len = grouped_arg_offsets[idx]
            idx += 1
            offsets = grouped_arg_offsets[idx : idx + offsets_len]
            split_index = offsets[0] + 1
            offset_keys = offsets[1:split_index]
            offset_values = offsets[split_index:]
            parsed.append([offset_keys, offset_values])
            idx += offsets_len
        return parsed

    if eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF:
        # We assume there is only one UDF here because grouped map doesn't
        # support combining multiple UDFs.
        assert num_udfs == 1

        # See FlatMapGroupsInPandasExec for how arg_offsets are used to
        # distinguish between grouping attributes and data attributes
        arg_offsets, f = read_single_udf(
            pickleSer, infile, eval_type, runner_conf, udf_index=0, profiler=profiler
        )
        parsed_offsets = extract_key_value_indexes(arg_offsets)

        # Create function like this:
        #   mapper a: f([a[0]], [a[0], a[1]])
        def mapper(a):
            keys = [a[o] for o in parsed_offsets[0][0]]
            vals = [a[o] for o in parsed_offsets[0][1]]
            return f(keys, vals)

    elif eval_type == PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_UDF:
        # We assume there is only one UDF here because grouped map doesn't
        # support combining multiple UDFs.
        assert num_udfs == 1

        # See TransformWithStateInPandasExec for how arg_offsets are used to
        # distinguish between grouping attributes and data attributes
        arg_offsets, f = read_single_udf(
            pickleSer, infile, eval_type, runner_conf, udf_index=0, profiler=profiler
        )
        parsed_offsets = extract_key_value_indexes(arg_offsets)
        ser.key_offsets = parsed_offsets[0][0]
        stateful_processor_api_client = StatefulProcessorApiClient(state_server_port, key_schema)

        def mapper(a):
            mode = a[0]

            if mode == TransformWithStateInPySparkFuncMode.PROCESS_DATA:
                key = a[1]

                def values_gen():
                    for x in a[2]:
                        retVal = [x[1][o] for o in parsed_offsets[0][1]]
                        yield retVal

                # This must be generator comprehension - do not materialize.
                return f(stateful_processor_api_client, mode, key, values_gen())
            else:
                # mode == PROCESS_TIMER or mode == COMPLETE
                return f(stateful_processor_api_client, mode, None, iter([]))

    elif eval_type == PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_INIT_STATE_UDF:
        # We assume there is only one UDF here because grouped map doesn't
        # support combining multiple UDFs.
        assert num_udfs == 1

        # See TransformWithStateInPandasExec for how arg_offsets are used to
        # distinguish between grouping attributes and data attributes
        arg_offsets, f = read_single_udf(
            pickleSer, infile, eval_type, runner_conf, udf_index=0, profiler=profiler
        )
        # parsed offsets:
        # [
        #     [groupingKeyOffsets, dedupDataOffsets],
        #     [initStateGroupingOffsets, dedupInitDataOffsets]
        # ]
        parsed_offsets = extract_key_value_indexes(arg_offsets)
        ser.key_offsets = parsed_offsets[0][0]
        ser.init_key_offsets = parsed_offsets[1][0]
        stateful_processor_api_client = StatefulProcessorApiClient(state_server_port, key_schema)

        def mapper(a):
            mode = a[0]

            if mode == TransformWithStateInPySparkFuncMode.PROCESS_DATA:
                key = a[1]

                def values_gen():
                    for x in a[2]:
                        retVal = [x[1][o] for o in parsed_offsets[0][1]]
                        initVal = [x[2][o] for o in parsed_offsets[1][1]]
                        yield retVal, initVal

                # This must be generator comprehension - do not materialize.
                return f(stateful_processor_api_client, mode, key, values_gen())
            else:
                # mode == PROCESS_TIMER or mode == COMPLETE
                return f(stateful_processor_api_client, mode, None, iter([]))

    elif eval_type == PythonEvalType.SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_UDF:
        # We assume there is only one UDF here because grouped map doesn't
        # support combining multiple UDFs.
        assert num_udfs == 1

        # See TransformWithStateInPySparkExec for how arg_offsets are used to
        # distinguish between grouping attributes and data attributes
        arg_offsets, f = read_single_udf(
            pickleSer, infile, eval_type, runner_conf, udf_index=0, profiler=profiler
        )
        parsed_offsets = extract_key_value_indexes(arg_offsets)
        ser.key_offsets = parsed_offsets[0][0]
        stateful_processor_api_client = StatefulProcessorApiClient(state_server_port, key_schema)

        def mapper(a):
            mode = a[0]

            if mode == TransformWithStateInPySparkFuncMode.PROCESS_DATA:
                key = a[1]
                values = a[2]

                # This must be generator comprehension - do not materialize.
                return f(stateful_processor_api_client, mode, key, values)
            else:
                # mode == PROCESS_TIMER or mode == COMPLETE
                return f(stateful_processor_api_client, mode, None, iter([]))

    elif eval_type == PythonEvalType.SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_INIT_STATE_UDF:
        # We assume there is only one UDF here because grouped map doesn't
        # support combining multiple UDFs.
        assert num_udfs == 1

        # See TransformWithStateInPandasExec for how arg_offsets are used to
        # distinguish between grouping attributes and data attributes
        arg_offsets, f = read_single_udf(
            pickleSer, infile, eval_type, runner_conf, udf_index=0, profiler=profiler
        )
        # parsed offsets:
        # [
        #     [groupingKeyOffsets, dedupDataOffsets],
        #     [initStateGroupingOffsets, dedupInitDataOffsets]
        # ]
        parsed_offsets = extract_key_value_indexes(arg_offsets)
        ser.key_offsets = parsed_offsets[0][0]
        ser.init_key_offsets = parsed_offsets[1][0]
        stateful_processor_api_client = StatefulProcessorApiClient(state_server_port, key_schema)

        def mapper(a):
            mode = a[0]

            if mode == TransformWithStateInPySparkFuncMode.PROCESS_DATA:
                key = a[1]
                values = a[2]

                # This must be generator comprehension - do not materialize.
                return f(stateful_processor_api_client, mode, key, values)
            else:
                # mode == PROCESS_TIMER or mode == COMPLETE
                return f(stateful_processor_api_client, mode, None, iter([]))

    elif eval_type == PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF:
        import pyarrow as pa

        # We assume there is only one UDF here because grouped map doesn't
        # support combining multiple UDFs.
        assert num_udfs == 1

        # See FlatMapGroupsInPandasExec for how arg_offsets are used to
        # distinguish between grouping attributes and data attributes
        arg_offsets, f = read_single_udf(
            pickleSer, infile, eval_type, runner_conf, udf_index=0, profiler=profiler
        )
        parsed_offsets = extract_key_value_indexes(arg_offsets)

        def batch_from_offset(batch, offsets):
            return pa.RecordBatch.from_arrays(
                arrays=[batch.columns[o] for o in offsets],
                names=[batch.schema.names[o] for o in offsets],
            )

        def table_from_batches(batches, offsets):
            return pa.Table.from_batches([batch_from_offset(batch, offsets) for batch in batches])

        def mapper(a):
            keys = table_from_batches(a, parsed_offsets[0][0])
            vals = table_from_batches(a, parsed_offsets[0][1])
            return f(keys, vals)

    elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE:
        # We assume there is only one UDF here because grouped map doesn't
        # support combining multiple UDFs.
        assert num_udfs == 1

        # See FlatMapGroupsInPandas(WithState)Exec for how arg_offsets are used to
        # distinguish between grouping attributes and data attributes
        arg_offsets, f = read_single_udf(
            pickleSer, infile, eval_type, runner_conf, udf_index=0, profiler=profiler
        )
        parsed_offsets = extract_key_value_indexes(arg_offsets)

        def mapper(a):
            """
            The function receives (iterator of data, state) and performs extraction of key and
            value from the data, with retaining lazy evaluation.

            See `load_stream` in `ApplyInPandasWithStateSerializer` for more details on the input
            and see `wrap_grouped_map_pandas_udf_with_state` for more details on how output will
            be used.
            """
            from itertools import tee

            state = a[1]
            data_gen = (x[0] for x in a[0])

            # We know there should be at least one item in the iterator/generator.
            # We want to peek the first element to construct the key, hence applying
            # tee to construct the key while we retain another iterator/generator
            # for values.
            keys_gen, values_gen = tee(data_gen)
            keys_elem = next(keys_gen)
            keys = [keys_elem[o] for o in parsed_offsets[0][0]]

            # This must be generator comprehension - do not materialize.
            vals = ([x[o] for o in parsed_offsets[0][1]] for x in values_gen)

            return f(keys, vals, state)

    elif eval_type == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF:
        # We assume there is only one UDF here because cogrouped map doesn't
        # support combining multiple UDFs.
        assert num_udfs == 1
        arg_offsets, f = read_single_udf(
            pickleSer, infile, eval_type, runner_conf, udf_index=0, profiler=profiler
        )

        parsed_offsets = extract_key_value_indexes(arg_offsets)

        def mapper(a):
            df1_keys = [a[0][o] for o in parsed_offsets[0][0]]
            df1_vals = [a[0][o] for o in parsed_offsets[0][1]]
            df2_keys = [a[1][o] for o in parsed_offsets[1][0]]
            df2_vals = [a[1][o] for o in parsed_offsets[1][1]]
            return f(df1_keys, df1_vals, df2_keys, df2_vals)

    elif eval_type == PythonEvalType.SQL_COGROUPED_MAP_ARROW_UDF:
        import pyarrow as pa

        # We assume there is only one UDF here because cogrouped map doesn't
        # support combining multiple UDFs.
        assert num_udfs == 1
        arg_offsets, f = read_single_udf(
            pickleSer, infile, eval_type, runner_conf, udf_index=0, profiler=profiler
        )

        parsed_offsets = extract_key_value_indexes(arg_offsets)

        def batch_from_offset(batch, offsets):
            return pa.RecordBatch.from_arrays(
                arrays=[batch.columns[o] for o in offsets],
                names=[batch.schema.names[o] for o in offsets],
            )

        def table_from_batches(batches, offsets):
            return pa.Table.from_batches([batch_from_offset(batch, offsets) for batch in batches])

        def mapper(a):
            df1_keys = table_from_batches(a[0], parsed_offsets[0][0])
            df1_vals = table_from_batches(a[0], parsed_offsets[0][1])
            df2_keys = table_from_batches(a[1], parsed_offsets[1][0])
            df2_vals = table_from_batches(a[1], parsed_offsets[1][1])
            return f(df1_keys, df1_vals, df2_keys, df2_vals)

    else:
        udfs = []
        for i in range(num_udfs):
            udfs.append(
                read_single_udf(
                    pickleSer, infile, eval_type, runner_conf, udf_index=i, profiler=profiler
                )
            )

        def mapper(a):
            result = tuple(f(*[a[o] for o in arg_offsets]) for arg_offsets, f in udfs)
            # In the special case of a single UDF this will return a single result rather
            # than a tuple of results; this is the format that the JVM side expects.
            if len(result) == 1:
                return result[0]
            else:
                return result

    def func(_, it):
        return map(mapper, it)

    # profiling is not supported for UDF
    return func, None, ser, ser