switch.py (106 lines of code) (raw):

# Copyright 2025 DeepMind Technologies Limited. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Switch processors to route parts to different processorss.""" from __future__ import annotations import asyncio from collections.abc import AsyncIterable, Callable from typing import Generic, Self, TypeAlias, TypeVar from genai_processors import content_api from genai_processors import processor from genai_processors import streams _T = TypeVar('_T') ProcessorPart: TypeAlias = content_api.ProcessorPart PartProcessor: TypeAlias = processor.PartProcessor Processor: TypeAlias = processor.Processor class Switch(Processor, Generic[_T]): """Switch between processors. Convenient way to create a processor that route the parts of the input stream to different processors based on a condition (aka a case). The condition can be: 1. a function that takes a `ProcessorPart` and returns a boolean. Example: ```python switch_processor = ( switch.Switch() .case(content_api.is_audio, audio_processor) .case(content_api.is_video, video_processor) .default(processor.passthrough()) ) ``` 2. a function that takes any value returned by the match_fn passed in the constructor and returns a boolean. We have a shortcut for boolean functions that tests for equality, e.g. `lambda x: x == "a"`. They can be replace with the value itself , e.g. `"a"`. Example: ```python # The match_fn is applied to the input part and the result is compared to # the value passed in the case() method. switch_processor = ( switch.Switch(content_api.get_substream_name) .case("a", p) # equivalent to .case(lambda x: x == "a") .case("b", q) # equivalent to .case(lambda x: x == "b") .default(processor.passthrough()) ) ``` The order of the parts in the output and input streams is only kept for parts returned by the same processor, i.e. two parts matching two different cases are not guaranteed to be in the same order in the input and output stream. PartProcessors can be used instead of Processor, they will be converted to Processor automatically. The `processor.passthrough()` (a PartProcessor) in the examples above is equivalent to passing `processor.passthrough().to_processor()`. If the cases involve `PartProcessor`s only, it is best to use the `PartSwitch` class, which is optimized for concurrent processing of parts. """ def __init__( self, match_fn: Callable[[ProcessorPart], _T] | None = None, ): self._cases: list[tuple[Callable[[_T], bool], Processor]] = [] self._match = match_fn self._default_set = False async def call( self, content: AsyncIterable[ProcessorPart] ) -> AsyncIterable[content_api.ProcessorPartTypes]: input_queues = [asyncio.Queue() for _ in range(len(self._cases))] async def _triage(): """Triage the input parts to the correct input queue.""" async for part in content: for i, (filter_fn, _) in enumerate(self._cases): if filter_fn(part): await input_queues[i].put(part) break for q in input_queues: await q.put(None) triage_task = processor.create_task(_triage()) # Process the parts in the input queues and merge all results. output_streams = [ self._cases[i][1](streams.dequeue(queue)) for i, queue in enumerate(input_queues) ] async for part in streams.merge(output_streams): yield part await triage_task def case( self, v: _T | Callable[[_T], bool], p: Processor | PartProcessor, ) -> Self: if self._default_set: raise ValueError( f'This case is added after the default processor is set: {v}' ) if self._match is None: self._match = lambda x: x if isinstance(p, PartProcessor): case_processor = p.to_processor() else: case_processor = p if isinstance(v, Callable): self._cases.append((lambda x: v(self._match(x)), case_processor)) else: self._cases.append((lambda x: v == self._match(x), case_processor)) return self def default(self, p: Processor | PartProcessor) -> Self: if self._default_set: raise ValueError('The default processor is already set.') if isinstance(p, PartProcessor): self._cases.append((lambda x: True, p.to_processor())) else: self._cases.append((lambda x: True, p)) self._default_set = True return self class PartSwitch(PartProcessor, Generic[_T]): """Switch between part processors. Convenient way to create a switch processor that runs the first case that matches. A case is defined by a condition and a PartProcessor. The condition can be a function that takes a `ProcessorPart` and returns a boolean or a value. In the latter case, the condition is compared to the result of the match_fn passed to the constructor. By default, the switch processor does not return any part when no case matches. To return a part in that case, the default() method should be called after all cases have been added. Example usage: Simple match conditions based on equality: ```python # Applies content_api.as_text on the input part and checks equality with # the string "a" or "b". switch_processor = ( switch.Switch(content_api.as_text) .case("a", p) .case("b", q) .default(processor.passthrough()) ``` Note that you can also compare parts directly by using the `lambda x: x` function in `Switch` but it is not recommended. For more complex match conditions, you can use cases operating on `part` directly as follows: ```python # Applies the lambda function defined in `case` on the input part to check # which case is valid. switch_processor = ( switch.Switch() .case(lambda x: x.text.startswith("a"), p) .case(lambda x: x.text.startswith("b"), q) .default() ) ``` """ def __init__( self, match_fn: Callable[[ProcessorPart], _T] | None = None, ): self._cases: list[tuple[Callable[[_T], bool], PartProcessor]] = [] self._match = match_fn self._default_set = False async def call( self, part: ProcessorPart ) -> AsyncIterable[content_api.ProcessorPartTypes]: for filter_fn, p in self._cases: if filter_fn(part): async for c in p(part): yield c break def case( self, v: _T | Callable[[_T], bool], p: PartProcessor, ) -> Self: if self._default_set: raise ValueError( f'This case is added after the default processor is set: {v}' ) if self._match is None: self._match = lambda x: x if isinstance(v, Callable): self._cases.append((lambda x: v(self._match(x)), p)) else: self._cases.append((lambda x: v == self._match(x), p)) return self def default(self, p: PartProcessor) -> Self: if self._default_set: raise ValueError('The default processor is already set.') self._cases.append((lambda x: True, p)) self._default_set = True return self