core/event_detection.py (144 lines of code) (raw):

# Copyright 2025 DeepMind Technologies Limited. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== r"""Event detection processor. This processor detects events from the images it receives in its input stream: * All inputs are passed to the output. * The last `max_images` encountered in the input are used to detect event. Non-image content is ignored. * If an event is detected, the `ProcessorParts` corresponding to the current transition of states (set in the `output_dict` passed in the constructor) are injected into the output. * Event detection is asynchronous so events may not necessarily be injected after the `ProcessorPart` on which it has been observed. This processor keeps making calls to the model passed in the constructor to detect events in the images. When the model is generating a response, the processor keeps collecting up to `max_images` coming up in the input stream and sends them all together to the model in the next call. `max_images` is a parameter that controls the number of images to keep in the input stream and should be tuned based on the image frequency in the input stream and the latency of the model to generate a response. With a FPS of 1, a `max_images` of 5 is a reasonable value. The set of events should be defined by an enum.StrEnum class. The model response will be a string that matches one of the values of the Enum. ```py import enum # Only use the enum.auto() function to define the values of the enum. Or ensure # all the string values are lower case. The empty string is used to define the # default start state. class EventState(enum.StrEnum): SUNNY = enum.auto() RAINING = enum.auto() RAINING_WITH_MEATBALLS = enum.auto() ``` The config passed to the constructor should contain the response schema for the event detection model. ```python config = { system_instruction=( "Determine the weather conditions under which these images have been" f" taken. Respond with '{EventName.SUNNY}' if the sun is shining, " f"'{EventName.RAINING}' if it is raining and" f"'{EventName.RAINING_WITH_MEATBALLS}' if the rain contains \meatballs." " You can classify any edible objects as meatballs.", "response_mime_type": "text/x.enum", "response_schema": EventName ... } ``` Each event state transition is associated to an output that the processor will generate when the event transitionis detected. The output should be a `ProcessorContentTypes` or None and is passed to the constructor via the `output_dict` argument. When the output is None, the transition is detected but no output is generated. By default, self-transitions are ignored, i.e. if the transition is from a state to the same state, no output will be generated. ```python output_dict = { # The "*" wild card can be used to define all states, i.e. transitions from # any state (including the start state) to the event state. ("*", EventState.EVENT_1): ProcessorPart( text="event_1 is detected", role="USER", end_of_turn=True, ), (EventState.EVENT_1, EventSate.EVENT_2): ProcessorPart( text="event_2 is detected", role="USER", end_of_turn=True, ), # No output for this transition. (EventState.EVENT_2, EventState.EVENT_1): None, } ``` If an observed event transition is not included in the `output_dict`, the processor will not generate any output for that transition and the new event state will be ignored, i.e. the next transition starting state will be the same as the previous one. Lastly, the sensitivity dictionary passed to the constructor is used to define the number of detections in a row that should happen before the event detection processor sends a detection output. The keys of this dictionary should be the event transitions and the values should be the number of detection in a row. This is helpful in noisy situations where you want to have confirmation of a detection before you generate an output for the event detection. """ import asyncio import collections import time from typing import AsyncIterable, Optional, TypeAlias from absl import logging from genai_processors import content_api from genai_processors import processor from genai_processors.core import timestamp from google import genai from google.genai import types as genai_types ProcessorPart = content_api.ProcessorPart # Name of the default start state, i.e. no event detected. START_STATE = "" # A transition between two events. A state is represented here by a string, # which is the name of the event (events should be defined by an enum.StrEnum). EventTransition: TypeAlias = tuple[str, str] class EventDetection(processor.Processor): """Event detection processor.""" def __init__( self, api_key: str, model: str, config: genai_types.GenerateContentConfig, output_dict: dict[ EventTransition, content_api.ProcessorContentTypes | None ], sensitivity: Optional[dict[EventTransition, int]] = None, max_images: int = 5, ): """Initializes the event detection processor. Args: api_key: The API key to use for the event detection model. model: The model to use for the event detection. config: The configuration to use for the event detection model. This configuration should contain the response schema for the event detection model. output_dict: A dictionary of transitions between events to the output to return when the transition is detected. A transition is a pair of event names `(from_event_state, to_event_state)`, where `from_event_state` can be the start state `START_STATE`, an event name, or the wild card `"*"` to define all transitions from any state (including the start state) to the event state. When the output is None, the transition is detected but no output is returned. sensitivity: A dictionary of transitions to the number of detection in a row that should happen before the event detection processor sends a detection output. By default, the sensitivity is 1 for a transition. max_images: The maximum number of images to keep in the input stream. """ self._client = genai.Client(api_key=api_key) self._model = model self._config = config if not config.response_schema: raise ValueError("Response schema is required for event detection.") self._sensitivity = sensitivity self._output_dict = {} self._init_output_dict(output_dict) self._last_transition = (START_STATE, START_STATE) self._transition_counter = (self._last_transition, 0) # deque of (image, timestamp) tuples. self._images = collections.deque[tuple[ProcessorPart, float]]( maxlen=max_images ) def _init_output_dict( self, output_dict: dict[EventTransition, content_api.ProcessorContentTypes], ): """Initializes the output dictionary.""" # Collect the list of event states from the output_dict (including start # state). event_states = [START_STATE] for transition in output_dict: event_states.append(transition[1]) # Add all wild card transitions. for transition, output in output_dict.items(): if transition[0] == "*": for event_state in event_states: self._output_dict[(event_state, transition[1])] = output # Add the other transitions: it will override the wild card transitions by # more specific transitions. for transition, output in output_dict.items(): if transition[0] != "*": self._output_dict[transition] = output async def detect_event( self, output_queue: asyncio.Queue[ProcessorPart], ): """Detects an event in the image.""" images_with_timestamp = [] start_time = None for image, t in self._images: if start_time is None: start_time = t images_with_timestamp.append(image) images_with_timestamp.append( content_api.ProcessorPart(timestamp.to_timestamp(t - start_time)) ) response = await self._client.aio.models.generate_content( model=self._model, config=self._config, contents=images_with_timestamp, ) logging.debug( "%s - Event detection response: %s / last transition: %s / transition" " counter: %s", time.perf_counter(), response.text, self._last_transition, self._transition_counter, ) if not response.text: logging.debug( "%s - No text response from the event detection model", time.perf_counter(), ) return event_name = response.text.lower() current_transition = (self._last_transition[1], event_name) if current_transition == self._transition_counter[0]: self._transition_counter = ( current_transition, self._transition_counter[1] + 1, ) else: self._transition_counter = (current_transition, 1) is_valid_transition = ( current_transition in self._output_dict or current_transition[0] == START_STATE ) is_valid_transition = ( is_valid_transition and current_transition[1] != current_transition[0] ) is_sensitivity_reached = ( current_transition not in self._sensitivity or self._transition_counter[1] > self._sensitivity[current_transition] ) if is_valid_transition and is_sensitivity_reached: logging.debug( "%s - New event transition: %s", time.perf_counter(), current_transition, ) if ( current_transition in self._output_dict and self._output_dict[current_transition] is not None ): for part in self._output_dict[current_transition]: output_queue.put_nowait(part) self._last_transition = current_transition self._transition_counter = (current_transition, 1) async def call( self, content: AsyncIterable[ProcessorPart] ) -> AsyncIterable[ProcessorPart]: """Run the event detection processor.""" output_queue = asyncio.Queue() async def consume_content(): image_detection_task = None async for part in content: output_queue.put_nowait(part) if content_api.is_image(part.mimetype): self._images.append((part.part, time.perf_counter())) if image_detection_task is None or image_detection_task.done(): # Only run one image detection task at a time when the previous # one is done. Use a single image to minimize detection latency. image_detection_task = processor.create_task( self.detect_event(output_queue) ) output_queue.put_nowait(None) if image_detection_task is not None: try: image_detection_task.cancel() await image_detection_task except asyncio.CancelledError: pass consume_content_task = processor.create_task(consume_content()) while chunk := await output_queue.get(): yield chunk await consume_content_task