core/jinja_template.py (73 lines of code) (raw):
# Copyright 2025 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Processor for rendering Jinja templates with multimodal contents."""
from collections.abc import AsyncIterable
from typing import Any
import uuid
from genai_processors import content_api
from genai_processors import processor
from genai_processors import streams
import jinja2
class JinjaTemplate(processor.Processor):
"""Processor for rendering a Jinja template with multimodal contents.
Example usage:
```python
from genai_processors import content_api
from genai_processors import processor
from genai_processors.core import jinja_template
p = jinja_template.JinjaTemplate(
template_str="Hello {{ name }}, answer this question: {{ content }}",
content_varname="content",
role=content_api.Roles.USER,
name="World",
)
output = processor.apply_sync(
p,
[
content_api.ProcessorPart(
"What is this landmark?",
mimetype="text/plain",
),
content_api.ProcessorPart(
<image_bytes>,
mimetype="image/png",
),
],
)
print(content_api.as_text(output))
```
"""
def __init__(
self,
template_str: str,
content_varname: str = "content",
role: str = "user",
*args,
**kwargs,
) -> None:
"""Initializes the processor.
Accepts the same args and kwargs as Jinja's `render()` method
https://jinja.palletsprojects.com/en/stable/api/#jinja2.Template.render.
Args:
template_str: The Jinja template string.
content_varname: The name of the Jinja variable to render the content.
role: The role to use when outputting the rendered template.
*args: Positional arguments to pass to Jinja's `render()` method.
**kwargs: Keyword arguments to pass to Jinja's `render()` method.
Raises:
ValueError: If `content_varname` is passed in **kwargs.
"""
if content_varname in kwargs:
raise ValueError(
f"'{content_varname}' is set to render the processor's content and"
" must not be passed as a variable to the Jinja template."
)
# Render the template using a placeholder value for the processor's content
# variable so the processor's content location can be found in the next
# step. We use a UUID to ensure the placeholder value is not already present
# in the template.
content_placeholder = str(uuid.uuid4())
kwargs.update({content_varname: content_placeholder})
rendered_template = jinja2.Template(template_str).render(*args, **kwargs)
# Split the template using the placeholder value as a delimiter, meaning
# that the processor's content needs to be inserted between each element.
# Splitting the template allows us to inject not only text but also
# multi-part and multimodal content.
self._template_split = rendered_template.split(content_placeholder)
self._role = role
async def call(
self,
content: AsyncIterable[content_api.ProcessorPart],
) -> AsyncIterable[content_api.ProcessorPartTypes]:
# If the template was split into a single part, then the template did not
# contain a variable to render the processor's content and should be
# returned as is.
if len(self._template_split) == 1:
yield content_api.ProcessorPart(
self._template_split[0],
role=self._role,
)
return
# `content` is a stream that can only be iterated once, so we duplicate it
# into identical streams to insert `content_streams[i]` between
# `self._template_split[i]` and `self._template_split[i+1]`.
content_streams = streams.split(
content,
n=len(self._template_split) - 1,
with_copy=False,
)
for i, template_part in enumerate(self._template_split):
# Yield the template part. Empty parts are skipped as they correspond to
# where the content variable was located.
if template_part:
yield content_api.ProcessorPart(
template_part,
role=self._role,
)
# Yield the processor's content between two consecutive elements of the
# template split.
if i < len(content_streams):
async for part in content_streams[i]:
yield part
class RenderDataClass(processor.PartProcessor):
r"""PartProcessor for rendering a dataclass part using a Jinja template.
The dataclass object must be referenced by the name `data` in the jinja
template, i.e. `{{ data.first_name }}`.
Example usage:
```python
@dataclasses_json.dataclass_json
@dataclasses.dataclass(frozen=True)
class ExampleDataClass:
first_name: str
last_name: str
shopping_list = ["A", "B", "C"]
p = jinja_template.RenderDataClass(
template_str=(
"Hello {{ data.first_name }},\n"
"This is your shopping list:\n"
"{% for item in your_list %}This is item: {{ item }}\n"
"{% endfor %}"
),
data_class=ExampleDataClass,
your_list=shopping_list,
)
output = processor.apply_sync(
p,
[
content_api.ProcessorPart.from_dataclass(
dataclass=ExampleDataClass(first_name="John", last_name="Doe")
)
],
)
print(content_api.as_text)
```
The expected output is:
```
Hello John Doe,
This is your shopping list:
This is item: A
This is item: B
This is item: C
```
"""
def __init__(
self,
template_str: str,
data_class: type[Any],
**kwargs,
):
"""Initializes the processor.
Args:
template_str: The Jinja template string.
data_class: The type of the dataclass to render.
**kwargs: Keyword arguments to pass to the Jinja template.
"""
self._environment = jinja2.Environment()
self._environment.globals.update(**kwargs)
self._template = self._environment.from_string(template_str)
self._data_class = data_class
def match(self, part: content_api.ProcessorPart) -> bool:
return content_api.is_dataclass(part.mimetype, self._data_class)
async def call(
self, part: content_api.ProcessorPart
) -> AsyncIterable[content_api.ProcessorPart]:
"""Renders a dataclass part in a Jinja template."""
yield content_api.ProcessorPart(
self._template.render(data=part.get_dataclass(self._data_class)),
role=part.role,
metadata=part.metadata,
substream_name=part.substream_name,
)