switch.py (106 lines of code) (raw):
# Copyright 2025 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Switch processors to route parts to different processorss."""
from __future__ import annotations
import asyncio
from collections.abc import AsyncIterable, Callable
from typing import Generic, Self, TypeAlias, TypeVar
from genai_processors import content_api
from genai_processors import processor
from genai_processors import streams
_T = TypeVar('_T')
ProcessorPart: TypeAlias = content_api.ProcessorPart
PartProcessor: TypeAlias = processor.PartProcessor
Processor: TypeAlias = processor.Processor
class Switch(Processor, Generic[_T]):
"""Switch between processors.
Convenient way to create a processor that route the parts of the input stream
to different processors based on a condition (aka a case). The condition can
be:
1. a function that takes a `ProcessorPart` and returns a boolean. Example:
```python
switch_processor = (
switch.Switch()
.case(content_api.is_audio, audio_processor)
.case(content_api.is_video, video_processor)
.default(processor.passthrough())
)
```
2. a function that takes any value returned by the match_fn passed in the
constructor and returns a boolean. We have a shortcut for boolean functions
that tests for equality, e.g. `lambda x: x == "a"`. They can be replace
with the value itself , e.g. `"a"`. Example:
```python
# The match_fn is applied to the input part and the result is compared to
# the value passed in the case() method.
switch_processor = (
switch.Switch(content_api.get_substream_name)
.case("a", p) # equivalent to .case(lambda x: x == "a")
.case("b", q) # equivalent to .case(lambda x: x == "b")
.default(processor.passthrough())
)
```
The order of the parts in the output and input streams is only kept for parts
returned by the same processor, i.e. two parts matching two different cases
are not guaranteed to be in the same order in the input and output stream.
PartProcessors can be used instead of Processor, they will be converted to
Processor automatically. The `processor.passthrough()` (a PartProcessor) in
the examples above is equivalent to passing
`processor.passthrough().to_processor()`.
If the cases involve `PartProcessor`s only, it is best to use the
`PartSwitch` class, which is optimized for concurrent processing of parts.
"""
def __init__(
self,
match_fn: Callable[[ProcessorPart], _T] | None = None,
):
self._cases: list[tuple[Callable[[_T], bool], Processor]] = []
self._match = match_fn
self._default_set = False
async def call(
self, content: AsyncIterable[ProcessorPart]
) -> AsyncIterable[content_api.ProcessorPartTypes]:
input_queues = [asyncio.Queue() for _ in range(len(self._cases))]
async def _triage():
"""Triage the input parts to the correct input queue."""
async for part in content:
for i, (filter_fn, _) in enumerate(self._cases):
if filter_fn(part):
await input_queues[i].put(part)
break
for q in input_queues:
await q.put(None)
triage_task = processor.create_task(_triage())
# Process the parts in the input queues and merge all results.
output_streams = [
self._cases[i][1](streams.dequeue(queue))
for i, queue in enumerate(input_queues)
]
async for part in streams.merge(output_streams):
yield part
await triage_task
def case(
self,
v: _T | Callable[[_T], bool],
p: Processor | PartProcessor,
) -> Self:
if self._default_set:
raise ValueError(
f'This case is added after the default processor is set: {v}'
)
if self._match is None:
self._match = lambda x: x
if isinstance(p, PartProcessor):
case_processor = p.to_processor()
else:
case_processor = p
if isinstance(v, Callable):
self._cases.append((lambda x: v(self._match(x)), case_processor))
else:
self._cases.append((lambda x: v == self._match(x), case_processor))
return self
def default(self, p: Processor | PartProcessor) -> Self:
if self._default_set:
raise ValueError('The default processor is already set.')
if isinstance(p, PartProcessor):
self._cases.append((lambda x: True, p.to_processor()))
else:
self._cases.append((lambda x: True, p))
self._default_set = True
return self
class PartSwitch(PartProcessor, Generic[_T]):
"""Switch between part processors.
Convenient way to create a switch processor that runs the first case that
matches. A case is defined by a condition and a PartProcessor. The
condition can be a function that takes a `ProcessorPart` and returns a boolean
or a value. In the latter case, the condition is compared to the result of the
match_fn passed to the constructor.
By default, the switch processor does not return any part when no case
matches. To return a part in that case, the default() method should be
called after all cases have been added.
Example usage:
Simple match conditions based on equality:
```python
# Applies content_api.as_text on the input part and checks equality with
# the string "a" or "b".
switch_processor = (
switch.Switch(content_api.as_text)
.case("a", p)
.case("b", q)
.default(processor.passthrough())
```
Note that you can also compare parts directly by using the `lambda x: x`
function in `Switch` but it is not recommended. For more complex match
conditions, you can use cases operating on `part` directly as follows:
```python
# Applies the lambda function defined in `case` on the input part to check
# which case is valid.
switch_processor = (
switch.Switch()
.case(lambda x: x.text.startswith("a"), p)
.case(lambda x: x.text.startswith("b"), q)
.default()
)
```
"""
def __init__(
self,
match_fn: Callable[[ProcessorPart], _T] | None = None,
):
self._cases: list[tuple[Callable[[_T], bool], PartProcessor]] = []
self._match = match_fn
self._default_set = False
async def call(
self, part: ProcessorPart
) -> AsyncIterable[content_api.ProcessorPartTypes]:
for filter_fn, p in self._cases:
if filter_fn(part):
async for c in p(part):
yield c
break
def case(
self,
v: _T | Callable[[_T], bool],
p: PartProcessor,
) -> Self:
if self._default_set:
raise ValueError(
f'This case is added after the default processor is set: {v}'
)
if self._match is None:
self._match = lambda x: x
if isinstance(v, Callable):
self._cases.append((lambda x: v(self._match(x)), p))
else:
self._cases.append((lambda x: v == self._match(x), p))
return self
def default(self, p: PartProcessor) -> Self:
if self._default_set:
raise ValueError('The default processor is already set.')
self._cases.append((lambda x: True, p))
self._default_set = True
return self