maga_transformer/openai/renderers/qwen_renderer.py (453 lines of code) (raw):
import copy
import json
import re
import logging
import torch
from dataclasses import dataclass, field
from typing import Optional, List, Dict, Any, Union, Callable, Tuple, AsyncGenerator
import functools
from maga_transformer.models.base_model import GenerateOutput, GenerateOutputs
from maga_transformer.config.generate_config import GenerateConfig
from maga_transformer.openai.renderers.qwen_tool_renderer import QwenToolRenderer, QwenToolStreamStatus
from maga_transformer.tokenizer.tokenization_qwen import QWenTokenizer
from transformers import Qwen2Tokenizer
from maga_transformer.openai.api_datatype import ChatMessage, GPTFunctionDefinition, \
ChatCompletionRequest, RoleEnum, FunctionCall, ChatCompletionResponseStreamChoice, \
DeltaMessage, FinisheReason, UsageInfo, RendererInfo, PromptTokensDetails
from maga_transformer.openai.renderers.custom_renderer import CustomChatRenderer, RendererParams, \
StreamResponseObject, RenderedInputs, StreamStatus, StreamStatusSync, OutputDelta, ThinkStatus
from maga_transformer.openai.renderers.basic_renderer import BasicRenderer
from maga_transformer.openai.renderer_factory_register import register_renderer
from maga_transformer.utils.word_util import get_stop_word_slices, truncate_response_with_stop_words, is_truncated
QwenTokenizerTypes = Union[QWenTokenizer, Qwen2Tokenizer]
TOOL_DESC = """{name_for_model}: Call this tool to interact with the {name_for_human} API. What is the {name_for_human} API useful for? {description_for_model} Parameters: {parameters}"""
REACT_INSTRUCTION = """Answer the following questions as best you can. You have access to the following APIs:
{tools_text}
Use the following format:
Question: the input question you must answer
Thought: you should always think about what to do
Action: the action to take, should be one of [{tools_name_text}]
Action Input: the input to the action
Observation: the result of the action
... (this Thought/Action/Action Input/Observation can be repeated zero or more times)
Thought: I now know the final answer
Final Answer: the final answer to the original input question
Begin!"""
DUMMY_THOUGHT = {
"en": "\nThought: I now know the final answer.\nFinal answer: ",
"zh": "\nThought: 我会作答了。\nFinal answer: ",
}
_TEXT_COMPLETION_CMD = object()
class QwenStreamStatus(StreamStatus):
generating_function_call: bool = False
total_output_string: str = ""
def __init__(self, request: ChatCompletionRequest):
super().__init__(request)
def update_result(self):
self.last_token_length = len(self.output_ids) - len(self.last_output_ids)
self.last_output_ids = self.output_ids
self.responded_string = self.total_output_string[: - len('\nAction:')]
@property
def responded_length(self):
return len(self.responded_string)
@property
def output_length(self):
return len(self.total_output_string)
def check_stop_reason(self):
if self.finish_reason == None:
logging.debug(f"output [{self.responded_string}] found no stop reason! use stop as default.")
self.finish_reason = FinisheReason.stop
class QwenStreamStatusSync(StreamStatusSync):
generating_function_call: bool = False
total_output_string: str = ""
def __init__(self, request: ChatCompletionRequest):
super().__init__(request)
def update_result(self):
self.responded_string = self.total_output_string[: - len('\nAction:')]
@property
def responded_length(self):
return len(self.responded_string)
@property
def output_length(self):
return len(self.total_output_string)
def check_stop_reason(self):
if self.finish_reason == None:
logging.debug(f"output [{self.responded_string}] found no stop reason! use stop as default.")
self.finish_reason = FinisheReason.stop
@dataclass
class ProcessedOutput:
output_str: str
output_token_length: int
finish_reason: Optional[FinisheReason]
# TODO(wangyin): pass `max_window_size` to here.
def make_context(
tokenizer: QwenTokenizerTypes,
query: str,
history: List[Tuple[str, str]] = [],
system: str = "",
max_window_size: int = 6144,
):
history = copy.deepcopy(history)
im_start, im_end = "<|im_start|>", "<|im_end|>"
im_start_tokens = [tokenizer.im_start_id]
im_end_tokens = [tokenizer.im_end_id]
nl_tokens = tokenizer.encode("\n")
def _tokenize_str(role, content):
return f"{role}\n{content}", tokenizer.encode(
role, allowed_special=set()
) + nl_tokens + tokenizer.encode(content, allowed_special=set())
system_text, system_tokens_part = _tokenize_str("system", system)
system_tokens = im_start_tokens + system_tokens_part + im_end_tokens
raw_text = ""
context_tokens = []
for turn_query, turn_response in reversed(history):
query_text, query_tokens_part = _tokenize_str("user", turn_query)
query_tokens = im_start_tokens + query_tokens_part + im_end_tokens
response_text, response_tokens_part = _tokenize_str(
"assistant", turn_response
)
response_tokens = im_start_tokens + response_tokens_part + im_end_tokens
next_context_tokens = nl_tokens + query_tokens + nl_tokens + response_tokens
prev_chat = (
f"\n{im_start}{query_text}{im_end}\n{im_start}{response_text}{im_end}"
)
current_context_size = (
len(system_tokens) + len(next_context_tokens) + len(context_tokens)
)
if current_context_size < max_window_size:
context_tokens = next_context_tokens + context_tokens
raw_text = prev_chat + raw_text
else:
break
context_tokens = system_tokens + context_tokens
raw_text = f"{im_start}{system_text}{im_end}" + raw_text
context_tokens += (
nl_tokens
+ im_start_tokens
+ _tokenize_str("user", query)[1]
+ im_end_tokens
+ nl_tokens
+ im_start_tokens
+ tokenizer.encode("assistant")
+ nl_tokens
)
raw_text += f"\n{im_start}user\n{query}{im_end}\n{im_start}assistant\n"
return raw_text, context_tokens
class QwenRenderer(CustomChatRenderer):
def __init__(self, tokenizer: QwenTokenizerTypes, renderer_params: RendererParams):
super().__init__(tokenizer, renderer_params)
self.add_extra_stop_word_ids([[37763, 367, 25], [151643]]) # Observation:
self.qwen_tool_renderer = QwenToolRenderer(tokenizer, renderer_params)
self.template_chat_renderer: Optional[BasicRenderer] = None
try:
if tokenizer.chat_template != None:
logging.info(f"qwen model has chat_template [{tokenizer.chat_template}], "
"which will be used for non-function call dialogue.")
self.template_chat_renderer = BasicRenderer(tokenizer, renderer_params)
except AttributeError:
pass
def render_chat(self, request: ChatCompletionRequest) -> RenderedInputs:
if request.tools:
return self.qwen_tool_renderer.render_chat(request)
if (self.template_chat_renderer != None) and \
((request.functions == None) or (len(request.functions) == 0)):
return self.template_chat_renderer.render_chat(request)
query, history, system = self.parse_messages(request.messages, request.functions)
logging.debug(f"parsed query: {query}, history: {history}, system: {system}")
input_ids = []
if (query == _TEXT_COMPLETION_CMD):
input_ids = self.text_complete_last_message(history)
else:
assert (isinstance(query, str))
input_ids = make_context(self.tokenizer, query, history, system)[1]
return RenderedInputs(input_ids=input_ids)
def text_complete_last_message(self, history):
im_start = "<|im_start|>"
im_end = "<|im_end|>"
prompt = f"{im_start}system\nYou are a helpful assistant.{im_end}"
for i, (query, response) in enumerate(history):
query = query.lstrip("\n").rstrip()
response = response.lstrip("\n").rstrip()
prompt += f"\n{im_start}user\n{query}{im_end}"
prompt += f"\n{im_start}assistant\n{response}{im_end}"
prompt = prompt[: -len(im_end)]
return self.tokenizer.encode(prompt)
def parse_messages(
self,
messages: List[ChatMessage],
functions: Optional[List[GPTFunctionDefinition]] = None
):
if all(m.role != "user" for m in messages):
raise ValueError("At least one message must be from user.")
messages = copy.deepcopy(messages)
if messages[0].role == 'system':
system = messages.pop(0).content.lstrip('\n').rstrip()
else:
system = 'You are a helpful assistant.'
if functions:
tools_text = []
tools_name_text = []
for func_info in functions:
name = func_info.name
name_m = func_info.name_for_model or name
name_h = func_info.name_for_human or name
desc = func_info.description
desc_m = func_info.description_for_model or desc
tool = TOOL_DESC.format(
name_for_model=name_m,
name_for_human=name_h,
# Hint: You can add the following format requirements in description:
# "Format the arguments as a JSON object."
# "Enclose the code within triple backticks (`) at the beginning and end of the code."
description_for_model=desc_m,
parameters=json.dumps(func_info.parameters, ensure_ascii=False),
)
tools_text.append(tool)
tools_name_text.append(name_m)
tools_text = "\n\n".join(tools_text)
tools_name_text = ", ".join(tools_name_text)
instruction = (REACT_INSTRUCTION.format(
tools_text=tools_text,
tools_name_text=tools_name_text,
).lstrip('\n').rstrip())
else:
instruction = ''
messages_with_fncall = messages
messages = []
for m_idx, m in enumerate(messages_with_fncall):
role, content, func_call = m.role, m.content, m.function_call
content = content or ""
content = content.lstrip("\n").rstrip()
if role == "function":
if (len(messages) == 0) or (messages[-1].role != "assistant"):
raise ValueError(f"Invalid request: Expecting role assistant before role function.")
messages[-1].content += f'\nObservation: {content}'
if m_idx == len(messages_with_fncall) - 1:
# add a prefix for text completion
messages[-1].content += '\nThought:'
elif role == 'assistant':
if len(messages) == 0:
raise ValueError(f"Invalid request: Expecting role user before role assistant.")
if func_call is None:
if functions:
content = f'Thought: I now know the final answer.\nFinal Answer: {content}'
else:
f_name, f_args = func_call.name, func_call.arguments
if not content.startswith('Thought:'):
content = f'Thought: {content}'
content = f'{content}\nAction: {f_name}\nAction Input: {f_args}'
if messages[-1].role == 'user':
messages.append(
ChatMessage(role=RoleEnum.assistant, content=content.lstrip('\n').rstrip())
)
else:
messages[-1].content += '\n' + content
elif role == 'user':
messages.append(
ChatMessage(role='user',content=content.lstrip('\n').rstrip()))
else:
raise ValueError(f"Invalid request: Incorrect role {role}.")
query = _TEXT_COMPLETION_CMD
if messages[-1].role == 'user':
query = messages[-1].content
messages = messages[:-1]
history = [] # [(Q1, A1), (Q2, A2), ..., (Q_last_turn, A_last_turn)]
for i in range(0, len(messages), 2):
if messages[i].role == 'user' and messages[i + 1].role == 'assistant':
usr_msg = messages[i].content.lstrip('\n').rstrip()
bot_msg = messages[i + 1].content.lstrip('\n').rstrip()
if instruction and (i == len(messages) - 2):
usr_msg = f'{instruction}\n\nQuestion: {usr_msg}'
instruction = ''
history.append([usr_msg, bot_msg])
else:
raise ValueError(
"Invalid request: Expecting exactly one user (or function) role before every assistant role."
)
if instruction:
assert query is not _TEXT_COMPLETION_CMD
query = f'{instruction}\n\nQuestion: {query}'
return query, history, system
def _parse_function_response(self, response: str) -> Optional[DeltaMessage]:
func_name, func_args = "", ""
i = response.rfind("\nAction:")
j = response.rfind("\nAction Input:")
k = response.rfind("\nObservation:")
if 0 <= i < j: # If the text has `Action` and `Action input`,
if k < j: # but does not contain `Observation`,
# then it is likely that `Observation` is omitted by the LLM,
# because the output text may have discarded the stop word.
response = response.rstrip() + "\nObservation:" # Add it back.
k = response.rfind("\nObservation:")
func_name = response[i + len("\nAction:") : j].strip()
func_args = response[j + len("\nAction Input:") : k].strip()
logging.info(f"parsed function from response: [{response}]: {func_name}, {func_args}")
if func_name:
return DeltaMessage(
content=response[:i],
function_call=FunctionCall(name=func_name, arguments=func_args),
)
return None
# z = response.rfind("\nFinal Answer: ")
# if z >= 0:
# response = response[z + len("\nFinal Answer: ") :]
async def _update_single_status(self, status: StreamStatus,
output: GenerateOutput, max_new_tokens: int,
stop_words_str: List[str], stop_word_slice_list: List[str], is_streaming: bool) -> OutputDelta:
if status.request.tools:
return await self.qwen_tool_renderer._update_single_status(
status, output, max_new_tokens, stop_words_str, stop_word_slice_list, is_streaming)
if not isinstance(status, QwenStreamStatus):
return await super()._update_single_status(status, output, max_new_tokens, stop_words_str, stop_word_slice_list, is_streaming)
if status.finish_reason != None:
return await self._create_empty_delta(status.output.aux_info)
status.update_output(output,
self._clean_output_ids,
functools.partial(self._check_finish_reason, max_new_tokens=max_new_tokens),
self._remove_stop_word_ids)
status.total_output_string = self.tokenizer.decode(status.output_ids).strip()
if (len(status.total_output_string)) and (u'\uFFFD' == status.total_output_string[-1]):
return await self._create_empty_delta(output.aux_info)
# For some tokenizers (e.g. ChatGLM), decode a single token differs from decode a list of tokens.
if (status.total_output_string.endswith("\nAction:")):
status.generating_function_call = True
return await self._create_empty_delta(output.aux_info)
if (status.generating_function_call):
return await self._create_empty_delta(output.aux_info)
if is_truncated(status.total_output_string, stop_words_str, is_streaming):
status.finish_reason = FinisheReason.stop
return await self._create_empty_delta(output.aux_info)
if (len(status.total_output_string) > status.responded_length + len('\nAction:')):
status.delta_output_string = status.total_output_string[status.responded_length : status.output_length - len('\nAction:')]
if is_truncated(status.delta_output_string, stop_word_slice_list, is_streaming):
return await self._create_empty_delta(output.aux_info)
else:
status.update_result()
return OutputDelta(
status.delta_output_string,
await self._generate_log_probs(status, output),
status.input_token_length,
status.output_token_length,
status.reuse_length)
return await self._create_empty_delta(output.aux_info)
#override
async def _create_status_list(self, n: int, request: ChatCompletionRequest) -> List[StreamStatus]:
if request.tools:
return [QwenToolStreamStatus(request) for _ in range(n)]
if request.functions and (len(request.functions) > 0):
return [QwenStreamStatus(request) for _ in range(n)]
else:
return [StreamStatus(request) for _ in range(n)]
#override
async def _flush_buffer(self, buffer_list: List[StreamStatus], stop_words_str: List[str], is_streaming: bool, think_status: ThinkStatus):
if buffer_list[0].request.tools:
return await self.qwen_tool_renderer._flush_buffer(buffer_list, stop_words_str, is_streaming, think_status)
if (not isinstance(buffer_list[0], QwenStreamStatus)):
return await super()._flush_buffer(buffer_list, stop_words_str, is_streaming, think_status)
output_items: List[OutputDelta] = []
for status in buffer_list:
if status.generating_function_call:
function_message = self._parse_function_response(status.total_output_string[status.responded_length:])
if (function_message == None):
logging.warning(f"output [{status.total_output_string}] failed to parse function from [{status.responded_length}]. "
"regarded as normal output.")
function_message = ""
else:
status.finish_reason = FinisheReason.function_call
output_items.append(OutputDelta(
function_message,
await self._generate_log_probs(status, status.output),
status.input_token_length,
status.output_token_length,
status.reuse_length))
else:
trunc_string = truncate_response_with_stop_words(status.total_output_string[status.responded_length:], stop_words_str, is_streaming)
output_items.append(OutputDelta(
trunc_string,
await self._generate_log_probs(status, status.output),
status.input_token_length,
status.output_token_length,
status.reuse_length))
return await self._generate_stream_response(output_items, think_status)
#override
def _update_single_status_sync(self,
status: StreamStatusSync,
input_len, # output.aux_info
output_len, # output.aux_info
reuse_len, # output.aux_info
all_probs: torch.Tensor,
output_ids: torch.Tensor,
max_new_tokens: int,
stop_words_str: List[str],
stop_word_slice_list: List[str],
is_streaming: bool) -> OutputDelta:
# function call is disabled when logprobs is required.
if not isinstance(status, QwenStreamStatusSync):
return super()._update_single_status_sync(status,
input_len, output_len, reuse_len,
all_probs, output_ids,
max_new_tokens, stop_words_str,
stop_word_slice_list,
is_streaming)
if status.finish_reason != None:
return self._create_empty_delta_sync(input_len, output_len, reuse_len)
status.update_output_sync(output_ids, input_len,
self._clean_output_ids,
functools.partial(self._check_finish_reason, max_new_tokens=max_new_tokens),
self._remove_stop_word_ids)
status.total_output_string = self.tokenizer.decode(status.output_ids).strip()
if (len(status.total_output_string)) and (u'\uFFFD' == status.total_output_string[-1]):
return self._create_empty_delta_sync(input_len, output_len, reuse_len)
# For some tokenizers (e.g. ChatGLM), decode a single token differs from decode a list of tokens.
if (status.total_output_string.endswith("\nAction:")):
status.generating_function_call = True
return self._create_empty_delta_sync(input_len, output_len, reuse_len)
if (status.generating_function_call):
return self._create_empty_delta_sync(input_len, output_len, reuse_len)
if is_truncated(status.total_output_string, stop_words_str, is_streaming):
status.finish_reason = FinisheReason.stop
return self._create_empty_delta_sync(input_len, output_len, reuse_len)
if (len(status.total_output_string) > status.responded_length + len('\nAction:')):
status.delta_output_string = status.total_output_string[status.responded_length : status.output_length - len('\nAction:')]
if is_truncated(status.delta_output_string, stop_word_slice_list, is_streaming):
return self._create_empty_delta_sync(input_len, output_len, reuse_len)
else:
status.update_result()
return OutputDelta(
output_str=status.delta_output_string,
logprobs=self._generate_log_probs_sync(status, all_probs, output_ids),
input_length=input_len,
output_length=output_len,
reuse_length=reuse_len)
return self._create_empty_delta_sync(input_len, output_len, reuse_len)
#override
def _create_status_list_sync(self, n: int, body: str) -> List[StreamStatusSync]:
request = self.getRequest(body)
if request.logprobs:
return [StreamStatusSync(request) for _ in range(n)]
else:
return [QwenStreamStatusSync(request) for _ in range(n)]
#override
def _flush_buffer_sync(self,
buffer_list: List[StreamStatusSync],
input_len_list, output_len_list, reuse_len_list,
all_probs_list, output_ids_list,
stop_words_str: List[str],
is_streaming: bool):
if (not isinstance(buffer_list[0], QwenStreamStatusSync)):
return super()._flush_buffer_sync(buffer_list,
input_len_list, output_len_list, reuse_len_list,
all_probs_list, output_ids_list,
stop_words_str,
is_streaming)
output_items: List[OutputDelta] = []
for status, input_len, output_len, reuse_len, all_probs, output_ids in zip(
buffer_list,
input_len_list, output_len_list, reuse_len_list,
all_probs_list, output_ids_list
):
if status.generating_function_call:
function_message = self._parse_function_response(status.total_output_string[status.responded_length:])
if (function_message == None):
logging.warning(f"output [{status.total_output_string}] failed to parse function from [{status.responded_length}]. "
"regarded as normal output.")
function_message = ""
else:
status.finish_reason = FinisheReason.function_call
output_items.append(OutputDelta(
function_message,
self._generate_log_probs_sync(status, all_probs, output_ids),
input_len,
output_len,
reuse_len))
else:
trunc_string = truncate_response_with_stop_words(status.total_output_string[status.responded_length:], stop_words_str, is_streaming)
output_items.append(OutputDelta(
trunc_string,
self._generate_log_probs_sync(status, all_probs, output_ids),
input_len,
output_len,
reuse_len))
return self._generate_stream_response_sync(output_items)
def get_renderer_info(self) -> RendererInfo:
renderer_info = super().get_renderer_info()
if self.template_chat_renderer:
renderer_info.template = self.template_chat_renderer.chat_template
return renderer_info
register_renderer('qwen', QwenRenderer)
register_renderer('qwen_7b', QwenRenderer)
register_renderer('qwen_13b', QwenRenderer)
register_renderer('qwen_1b8', QwenRenderer)
register_renderer('qwen_2', QwenRenderer)
register_renderer('qwen_2_moe', QwenRenderer)