in python/datafusion/udf.py [0:0]
def udaf(*args: Any, **kwargs: Any): # noqa: D417
"""Create a new User-Defined Aggregate Function (UDAF).
This class allows you to define an **aggregate function** that can be used in
data aggregation or window function calls.
Usage:
- **As a function**: Call `udaf(accum, input_types, return_type, state_type,
volatility, name)`.
- **As a decorator**: Use `@udaf(input_types, return_type, state_type,
volatility, name)`.
When using `udaf` as a decorator, **do not pass `accum` explicitly**.
**Function example:**
If your `:py:class:Accumulator` can be instantiated with no arguments, you
can simply pass it's type as `accum`. If you need to pass additional
arguments to it's constructor, you can define a lambda or a factory method.
During runtime the `:py:class:Accumulator` will be constructed for every
instance in which this UDAF is used. The following examples are all valid.
```
import pyarrow as pa
import pyarrow.compute as pc
class Summarize(Accumulator):
def __init__(self, bias: float = 0.0):
self._sum = pa.scalar(bias)
def state(self) -> list[pa.Scalar]:
return [self._sum]
def update(self, values: pa.Array) -> None:
self._sum = pa.scalar(self._sum.as_py() + pc.sum(values).as_py())
def merge(self, states: list[pa.Array]) -> None:
self._sum = pa.scalar(self._sum.as_py() + pc.sum(states[0]).as_py())
def evaluate(self) -> pa.Scalar:
return self._sum
def sum_bias_10() -> Summarize:
return Summarize(10.0)
udaf1 = udaf(Summarize, pa.float64(), pa.float64(), [pa.float64()],
"immutable")
udaf2 = udaf(sum_bias_10, pa.float64(), pa.float64(), [pa.float64()],
"immutable")
udaf3 = udaf(lambda: Summarize(20.0), pa.float64(), pa.float64(),
[pa.float64()], "immutable")
```
**Decorator example:**
```
@udaf(pa.float64(), pa.float64(), [pa.float64()], "immutable")
def udf4() -> Summarize:
return Summarize(10.0)
```
Args:
accum: The accumulator python function. **Only needed when calling as a
function. Skip this argument when using `udaf` as a decorator.**
input_types: The data types of the arguments to ``accum``.
return_type: The data type of the return value.
state_type: The data types of the intermediate accumulation.
volatility: See :py:class:`Volatility` for allowed values.
name: A descriptive name for the function.
Returns:
A user-defined aggregate function, which can be used in either data
aggregation or window function calls.
"""
def _function(
accum: Callable[[], Accumulator],
input_types: pa.DataType | list[pa.DataType],
return_type: pa.DataType,
state_type: list[pa.DataType],
volatility: Volatility | str,
name: Optional[str] = None,
) -> AggregateUDF:
if not callable(accum):
msg = "`func` must be callable."
raise TypeError(msg)
if not isinstance(accum(), Accumulator):
msg = "Accumulator must implement the abstract base class Accumulator"
raise TypeError(msg)
if name is None:
name = accum().__class__.__qualname__.lower()
if isinstance(input_types, pa.DataType):
input_types = [input_types]
return AggregateUDF(
name=name,
accumulator=accum,
input_types=input_types,
return_type=return_type,
state_type=state_type,
volatility=volatility,
)
def _decorator(
input_types: pa.DataType | list[pa.DataType],
return_type: pa.DataType,
state_type: list[pa.DataType],
volatility: Volatility | str,
name: Optional[str] = None,
) -> Callable[..., Callable[..., Expr]]:
def decorator(accum: Callable[[], Accumulator]) -> Callable[..., Expr]:
udaf_caller = AggregateUDF.udaf(
accum, input_types, return_type, state_type, volatility, name
)
@functools.wraps(accum)
def wrapper(*args: Any, **kwargs: Any) -> Expr:
return udaf_caller(*args, **kwargs)
return wrapper
return decorator
if args and callable(args[0]):
# Case 1: Used as a function, require the first parameter to be callable
return _function(*args, **kwargs)
# Case 2: Used as a decorator with parameters
return _decorator(*args, **kwargs)