core/jinja_template.py (73 lines of code) (raw):

# Copyright 2025 DeepMind Technologies Limited. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Processor for rendering Jinja templates with multimodal contents.""" from collections.abc import AsyncIterable from typing import Any import uuid from genai_processors import content_api from genai_processors import processor from genai_processors import streams import jinja2 class JinjaTemplate(processor.Processor): """Processor for rendering a Jinja template with multimodal contents. Example usage: ```python from genai_processors import content_api from genai_processors import processor from genai_processors.core import jinja_template p = jinja_template.JinjaTemplate( template_str="Hello {{ name }}, answer this question: {{ content }}", content_varname="content", role=content_api.Roles.USER, name="World", ) output = processor.apply_sync( p, [ content_api.ProcessorPart( "What is this landmark?", mimetype="text/plain", ), content_api.ProcessorPart( <image_bytes>, mimetype="image/png", ), ], ) print(content_api.as_text(output)) ``` """ def __init__( self, template_str: str, content_varname: str = "content", role: str = "user", *args, **kwargs, ) -> None: """Initializes the processor. Accepts the same args and kwargs as Jinja's `render()` method https://jinja.palletsprojects.com/en/stable/api/#jinja2.Template.render. Args: template_str: The Jinja template string. content_varname: The name of the Jinja variable to render the content. role: The role to use when outputting the rendered template. *args: Positional arguments to pass to Jinja's `render()` method. **kwargs: Keyword arguments to pass to Jinja's `render()` method. Raises: ValueError: If `content_varname` is passed in **kwargs. """ if content_varname in kwargs: raise ValueError( f"'{content_varname}' is set to render the processor's content and" " must not be passed as a variable to the Jinja template." ) # Render the template using a placeholder value for the processor's content # variable so the processor's content location can be found in the next # step. We use a UUID to ensure the placeholder value is not already present # in the template. content_placeholder = str(uuid.uuid4()) kwargs.update({content_varname: content_placeholder}) rendered_template = jinja2.Template(template_str).render(*args, **kwargs) # Split the template using the placeholder value as a delimiter, meaning # that the processor's content needs to be inserted between each element. # Splitting the template allows us to inject not only text but also # multi-part and multimodal content. self._template_split = rendered_template.split(content_placeholder) self._role = role async def call( self, content: AsyncIterable[content_api.ProcessorPart], ) -> AsyncIterable[content_api.ProcessorPartTypes]: # If the template was split into a single part, then the template did not # contain a variable to render the processor's content and should be # returned as is. if len(self._template_split) == 1: yield content_api.ProcessorPart( self._template_split[0], role=self._role, ) return # `content` is a stream that can only be iterated once, so we duplicate it # into identical streams to insert `content_streams[i]` between # `self._template_split[i]` and `self._template_split[i+1]`. content_streams = streams.split( content, n=len(self._template_split) - 1, with_copy=False, ) for i, template_part in enumerate(self._template_split): # Yield the template part. Empty parts are skipped as they correspond to # where the content variable was located. if template_part: yield content_api.ProcessorPart( template_part, role=self._role, ) # Yield the processor's content between two consecutive elements of the # template split. if i < len(content_streams): async for part in content_streams[i]: yield part class RenderDataClass(processor.PartProcessor): r"""PartProcessor for rendering a dataclass part using a Jinja template. The dataclass object must be referenced by the name `data` in the jinja template, i.e. `{{ data.first_name }}`. Example usage: ```python @dataclasses_json.dataclass_json @dataclasses.dataclass(frozen=True) class ExampleDataClass: first_name: str last_name: str shopping_list = ["A", "B", "C"] p = jinja_template.RenderDataClass( template_str=( "Hello {{ data.first_name }},\n" "This is your shopping list:\n" "{% for item in your_list %}This is item: {{ item }}\n" "{% endfor %}" ), data_class=ExampleDataClass, your_list=shopping_list, ) output = processor.apply_sync( p, [ content_api.ProcessorPart.from_dataclass( dataclass=ExampleDataClass(first_name="John", last_name="Doe") ) ], ) print(content_api.as_text) ``` The expected output is: ``` Hello John Doe, This is your shopping list: This is item: A This is item: B This is item: C ``` """ def __init__( self, template_str: str, data_class: type[Any], **kwargs, ): """Initializes the processor. Args: template_str: The Jinja template string. data_class: The type of the dataclass to render. **kwargs: Keyword arguments to pass to the Jinja template. """ self._environment = jinja2.Environment() self._environment.globals.update(**kwargs) self._template = self._environment.from_string(template_str) self._data_class = data_class def match(self, part: content_api.ProcessorPart) -> bool: return content_api.is_dataclass(part.mimetype, self._data_class) async def call( self, part: content_api.ProcessorPart ) -> AsyncIterable[content_api.ProcessorPart]: """Renders a dataclass part in a Jinja template.""" yield content_api.ProcessorPart( self._template.render(data=part.get_dataclass(self._data_class)), role=part.role, metadata=part.metadata, substream_name=part.substream_name, )