maga_transformer/openai/renderers/basic_renderer.py (130 lines of code) (raw):
from typing import Optional, List, Dict, Any, Union, Callable, AsyncGenerator
import logging
import torch
import os
from functools import lru_cache
from packaging import version
import json
from transformers import PreTrainedTokenizerBase
from dataclasses import dataclass, field
import jinja2
from jinja2.exceptions import TemplateError
from jinja2.sandbox import ImmutableSandboxedEnvironment
from maga_transformer.openai.renderers.custom_renderer import CustomChatRenderer, \
RendererParams, StreamResponseObject, RenderedInputs, RendererInfo
from maga_transformer.models.base_model import GenerateOutput
from maga_transformer.openai.api_datatype import ChatMessage, GPTFunctionDefinition, RoleEnum, \
ChatCompletionRequest, ChatCompletionResponseStreamChoice, DeltaMessage, FinisheReason, UsageInfo
from maga_transformer.utils.multimodal_util import MMUrlType, MMPreprocessConfig
DEFAULT_CHAT_API_TEMPLATE = (
"{% for message in messages %}"
"{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}"
"{% endfor %}"
"{% if add_generation_prompt %}"
"{{ '<|im_start|>assistant\n' }}"
"{% endif %}"
)
@dataclass
class PromptWithMMInput:
prompt: str
urls: List[str]
mm_types: List[MMUrlType] = field(default_factory=list)
preprocess_configs: List[MMPreprocessConfig] = field(default_factory=list)
# This class is designed to replace `PreTrainedTokenizerBase.apply_chat_template` functionality,
# providing more capability to customize the template.
# More specifically, this method allows template to use `functions` field, following openai chat api format.
# Besides that, other template elements is compatible with `PreTrainedTokenizerBase.apply_chat_template`.
class BasicRenderer(CustomChatRenderer):
def __init__(self,
tokenizer: PreTrainedTokenizerBase,
renderer_params: RendererParams,
):
super().__init__(tokenizer, renderer_params)
if version.parse(jinja2.__version__) <= version.parse("3.0.0"):
raise ImportError(
"apply_chat_template requires jinja2>=3.0.0 to be installed. "
"Your version is " f"{jinja2.__version__}."
)
self.add_generation_prompt = True
self.chat_template = None
self.special_tokens_map = {}
try:
self.chat_template = tokenizer.chat_template
assert (self.chat_template != None)
except:
try:
self.chat_template = tokenizer.default_chat_template
assert (self.chat_template != None)
except:
logging.info(f"tokenizer {tokenizer} has no chat_template nor "
"default_chat_template attribute. Use default template.")
self.chat_template = DEFAULT_CHAT_API_TEMPLATE
self.add_extra_stop_words(["<|im_end|>"])
try:
if tokenizer.special_tokens_map != None:
self.special_tokens_map = tokenizer.special_tokens_map
for k, v in self.special_tokens_map.items():
logging.info(f"special token [{v}]({k}) added as stop words.")
if isinstance(v, str):
self.add_extra_stop_words([v])
elif isinstance(v, list):
self.add_extra_stop_words(v)
except:
pass
try:
if tokenizer.additional_special_tokens != None:
logging.info(f"additional special tokens {tokenizer.additional_special_tokens}"
"added as stop words.")
self.add_extra_stop_words(tokenizer.additional_special_tokens)
except:
pass
# try:
# if tokenizer.added_tokens_decoder != None:
# for token_id, added_token in tokenizer.added_tokens_decoder.items():
# logging.info(f"added token [{token_id}]({added_token}) added as stop words.")
# self.add_extra_stop_word_ids([[token_id]])
# except:
# pass
logging.info(f"found chat template to use: {self.chat_template}")
self.default_template_key = os.environ.get("DEFAULT_CHAT_TEMPLATE_KEY", "default")
self.default_tool_use_template_key = os.environ.get("DEFAULT_TOOL_USE_TEMPLATE_KEY", "tool_use")
self.compiled_template_map: Dict[str, jinja2.Template] = {}
if isinstance(self.chat_template, dict):
for key, template in self.chat_template.items():
self.compiled_template_map[key] = self._compile_jinja_template(template)
elif isinstance(self.chat_template, str):
self.compiled_template_map[self.default_template_key] = self._compile_jinja_template(self.chat_template)
else:
raise Exception(f"chat template [{self.chat_template}] "
f"of type [{type(self.chat_template)}] is not supported.")
if self.default_template_key not in self.compiled_template_map:
raise Exception(f"default template key [{self.default_template_key}] not found "
f"in chat templates: [{self.compiled_template_map.keys()}]")
if self.default_tool_use_template_key not in self.compiled_template_map:
self.default_tool_use_template_key = self.default_template_key
logging.info(f"compiled chat templates to use: {self.compiled_template_map.keys()}")
def get_renderer_info(self) -> RendererInfo:
renderer_info = super().get_renderer_info()
renderer_info.template = self.chat_template
return renderer_info
@lru_cache
def _compile_jinja_template(self, chat_template) -> jinja2.Template:
def raise_exception(message):
raise TemplateError(message)
jinja_env = ImmutableSandboxedEnvironment(trim_blocks=True, lstrip_blocks=True)
jinja_env.globals["raise_exception"] = raise_exception
return jinja_env.from_string(chat_template)
def _get_template(self, request: ChatCompletionRequest) -> jinja2.Template:
if request.user_template:
if request.template_key:
raise ValueError("template_key and user_template can not be used together.")
return self._compile_jinja_template(request.user_template)
template_key = self.default_tool_use_template_key \
if request.functions else self.default_template_key
template_key = request.template_key or template_key
return self.compiled_template_map[template_key]
def render_chat(self, request: ChatCompletionRequest) -> RenderedInputs:
template = self._get_template(request)
request_dict = json.loads(request.model_dump_json(exclude_none=True))
render_args = {
"messages": request_dict['messages'],
"json": json,
"add_generation_prompt": self.add_generation_prompt,
}
render_args.update(self.special_tokens_map)
# functions with none value may occur exception in llama3 template
if request_dict.get('functions', None):
render_args['functions'] = request_dict['functions']
rendered = template.render(
**render_args
)
logging.debug(f"request [{request.model_dump_json(indent=4)}] rendered string: [{rendered}]]")
return RenderedInputs(input_ids=self.tokenizer.encode(rendered))