map_processor.py (125 lines of code) (raw):
# Copyright 2025 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""Chaining of a sequence of 1:many function across a stream concurrently.
To execute all functions concurrently we build up a tree of results.
Given:
- A content stream of c={c0, c1, .., cN}
- A sequence of functions f={f0, f1, .., fM}
We have the following tree with streaming content on the x-axis and functions
on the y-axis.
-> Content Stream ->
c0 c1 c2 .. cN
f0 o o p ..
|\ | |
f1 p o o o
| |
f2 o o
| |
..
| |
fM r r
o = complete node (all children fetched)
p = pending node (potentially more children)
r = result node (has a result to be yielded)
For the tree above the following has happened:
- f0(c0) has yielded 2 children
- f1(f0(c0)[0]) is executing and hasn't yielded any children yet
- fM(... f0(c0)[1]) has executed all functions and has a result to be returned
- fM(... f0(c1)[0]) has executed all functions and has a result to be returned
- f0(c2) has yielded 1 child and is pending (may yield more children)
- f1(f0(c2)[0]) has finished and returned no children
Currently no results have been returned to the user as the left-most node is
still pending. Once the 'f1(f0(c0)[0])' sub-tree completes it's results
(along with the two that are buffered) will be returned in order.
"""
import asyncio
from collections.abc import AsyncIterable, Callable, Iterable, Sequence
import functools
from typing import TypeAlias, TypeVar
from genai_processors import context
from genai_processors import streams
_T = TypeVar('_T')
# Models a part function (i.e. part processors): T -> stream[T]
PartFn: TypeAlias = Callable[[_T], AsyncIterable[_T]]
# Models a function that returns True if the part should be processed. When a
# part should not be processed, the part processor will not be called and
# the part will be passed as is for chains, and apply, or will be dropped for
# parallel.
MatchFn: TypeAlias = Callable[[_T], bool]
PartWithMatchFn: TypeAlias = tuple[PartFn, MatchFn]
# Models a stream function (i.e. processors): stream[T] -> stream[T]
StreamFn: TypeAlias = Callable[[AsyncIterable[_T]], AsyncIterable[_T]]
def apply_sync(fn: StreamFn, content: Iterable[_T]) -> list[_T]:
"""Applies a part function synchronously.
Args:
fn: the part function to apply to the content.
content: a collection of inputs/parts on which to apply the function.
Returns:
the content, with the function `fn` applied to each input/part.
"""
async def run_with_context():
async with context.context():
as_async = streams.stream_content(content)
return await streams.gather_stream(fn(as_async))
return asyncio.run(run_with_context())
def map_part_function(
fn: PartFn,
match_fn: MatchFn | None = None,
) -> StreamFn:
"""Converts a part function to a function taking a stream of parts.
Adds a context if missing to ensure error propagation.
Args:
fn: a function that can be applied on a single part.
match_fn: a function that returns True if the part should be processed by
the part function. When the part should not be processed, the part
processor will not be called and the part will be passed as is.
Returns:
A function that is applied concurrently across the parts of the input
stream.
"""
match_fn = match_fn or (lambda _: True)
return functools.partial(_apply_part_function, (fn, match_fn))
def _to_tuple_fns(
fns: Sequence[PartFn], match_fns: Sequence[MatchFn] | None
) -> Sequence[PartWithMatchFn]:
"""Converts a sequence of functions to a sequence of function tuples."""
match_fns = match_fns or [lambda _: True for _ in fns]
if len(fns) != len(match_fns):
raise ValueError(
'fns and match_fns must be the same length. Got'
f' {len(fns)} != {len(match_fns)}.'
)
return list(zip(fns, match_fns))
def chain_part_functions(
fns: Sequence[PartFn],
match_fns: Sequence[MatchFn] | None = None,
) -> PartFn:
"""Chain the `fns` and execute them concurrently.
See file comment.
Args:
fns: sequence of part functions to chain.
match_fns: sequence of functions that return True if the part should be
processed by the part function. When the part should not be processed, the
part function will not be called and the part will be passed as is. When
match_fns is not provided, all parts are processed by default.
Returns:
Part function that is a chain of the provided Sequence of functions.
Raises:
ValueError: if the length of fns and match_fns is not the same (when
match_fns is provided).
"""
return functools.partial(_chain_part_functions, _to_tuple_fns(fns, match_fns))
def parallel_part_functions(
fns: Sequence[PartFn],
match_fns: Sequence[MatchFn] | None = None,
with_default_output: bool = False,
with_always_output: bool = False,
) -> PartFn:
"""Combine `fns` to execute on _T in parallel across the `fns`.
Args:
fns: sequence of part functions to chain.
match_fns: sequence of functions that return True if the part should be
processed by the part function. When the part should not be processed, the
part function will not be called and nothing will be yielded by the part
function. When match_fns is not provided, all parts are processed by
default.
with_default_output: True when the parallel execution should fallback to
return the input part as is when fns do not return any output part.
with_always_output: True when the parallel execution should always return
the input part as is independent of the output of the fns. This is a
stronger condition than `with_default_output`. When `with_always_output`
is True, `with_default_output` is basically ignored.
Returns:
Part function that runs all functions 'fns' in parallel. The output stream
will keep the order of the input parts:
f_0(c) = c00, c01
f_1(c) = c10, c11, c12, c14
f_0(c) // f_1(c) = c00, c01, c10, c11, c12
"""
return functools.partial(
_parallel_part_functions,
_to_tuple_fns(fns, match_fns),
with_default_output=with_default_output,
with_always_output=with_always_output,
)
# -------- Part Function Methods ----------
class _Finished:
"""A constant that represents that a generator end has been reached."""
_FinishedT = type[_Finished]
def _eager_run_fn(
fn: PartFn,
part: _T,
) -> AsyncIterable[_T]:
"""Executes fn on part in an asyncio.task.
Must be called called in an async context. It eagerly schedules a task on
the event loop to execute the whole of `fn` on the part. Results from the
AsyncIterable returned by `fn` can be retrieved via the AsyncIterable returned
by this method.
Args:
fn: the part function to execute on the part.
part: the part to execute the function on.
Returns:
An AsyncIterable that can be used to retrieve the results of `fn` on `part`
in order.
NOTE: this method is non-blocking.
"""
q = asyncio.Queue[_T | _FinishedT]()
async def call_fn():
async for c in fn(part):
q.put_nowait(c)
q.put_nowait(_Finished)
# Adds execution to the event loop
context.create_task(call_fn())
# AsyncIterable to retrieve results from q
async def result_iter():
while (c := await q.get()) is not _Finished:
yield c
return result_iter()
async def _passthrough(part: _T) -> AsyncIterable[_T]:
yield part
async def _result_aiter(
q: asyncio.Queue[AsyncIterable[_T] | _FinishedT],
) -> AsyncIterable[_T]:
"""Flattens the queue of aiters into a single aiter."""
while (c_iter := await q.get()) is not _Finished:
async for c in c_iter:
yield c
def _chain_part_functions(
fns: Sequence[PartWithMatchFn],
part: _T,
) -> AsyncIterable[_T]:
"""Executes a sequence of functions (fn) on a part.
This executes a tree of work as describe in the module level comment.
Consider composing 2 PartFns `(f, g)` on a part `c`. We create a new
PartFn:
```
c -> flatten(g(r) for r in f(c))
```
This is a tree of work with depth 2. `_chain_part_functions` supports
arbitrary amounts of functions.
This must be called called in an async context. It immediately schedules tasks
on the event loop to execute the tree of work on the part. Results from the
AsyncIterable returned by composing `fns` can be retrieved via the
AsyncIterable returned by this method.
Args:
fns: the function tuples to execute on the part. The first element of the
tuple is the part function, the second element is a function that returns
True if the part should be processed by the part function. When the part
should not be processed, the part processor will not be called and the
part will be passed as is saving the creation of a new task.
part: the part to execute the function on.
Returns:
An AsyncIterable that can be used to retrieve the results of running the
composition of `fns` on `part` in order.
NOTE: this method is non-blocking.
"""
(fn, match_fn), *fns = fns
if not fns:
# Base case - do not spawn a new task if part is just passed through.
if match_fn(part):
return _eager_run_fn(fn, part)
else:
return _passthrough(part)
else:
# Recursive case
q = asyncio.Queue[AsyncIterable[_T] | _FinishedT]()
if not match_fn(part):
return _chain_part_functions(fns, part)
async def f():
async for c in fn(part):
q.put_nowait(_chain_part_functions(fns, c))
q.put_nowait(_Finished)
context.create_task(f())
return _result_aiter(q)
def _apply_part_function(
fn: PartWithMatchFn, content: AsyncIterable[_T]
) -> AsyncIterable[_T]:
"""Applies a part function to a stream of parts."""
q = asyncio.Queue[AsyncIterable[_T] | _FinishedT]()
fn, match_fn = fn
async def f():
async for c in content:
if match_fn(c):
q.put_nowait(_eager_run_fn(fn, c))
else:
q.put_nowait(_passthrough(c))
q.put_nowait(_Finished)
# Adds execution to the event loop
context.create_task(f())
return _result_aiter(q)
def _parallel_part_functions(
fns: Sequence[PartWithMatchFn],
part: _T,
with_default_output: bool = False,
with_always_output: bool = False,
) -> AsyncIterable[_T]:
"""Executes each part function in a sequence of part functions concurrently.
This method is similar to `_chain_part_functions` except that all of the
PartFns are exectued on exactly `part` instead of being chained together.
The resulting AsyncIterables returned by call each fn are concatenated
together in the provided fns order.
This must be called called in an async context. It immediately schedules tasks
on the event loop to execute each fn in fns on on the part.
Args:
fns: the part functions to execute on the part.
part: the part to execute the function on.
with_default_output: When True if the resulting Iterable is empty `part`
will be yielded.
with_always_output: When True the input part will be yielded regardless of
the output of the fns. This is a stronger condition than
`with_default_output`. When `with_always_output` is True,
`with_default_output` is basically ignored.
Returns:
An AsyncIterable that can be used to retrieve the results.
NOTE: this method is non-blocking.
"""
c_iters = [_eager_run_fn(fn, part) for fn, match_fn in fns if match_fn(part)]
async def result_iter():
has_output = False
for c_iter in c_iters:
async for c in c_iter:
has_output = True
yield c
if with_always_output or (not has_output and with_default_output):
yield part
return result_iter()