maga_transformer/openai/renderers/fast_chat_renderer.py (42 lines of code) (raw):
from typing import Optional, List, Dict, Any, Union, Callable, Tuple, AsyncGenerator
from dataclasses import dataclass
from .conversation import Conversation, get_conv_template
from transformers import PreTrainedTokenizerBase
from maga_transformer.openai.api_datatype import ChatMessage, GPTFunctionDefinition, \
ChatCompletionRequest, RoleEnum, RendererInfo
from maga_transformer.openai.renderers.llama_template import Template, get_template_and_fix_tokenizer
from maga_transformer.openai.renderers.custom_renderer import CustomChatRenderer, RendererParams, \
StreamResponseObject, RenderedInputs
from maga_transformer.openai.api_datatype import ChatMessage, GPTFunctionDefinition, RoleEnum, \
ChatCompletionRequest, ChatCompletionResponseStreamChoice
class FastChatRenderer(CustomChatRenderer):
def __init__(self, tokenizer: PreTrainedTokenizerBase, renderer_params: RendererParams):
super().__init__(tokenizer, renderer_params)
self.conv_template = get_conv_template(renderer_params.model_type)
self.roles_map = {
RoleEnum.user: self.conv_template.roles[0],
RoleEnum.assistant: self.conv_template.roles[1],
}
if isinstance(self.conv_template.stop_str, list):
self.add_extra_stop_words(self.conv_template.stop_str)
elif isinstance(self.conv_template.stop_str, str):
self.add_extra_stop_words([self.conv_template.stop_str])
if self.conv_template.stop_token_ids:
self.add_extra_stop_word_ids([[id] for id in self.conv_template.stop_token_ids])
def get_renderer_info(self) -> RendererInfo:
renderer_info = super().get_renderer_info()
renderer_info.template = str(self.conv_template)
return renderer_info
def render_chat(self, request: ChatCompletionRequest) -> RenderedInputs:
conversaion = self.conv_template.copy()
for message in request.messages:
assert (isinstance(message.content, str))
if message.role == RoleEnum.system:
conversaion.set_system_message(message.content)
else:
conversaion.append_message(self.roles_map[message.role], message.content)
if request.messages[-1].role != RoleEnum.assistant:
conversaion.append_message(self.roles_map[RoleEnum.assistant], "")
prompt = conversaion.get_prompt()
input_ids = self.tokenizer.encode(prompt)
return RenderedInputs(input_ids=input_ids)