maga_transformer/openai/renderers/llama_template_renderer.py (61 lines of code) (raw):
import copy
import json
import re
import logging
from typing import Optional, List, Dict, Any, Union, Callable, Tuple, AsyncGenerator
from transformers import PreTrainedTokenizerBase
from dataclasses import dataclass
from maga_transformer.openai.api_datatype import ChatMessage, GPTFunctionDefinition, \
ChatCompletionRequest, RoleEnum, FunctionCall, RendererInfo
from maga_transformer.openai.renderers.llama_template import Template, get_template_and_fix_tokenizer
from maga_transformer.openai.renderers.custom_renderer import CustomChatRenderer, RendererParams, \
StreamResponseObject, RenderedInputs
from maga_transformer.openai.api_datatype import ChatMessage, GPTFunctionDefinition, RoleEnum, \
ChatCompletionRequest, ChatCompletionResponseStreamChoice, DeltaMessage, FinisheReason, UsageInfo, \
ContentPart, ContentPartTypeEnum
@dataclass
class LlamaTemplateArgs:
query: str
resp: str = ""
history: Optional[List[Tuple[str, str]]] = None
system: Optional[str] = None
class LlamaTemplateRenderer(CustomChatRenderer):
def __init__(self, tokenizer: PreTrainedTokenizerBase, renderer_params: RendererParams):
super().__init__(tokenizer, renderer_params)
model_name = renderer_params.model_type
self.template = get_template_and_fix_tokenizer(model_name, tokenizer)
self.add_extra_stop_words(self.template.stop_words)
def get_renderer_info(self) -> RendererInfo:
renderer_info = super().get_renderer_info()
renderer_info.template = str(self.template)
return renderer_info
def _extract_history(self, messages: List[ChatMessage]) -> LlamaTemplateArgs:
# Messages must be formatted in the following way:
# 1. Messages may start with a system message or not.
# If started with a system message, it must be the first and only system message.
# If not started with a system message, it must not contain any system message.
# 2. Messages must be in the order of [user, assistant, user, assistant, ...]
# 3. The last message must be from the user.
history = []
system = None
query = ""
if messages[0].role == RoleEnum.system:
system = messages[0].content
assert isinstance(system, str)
messages = messages[1:]
query_message = messages.pop()
query = query_message.content
assert isinstance(query, str)
assert (len(messages) % 2 == 0)
for idx in range(0, len(messages), 2):
user_message = messages[idx]
assistant_message = messages[idx + 1]
assert (user_message.role == RoleEnum.user)
assert (assistant_message.role == RoleEnum.assistant)
history.append((user_message.content, assistant_message.content))
return LlamaTemplateArgs(query=query, history=history, system=system)
def render_chat(self, request: ChatCompletionRequest) -> RenderedInputs:
template_args = self._extract_history(request.messages)
assert isinstance(self.tokenizer, PreTrainedTokenizerBase)
encoded_ids = self.template.encode_oneturn(
self.tokenizer,
query=template_args.query,
resp=template_args.resp,
history=template_args.history,
system=template_args.system,
)[0]
return RenderedInputs(input_ids=encoded_ids)