maga_transformer/openai/renderers/custom_renderer.py (851 lines of code) (raw):
from typing import Optional, List, Dict, Any, Union, Tuple, Callable, AsyncGenerator
import functools
import os
import json
import torch
import asyncio
import logging
from dataclasses import dataclass, field
from PIL import Image
from concurrent.futures import Future
from transformers import PreTrainedTokenizerBase
from maga_transformer.models.base_model import GenerateOutput, BaseModel, GenerateInput, GenerateOutputs, AuxInfo
from maga_transformer.config.generate_config import GenerateConfig
from maga_transformer.config.gpt_init_model_parameters import TemplateType
from maga_transformer.utils.mm_process_engine import MMProcessEngine
from maga_transformer.openai.api_datatype import ChatMessage, GPTFunctionDefinition, UsageInfo, \
ChatCompletionRequest, ChatCompletionResponseStreamChoice, DeltaMessage, FinisheReason, \
RoleEnum, RendererInfo, ChatCompletionStreamResponse, CompletionTokensDetails, PromptTokensDetails, \
ChatCompletionTokenLogprob, TopLogprob, ChoiceLogprobs, \
ChatCompletionResponseChoice, ChatCompletionResponse, DebugInfo
from maga_transformer.async_decoder_engine.async_model import AsyncModel
from maga_transformer.utils.word_util import get_stop_word_slices, truncate_response_with_stop_words, is_truncated
from maga_transformer.utils.util import has_overlap, has_overlap_kmp
from maga_transformer.utils.multimodal_util import MMUrlType, MultimodalInput, MMPreprocessConfig
from maga_transformer.async_decoder_engine.backend_rpc_server_visitor import BackendRPCServerVisitor
think_mode = bool(int(os.environ.get("THINK_MODE", 0)))
think_start_tag = os.environ.get("THINK_START_TAG", "<think>\n").encode('utf-8').decode('unicode_escape')
think_end_tag = os.environ.get("THINK_END_TAG", "</think>\n\n").encode('utf-8').decode('unicode_escape')
class StreamStatus:
index: int = 0
request: ChatCompletionRequest
output: Optional[GenerateOutput] = None
origin_output_ids: torch.Tensor = torch.empty(0, dtype=torch.int32)
output_ids: torch.Tensor = torch.empty(0, dtype=torch.int32)
last_output_ids: List[int] = []
last_token_length: int = 0
finish_reason = None
tokenizer = None
responded_string = ""
delta_output_string = ""
def __init__(self, request: ChatCompletionRequest):
self.request = request
def update_output(self,
output: GenerateOutput,
clean_output_func,
check_finish_func,
remove_stop_word_ids_func):
self.index += 1
self.output = output
self.origin_output_ids = torch.cat((self.origin_output_ids, output.output_ids), dim=1)
self.output_ids = clean_output_func(self.origin_output_ids)
self.finish_reason = check_finish_func(self.output_ids, self.input_token_length)
self.output_ids = remove_stop_word_ids_func(self.output_ids)
def update_result(self):
self.last_token_length = len(self.output_ids) - len(self.last_output_ids)
self.last_output_ids = self.output_ids
self.responded_string += self.delta_output_string
@property
def output_token_length(self):
return len(self.output_ids)
@property
def input_token_length(self):
return self.output.aux_info.input_len
@property
def reuse_length(self):
return self.output.aux_info.reuse_len
@property
def prev_token_id(self):
return self.last_output_ids[-self.last_token_length:]
@property
def tokens_to_decode(self):
return self.prev_token_id + self.output_ids[len(self.last_output_ids):]
class StreamStatusSync:
index: int = 0
request: ChatCompletionRequest
origin_output_ids: torch.Tensor = torch.empty(0, dtype=torch.int32)
output_ids: torch.Tensor = torch.empty(0, dtype=torch.int32)
last_output_ids: List[int] = []
last_token_length: int = 0
finish_reason = None
tokenizer = None
responded_string = ""
delta_output_string = ""
def __init__(self, request: ChatCompletionRequest):
self.request = request
def update_output_sync(self,
output_ids,
input_len,
clean_output_func,
check_finish_func,
remove_stop_word_ids_func):
self.index += 1
self.origin_output_ids = torch.cat((self.origin_output_ids, output_ids), dim=1)
self.output_ids = clean_output_func(self.origin_output_ids)
self.finish_reason = check_finish_func(self.output_ids, input_len)
self.output_ids = remove_stop_word_ids_func(self.output_ids)
def update_result(self):
self.last_token_length = len(self.output_ids) - len(self.last_output_ids)
self.last_output_ids = self.output_ids
self.responded_string += self.delta_output_string
@property
def prev_token_id(self):
return self.last_output_ids[-self.last_token_length:]
@property
def tokens_to_decode(self):
return self.prev_token_id + self.output_ids[len(self.last_output_ids):]
@dataclass
class StreamResponseObject:
choices: List[ChatCompletionResponseStreamChoice] = field(default_factory=list)
usage: Optional[UsageInfo] = None
aux_info: Optional[AuxInfo] = None
@dataclass
class ResponseObject:
choices: List[ChatCompletionResponseChoice] = field(default_factory=list)
usage: Optional[UsageInfo] = None
aux_info: Optional[AuxInfo] = None
@dataclass
class RendererParams:
model_type: str
max_seq_len: int
eos_token_id: int
stop_word_ids_list: List[List[int]]
template_type: TemplateType = TemplateType.chat
ckpt_path: str = ""
@dataclass
class OutputDelta():
output_str: Union[str, DeltaMessage]
logprobs: Optional[ChatCompletionTokenLogprob]
input_length: int
output_length: int
reuse_length: int
@dataclass
class ThinkStatus():
in_think_mode: int = 0
think_buffer: str = ""
think_tokens: int = 0
is_streaming: bool = False
class RenderedInputs:
input_ids: List[int] = []
multimodal_inputs: List[MultimodalInput] = []
rendered_prompt: str = ""
def __init__(self, input_ids: List[int], rendered_prompt: str = "", input_urls: List[str] = [], input_urls_type: List[MMUrlType] = [], preprocess_configs: List[MMPreprocessConfig] = []):
self.input_ids = input_ids
self.rendered_prompt = rendered_prompt
self.multimodal_inputs = []
if len(input_urls_type) == 0:
input_urls_type = [MMUrlType.DEFAULT] * len(input_urls)
elif len(input_urls_type) != len(input_urls):
raise Exception(f"the number of multimodal input types must match url, now types {len(input_urls_type)} urls {len(input_urls)}")
if len(preprocess_configs) == 0:
preprocess_configs = [MMPreprocessConfig()] * len(input_urls)
elif len(preprocess_configs) != len(preprocess_configs):
raise Exception(f"the number of multimodal preprocess config must match url, now types {len(preprocess_configs)} urls {len(input_urls)}")
for url, type, config in zip(input_urls, input_urls_type, preprocess_configs):
self.multimodal_inputs.append(MultimodalInput(url, type, config))
class CustomChatRenderer():
def __init__(self,
tokenizer: PreTrainedTokenizerBase,
renderer_params: RendererParams,
):
self.tokenizer = tokenizer
self.model_type = renderer_params.model_type
self.max_seq_len = renderer_params.max_seq_len
self.eos_token_id = renderer_params.eos_token_id
self.stop_words_id_list = renderer_params.stop_word_ids_list
self.stop_words_str_list = [
self.tokenizer.decode(stop_word_ids) for stop_word_ids in self.stop_words_id_list
]
self.ckpt_path = renderer_params.ckpt_path
# NOTE: stop words or their ids only need to be added to one of these two lists.
self.extra_stop_words: List[str] = []
self.extra_stop_word_ids_list: List[List[int]] = []
def __str__(self) -> str:
return str(self.get_renderer_info())
def __repr__(self) -> str:
return self.__str__()
def get_renderer_info(self) -> RendererInfo:
extra_stop_word_ids_list = self.get_all_extra_stop_word_ids_list()
extra_stop_words_list = [
self.tokenizer.decode(stop_word_ids) for stop_word_ids in extra_stop_word_ids_list
]
if len(extra_stop_words_list) and isinstance(extra_stop_words_list[0], list):
extra_stop_words_list = [l[0] for l in extra_stop_words_list]
return RendererInfo(
class_name=self.__class__.__name__,
renderer_model_type=self.model_type,
extra_stop_word_ids_list=extra_stop_word_ids_list,
extra_stop_words_list=extra_stop_words_list,
)
def add_extra_stop_words(self, extra_stop_words: List[str]):
self.extra_stop_words.extend(extra_stop_words)
def add_extra_stop_word_ids(self, extra_stop_word_ids: List[List[int]]):
self.extra_stop_word_ids_list.extend(extra_stop_word_ids)
def tokenize_words(self, words: List[str]) -> List[List[int]]:
ids_list = []
for word in words:
if isinstance(self.tokenizer, PreTrainedTokenizerBase):
token_id = self.tokenizer.convert_tokens_to_ids(word)
if isinstance(token_id, int):
ids_list.append([token_id])
elif isinstance(token_id, list):
ids_list.append(token_id)
else:
ids_list.append(self.tokenizer.encode(word, add_special_tokens=True))
else:
ids_list.append(self.tokenizer.encode(word))
return ids_list
def get_all_extra_stop_word_ids_list(self) -> List[List[int]]:
ids_list_from_words = self.tokenize_words(self.extra_stop_words)
return self.extra_stop_word_ids_list + ids_list_from_words
def _check_all_finished(self, status_list) -> bool:
for s in status_list:
if s.finish_reason == None:
return False
return True
def getRequest(self, request: str) -> ChatCompletionRequest:
return ChatCompletionRequest(**(json.loads(request)))
def render_chat(self, request: ChatCompletionRequest) -> RenderedInputs:
raise NotImplementedError
async def generate_choice(
self,
request_id: int,
input_ids: List[int],
mm_inputs: List[MultimodalInput],
generate_config: GenerateConfig,
backend_rpc_server_visitor: BackendRPCServerVisitor,
request: ChatCompletionRequest
) -> AsyncGenerator[StreamResponseObject, None]:
token_type_ids = []
input_id_tensor = torch.Tensor(input_ids).int().unsqueeze(0)
output_generator: AsyncGenerator[GenerateOutput, None] = backend_rpc_server_visitor.enqueue(
GenerateInput(
request_id=request_id,
token_ids=input_id_tensor,
mm_inputs=mm_inputs,
generate_config=generate_config,
tokenizer=self.tokenizer,
token_type_ids=token_type_ids
)
)
async for response in self.render_response_stream(output_generator,
request,
generate_config):
yield response
async def _create_empty_delta(self, aux_info: AuxInfo):
return OutputDelta(
output_str="",
logprobs=None,
input_length=aux_info.input_len,
output_length=aux_info.output_len,
reuse_length=aux_info.reuse_len
)
async def _generate_log_probs(self, status: StreamStatus, output: Optional[GenerateOutput]) -> Optional[ChatCompletionTokenLogprob]:
assert output is not None
if not status.request.logprobs:
return None
prob_return_num = status.request.top_logprobs or 1
all_probs = output.all_probs
output_id = output.output_ids
if output_id == None:
return None
selected_id = output_id[-1].item()
if (all_probs == None):
raise Exception("all_probs is None when logprobs is true. There should be a internal bug.")
all_probs = all_probs.squeeze()
probs, tokens = all_probs.sort(descending=True)
non_zero_size = probs.nonzero().shape[0]
log_values = probs.log()
prob_return_num = min(prob_return_num, non_zero_size)
selected_token = self.tokenizer.decode([selected_id])
chat_logprob = ChatCompletionTokenLogprob(
token=selected_token,
bytes=list(selected_token.encode("utf-8", errors="replace")),
logprob=all_probs[output_id].log().item(),
top_logprobs=[]
)
for i in range(prob_return_num):
token = self.tokenizer.decode(tokens[i].item())
chat_logprob.top_logprobs.append(TopLogprob(
token=token,
logprob=log_values[i].item(),
bytes=list(token.encode("utf-8", errors="replace")),
))
logging.debug(f"chat_logprob: {chat_logprob.model_dump_json(indent=4)}")
return chat_logprob
async def _update_single_status(self, status: StreamStatus, output: GenerateOutput, max_new_tokens: int, stop_words_str: List[str], stop_word_slice_list: List[str], is_streaming: bool) -> OutputDelta:
if status.finish_reason != None:
return await self._create_empty_delta(status.output.aux_info)
status.update_output(output,
self._clean_output_ids,
functools.partial(self._check_finish_reason, max_new_tokens=max_new_tokens),
self._remove_stop_word_ids)
decoded_prev_token = self.tokenizer.decode(status.prev_token_id)
decoded_string = self.tokenizer.decode(status.tokens_to_decode)
# For some tokenizers (e.g. ChatGLM), decode a single token differs from decode a list of tokens.
if is_streaming:
if len(decoded_string) > 0 and u'\uFFFD' == decoded_string[-1]:
return await self._create_empty_delta(output.aux_info)
else:
while (len(decoded_string) > 0) and (u'\uFFFD' == decoded_string[-1]):
decoded_string = decoded_string[:-1]
status.delta_output_string = decoded_string[len(decoded_prev_token):]
if is_truncated(status.delta_output_string, stop_words_str, is_streaming):
status.finish_reason = FinisheReason.stop
return await self._create_empty_delta(output.aux_info)
if not is_truncated(status.delta_output_string, stop_word_slice_list, is_streaming):
status.update_result()
delta = OutputDelta(
output_str=status.delta_output_string,
logprobs=await self._generate_log_probs(status, output),
input_length=output.aux_info.input_len,
output_length=output.aux_info.output_len,
reuse_length=output.aux_info.reuse_len)
status.delta_output_string = ""
return delta
else:
return await self._create_empty_delta(output.aux_info)
async def _generate_first(self, n: int):
return StreamResponseObject(
choices=[ChatCompletionResponseStreamChoice(
index=i,
delta=DeltaMessage(
role=RoleEnum.assistant,
content="",
),
) for i in range(n)]
)
def _split_reasoning_text_and_content(self, item: OutputDelta, think_status: ThinkStatus):
if isinstance(item.output_str, str):
processing_index, output_len = 0, len(item.output_str)
if output_len == 0:
return DeltaMessage(content="")
reasoning_text, content = "", ""
update_think_tokens = think_status.in_think_mode
while processing_index < output_len:
if think_status.in_think_mode:
think_status.think_buffer += item.output_str[processing_index]
if think_status.think_buffer.startswith(think_start_tag):
think_status.think_buffer = think_status.think_buffer[len(think_start_tag):]
if think_status.think_buffer.endswith(think_end_tag):
reasoning_text = think_status.think_buffer[:-len(think_end_tag)]
think_status.think_buffer = ""
think_status.in_think_mode = False
elif has_overlap_kmp(think_status.think_buffer, think_end_tag) \
or think_start_tag.startswith(think_status.think_buffer):
pass
else:
reasoning_text = think_status.think_buffer
processing_index += 1
else:
content += item.output_str[processing_index:]
processing_index = output_len
if think_status.in_think_mode:
if has_overlap_kmp(think_status.think_buffer, think_end_tag) \
or think_start_tag.startswith(think_status.think_buffer):
reasoning_text = ""
else:
think_status.think_buffer = ""
if think_mode and update_think_tokens:
if not think_status.is_streaming:
think_status.think_tokens = item.output_length - len(self.tokenizer.tokenize(content or ""))
else:
think_status.think_tokens = item.output_length
return DeltaMessage(reasoning_content=reasoning_text or "", content=content or "")
elif isinstance(item.output_str, DeltaMessage):
return item.output_str
else:
raise Exception(f'undefined output_str type[{type(item.output_str)}]')
async def _generate_stream_response(self, items: List[OutputDelta], think_status: ThinkStatus) -> StreamResponseObject:
if len(items) == 0:
raise Exception("output items length should not be 0")
input_lengths = items[0].input_length
output_lengths = sum([x.output_length for x in items])
reuse_lengths = items[0].reuse_length
all_choices = []
for i, item in enumerate(items):
delta = self._split_reasoning_text_and_content(item, think_status)
all_choices.append(ChatCompletionResponseStreamChoice(
index=i,
delta=delta,
logprobs=ChoiceLogprobs(
content=[item.logprobs] if item.logprobs != None else None,
refusal=None
) if item.logprobs != None else None
))
return StreamResponseObject(
choices=all_choices,
usage=UsageInfo(
prompt_tokens=input_lengths,
total_tokens=input_lengths + output_lengths,
completion_tokens=output_lengths,
completion_tokens_details=CompletionTokensDetails(reasoning_tokens=think_status.think_tokens) if think_mode > 0 else None,
prompt_tokens_details=PromptTokensDetails(cached_tokens=reuse_lengths) if reuse_lengths > 0 else None
)
)
async def _flush_buffer(self, buffer_list: List[StreamStatus], stop_words_str: List[str], is_streaming: bool, think_status: ThinkStatus):
output_items: List[OutputDelta] = []
for buffer in buffer_list:
if buffer.output is None:
raise Exception("last output should not be None")
aux_info = buffer.output.aux_info
trunc_string = truncate_response_with_stop_words(buffer.delta_output_string, stop_words_str, is_streaming)
output_items.append(OutputDelta(
trunc_string,
await self._generate_log_probs(buffer, buffer.output),
aux_info.input_len,
aux_info.output_len,
aux_info.reuse_len))
return await self._generate_stream_response(output_items, think_status)
async def _generate_final(self, buffer_list: List[StreamStatus], request: ChatCompletionRequest, think_status: ThinkStatus):
input_token_length = 0
output_token_length = 0
reuse_length = 0
aux_info = None
for i, buffer in enumerate(buffer_list):
if buffer.output is None:
raise Exception("buffer last output should not be None")
# 延迟引入, 避免循环import
from maga_transformer.openai.renderers.qwen_tool_renderer import QwenToolStreamStatus
# 判断buffer有无generating_tool_call这个属性
if isinstance(buffer, QwenToolStreamStatus) and buffer.generating_tool_call:
buffer.finish_reason = FinisheReason.tool_calls
if buffer.finish_reason == None:
logging.debug(f"output {i} found no stop reason! use stop as default.")
buffer.finish_reason = FinisheReason.stop
if i == 0:
input_token_length = buffer.output.aux_info.input_len
reuse_length = buffer.output.aux_info.reuse_len
aux_info = buffer.output.aux_info if request.aux_info else None
output_token_length += buffer.output.aux_info.output_len
return StreamResponseObject(
choices=[ChatCompletionResponseStreamChoice(
index=i,
delta=DeltaMessage(
content="",
),
finish_reason=buffer.finish_reason
) for i, buffer in enumerate(buffer_list)],
usage=UsageInfo(
prompt_tokens=input_token_length,
total_tokens=input_token_length + output_token_length,
completion_tokens=output_token_length,
completion_tokens_details=CompletionTokensDetails(reasoning_tokens=think_status.think_tokens) if think_mode > 0 else None,
prompt_tokens_details=PromptTokensDetails(cached_tokens=reuse_length) if reuse_length > 0 else None
),
aux_info=aux_info
)
async def _create_status_list(self, n: int, request: ChatCompletionRequest) -> List[StreamStatus]:
return [StreamStatus(request) for _ in range(n)]
async def render_response_stream(
self,
output_generator: AsyncGenerator[GenerateOutputs, None],
request: ChatCompletionRequest,
generate_config: GenerateConfig
) -> AsyncGenerator[StreamResponseObject, None]:
stop_word_slice_list = get_stop_word_slices(generate_config.stop_words_str)
num_return_sequences = request.n if request.n is not None else 1
status_list = await self._create_status_list(num_return_sequences, request)
index = 0
global think_mode
think_status = ThinkStatus(in_think_mode=think_mode, think_buffer="", think_tokens=0, is_streaming=generate_config.is_streaming)
async for outputs in output_generator:
if index == 0:
yield await self._generate_first(num_return_sequences)
index += 1
if len(outputs.generate_outputs) != num_return_sequences:
raise Exception("output num != num_return_sequences")
delta_list: List[OutputDelta] = []
for status, output in zip(status_list, outputs.generate_outputs):
delta_list.append(await self._update_single_status(
status, output, generate_config.max_new_tokens, generate_config.stop_words_str,
stop_word_slice_list, generate_config.is_streaming))
yield await self._generate_stream_response(delta_list, think_status)
if self._check_all_finished(status_list):
break
if index != 0:
yield await self._flush_buffer(status_list, generate_config.stop_words_str, generate_config.is_streaming, think_status)
yield await self._generate_final(status_list, request, think_status)
def _create_empty_delta_sync(self, input_len, output_len, reuse_len):
return OutputDelta(
output_str="",
logprobs=None,
input_length=input_len,
output_length=output_len,
reuse_length=reuse_len
)
def _generate_log_probs_sync(self,
status: StreamStatusSync,
all_probs: torch.Tensor,
output_ids: torch.Tensor) -> Optional[ChatCompletionTokenLogprob]:
if not status.request.logprobs:
return None
prob_return_num = status.request.top_logprobs or 1
all_probs = all_probs
output_id = output_ids
if output_id == None:
return None
selected_id = output_id[-1].item()
if (all_probs == None):
raise Exception("all_probs is None when logprobs is true. There should be a internal bug.")
all_probs = all_probs.squeeze()
probs, tokens = all_probs.sort(descending=True)
non_zero_size = probs.nonzero().shape[0]
log_values = probs.log()
prob_return_num = min(prob_return_num, non_zero_size)
selected_token = self.tokenizer.decode([selected_id])
chat_logprob = ChatCompletionTokenLogprob(
token=selected_token,
bytes=list(selected_token.encode("utf-8", errors="replace")),
logprob=all_probs[output_id].log().item(),
top_logprobs=[]
)
for i in range(prob_return_num):
token = self.tokenizer.decode(tokens[i].item())
chat_logprob.top_logprobs.append(TopLogprob(
token=token,
logprob=log_values[i].item(),
bytes=list(token.encode("utf-8", errors="replace")),
))
logging.debug(f"chat_logprob: {chat_logprob.model_dump_json(indent=4)}")
return chat_logprob
def _update_single_status_sync(self,
status: StreamStatusSync,
input_len, # output.aux_info
output_len, # output.aux_info
reuse_len, # output.aux_info
all_probs: torch.Tensor,
output_ids: torch.Tensor,
max_new_tokens: int,
stop_words_str: List[str],
stop_word_slice_list: List[str],
is_streaming: bool) -> OutputDelta:
if status.finish_reason != None:
return self._create_empty_delta_sync(input_len, output_len, reuse_len)
status.update_output_sync(output_ids, input_len,
self._clean_output_ids,
functools.partial(self._check_finish_reason, max_new_tokens=max_new_tokens),
self._remove_stop_word_ids)
decoded_prev_token = self.tokenizer.decode(status.prev_token_id)
decoded_string = self.tokenizer.decode(status.tokens_to_decode)
# For some tokenizers (e.g. ChatGLM), decode a single token differs from decode a list of tokens.
if is_streaming:
if len(decoded_string) > 0 and u'\uFFFD' == decoded_string[-1]:
return self._create_empty_delta_sync(input_len, output_len, reuse_len)
else:
while (len(decoded_string) > 0) and (u'\uFFFD' == decoded_string[-1]):
decoded_string = decoded_string[:-1]
status.delta_output_string = decoded_string[len(decoded_prev_token):]
if is_truncated(status.delta_output_string, stop_words_str, is_streaming):
status.finish_reason = FinisheReason.stop
return self._create_empty_delta_sync(input_len, output_len, reuse_len)
if not is_truncated(status.delta_output_string, stop_word_slice_list, is_streaming):
status.update_result()
delta = OutputDelta(
output_str=status.delta_output_string,
logprobs=self._generate_log_probs_sync(status, all_probs, output_ids),
input_length=input_len,
output_length=output_len,
reuse_length=reuse_len)
status.delta_output_string = ""
return delta
else:
return self._create_empty_delta_sync(input_len, output_len, reuse_len)
def _generate_first_sync(self, n: int):
return StreamResponseObject(
choices=[ChatCompletionResponseStreamChoice(
index=i,
delta=DeltaMessage(
role=RoleEnum.assistant,
content="",
),
) for i in range(n)]
)
def _generate_stream_response_sync(self, items: List[OutputDelta]) -> StreamResponseObject:
if len(items) == 0:
raise Exception("output items length should not be 0")
input_lengths = items[0].input_length
output_lengths = sum([x.output_length for x in items])
reuse_lengths = items[0].reuse_length
return StreamResponseObject(
choices=[ChatCompletionResponseStreamChoice(
index=i,
delta=DeltaMessage(
content=item.output_str,
) if isinstance(item.output_str, str) else item.output_str,
logprobs=ChoiceLogprobs(
content=[item.logprobs] if item.logprobs != None else None,
refusal=None
) if item.logprobs != None else None
) for i, item in enumerate(items)],
usage=UsageInfo(
prompt_tokens=input_lengths,
total_tokens=input_lengths + output_lengths,
completion_tokens=output_lengths,
prompt_tokens_details=PromptTokensDetails(cached_tokens=reuse_lengths) if reuse_lengths > 0 else None
)
)
def _flush_buffer_sync(self,
buffer_list: List[StreamStatusSync],
input_len_list, output_len_list, reuse_len_list,
all_probs_list, output_ids_list,
stop_words_str: List[str],
is_streaming: bool):
output_items: List[OutputDelta] = []
for buffer, input_len, output_len, reuse_len, all_probs, output_ids in zip(
buffer_list,
input_len_list, output_len_list, reuse_len_list,
all_probs_list, output_ids_list
):
trunc_string = truncate_response_with_stop_words(buffer.delta_output_string, stop_words_str, is_streaming)
output_items.append(OutputDelta(
trunc_string,
self._generate_log_probs_sync(buffer, all_probs, output_ids),
input_len,
output_len,
reuse_len))
return self._generate_stream_response_sync(output_items)
def _generate_final_sync(self,
buffer_list: List[StreamStatusSync],
input_len_list, output_len_list, reuse_len_list):
input_token_length = 0
output_token_length = 0
reuse_length = 0
aux_info = None
for i, (buffer, input_len, output_len, reuse_len) in enumerate(zip(buffer_list,
input_len_list,
output_len_list,
reuse_len_list)):
if buffer.finish_reason == None:
logging.debug(f"output {i} found no stop reason! use stop as default.")
buffer.finish_reason = FinisheReason.stop
if i == 0:
input_token_length = input_len
reuse_length = reuse_len
output_token_length += output_len
return StreamResponseObject(
choices=[ChatCompletionResponseStreamChoice(
index=i,
delta=DeltaMessage(
content="",
),
finish_reason=buffer.finish_reason
) for i, buffer in enumerate(buffer_list)],
usage=UsageInfo(
prompt_tokens=input_token_length,
total_tokens=input_token_length + output_token_length,
completion_tokens=output_token_length,
prompt_tokens_details=PromptTokensDetails(cached_tokens=reuse_length) if reuse_length > 0 else None
),
aux_info=aux_info
)
def _create_status_list_sync(self, n: int, body: str) -> List[StreamStatusSync]:
request = self.getRequest(body)
return [StreamStatusSync(request) for _ in range(n)]
def render_stream_response_first(self, n: int, debug_info: str):
stream_response = self._generate_first_sync(n)
chat_response = ChatCompletionStreamResponse(
choices=stream_response.choices,
usage=stream_response.usage,
aux_info=stream_response.aux_info,
debug_info=debug_info
)
return chat_response.model_dump_json(exclude_none=True)
def render_stream_response_refactor(self,
status_list: StreamStatusSync, # pass in from cpp
input_len_list, # output.aux_info
output_len_list, # output.aux_info
reuse_len_list, # output.aux_info
all_probs_list, # GenerateOutput
output_ids_list, # GenerateOutput
max_new_tokens, # GenerateConfig
stop_words_str, # GenerateConfig
is_streaming):
stop_word_slice_list = get_stop_word_slices(stop_words_str) # move into cpp, then pass in
delta_list: List[OutputDelta] = []
for status, input_len, output_len, reuse_len, all_probs, output_ids in zip(
status_list,
input_len_list, output_len_list, reuse_len_list, # AuxInfo
all_probs_list, output_ids_list # GenerateOutput
):
delta_list.append(self._update_single_status_sync(status,
input_len, output_len, reuse_len,
all_probs, output_ids,
max_new_tokens, stop_words_str,
stop_word_slice_list,
is_streaming))
stream_response = self._generate_stream_response_sync(delta_list)
chat_response = ChatCompletionStreamResponse(
choices=stream_response.choices,
usage=stream_response.usage,
aux_info=stream_response.aux_info
)
return chat_response.model_dump_json(exclude_none=True)
def render_stream_response_flush(self,
status_list,
input_len_list, output_len_list, reuse_len_list,
all_probs_list, output_ids_list,
stop_words_str,
is_streaming):
stream_response = self._flush_buffer_sync(status_list,
input_len_list, output_len_list, reuse_len_list,
all_probs_list, output_ids_list,
stop_words_str,
is_streaming)
chat_response = ChatCompletionStreamResponse(
choices=stream_response.choices,
usage=stream_response.usage,
aux_info=stream_response.aux_info
)
return chat_response.model_dump_json(exclude_none=True)
def render_stream_response_final(self,
status_list,
input_len_list, output_len_list, reuse_len_list):
stream_response = self._generate_final_sync(status_list,
input_len_list, output_len_list, reuse_len_list)
chat_response = ChatCompletionStreamResponse(
choices=stream_response.choices,
usage=stream_response.usage,
aux_info=stream_response.aux_info
)
return chat_response.model_dump_json(exclude_none=True)
def render_stream_response_first_blocking(self, n: int):
stream_response = self._generate_first_sync(n)
return stream_response
def render_stream_response_blocking(self,
status_list: StreamStatusSync, # pass in from cpp
input_len_list, # output.aux_info
output_len_list, # output.aux_info
reuse_len_list, # output.aux_info
all_probs_list, # GenerateOutput
output_ids_list, # GenerateOutput
max_new_tokens, # GenerateConfig
stop_words_str, # GenerateConfig
is_streaming
):
stop_word_slice_list = get_stop_word_slices(stop_words_str) # move into cpp, then pass in
delta_list: List[OutputDelta] = []
for status, input_len, output_len, reuse_len, all_probs, output_ids in zip(
status_list,
input_len_list, output_len_list, reuse_len_list, # AuxInfo
all_probs_list, output_ids_list # GenerateOutput
):
delta_list.append(self._update_single_status_sync(status,
input_len, output_len, reuse_len,
all_probs, output_ids,
max_new_tokens, stop_words_str,
stop_word_slice_list,
is_streaming))
stream_response = self._generate_stream_response_sync(delta_list)
return stream_response
def render_stream_response_flush_blocking(self,
status_list,
input_len_list, output_len_list, reuse_len_list,
all_probs_list, output_ids_list,
stop_words_str,
is_streaming):
stream_response = self._flush_buffer_sync(status_list,
input_len_list, output_len_list, reuse_len_list,
all_probs_list, output_ids_list,
stop_words_str,
is_streaming)
return stream_response
def render_stream_response_final_blocking(self,
status_list,
input_len_list, output_len_list, reuse_len_list):
stream_response = self._generate_final_sync(status_list,
input_len_list, output_len_list, reuse_len_list)
return stream_response
def collect_complete_response(self, choice_generator):
all_choices = []
usage = None
aux_info = None
def split_think_tag(text: Optional[str]):
if text is None:
return None, None
text_results = text.split(think_end_tag, 1)
reasoning_content = text_results[0] if len(text_results) == 2 else None
content = text_results[1] if len(text_results) == 2 else text
return content, reasoning_content
for response in choice_generator:
if len(response.choices) != len(all_choices):
if (all_choices == []):
for i, choice in enumerate(response.choices):
content, reasoning_content = split_think_tag(choice.delta.content)
all_choices.append(ChatCompletionResponseChoice(
index=i,
message=ChatMessage(
role=choice.delta.role or RoleEnum.assistant,
content=content or None,
reasoning_content=reasoning_content or None,
function_call=choice.delta.function_call or None,
),
finish_reason=choice.finish_reason,
logprobs=choice.logprobs,
)
)
else:
raise ValueError(f"response.choices has different length! "
f"[{response.choices}] vs [{all_choices}].")
else:
for i in range(len(all_choices)):
if all_choices[i].message.content == None:
all_choices[i].message.content = (response.choices[i].delta.content or None)
else:
all_choices[i].message.content += (response.choices[i].delta.content or "")
content, reasoning_content = split_think_tag(all_choices[i].message.content)
all_choices[i].message.content = content
all_choices[i].message.reasoning_content = reasoning_content
all_choices[i].message.role = response.choices[i].delta.role or all_choices[i].message.role
all_choices[i].message.function_call = response.choices[i].delta.function_call or all_choices[i].message.function_call
all_choices[i].finish_reason = response.choices[i].finish_reason or all_choices[i].finish_reason
if all_choices[i].logprobs != None:
if response.choices[i].logprobs != None:
all_choices[i].logprobs.content += response.choices[i].logprobs.content
else:
all_choices[i].logprobs = response.choices[i].logprobs
usage = response.usage or usage
aux_info = response.aux_info or aux_info
if (usage == None):
logging.warning(f"No usage returned from stream response. use empty value.")
usage = UsageInfo(
prompt_tokens=0,
total_tokens=0,
completion_tokens=0
)
chat_response = ChatCompletionResponse(
choices=all_choices,
usage=usage,
aux_info=aux_info,
model="AsyncModel",
)
return chat_response.model_dump_json(exclude_none=True)
def _check_finish_reason(self, token_ids: List[int], input_token_length: int, max_new_tokens: int = -1) -> Optional[FinisheReason]:
stop_word_ids_list_all = self.get_all_extra_stop_word_ids_list() + self.stop_words_id_list
if max_new_tokens > 0 and len(token_ids) >= max_new_tokens:
return FinisheReason.length
if len(token_ids) + input_token_length >= self.max_seq_len:
return FinisheReason.length
if token_ids and token_ids[-1] == self.eos_token_id:
return FinisheReason.stop
for stop_word_ids in stop_word_ids_list_all:
if (len(token_ids) >= len(stop_word_ids)) and (token_ids[-len(stop_word_ids):] == stop_word_ids):
return FinisheReason.stop
return None
def _remove_stop_word_ids(self, output_ids: List[int]) -> List[int]:
stop_word_ids_list_all = self.get_all_extra_stop_word_ids_list() + self.stop_words_id_list
for stop_word_ids in stop_word_ids_list_all:
# 此处应该从最大的范围开始判断
# 有可能会有stopword_ids 重复的情况,比如[144575, 14098, 144575]
# 若从1开始判断会导致 去除了最后一个 144575 就退出了
for i in range(len(stop_word_ids) + 1, 1, -1):
if output_ids[-i:] == stop_word_ids[:i]:
output_ids = output_ids[:-i]
break
return output_ids
def _clean_output_ids(self, output_ids_tensor: torch.Tensor) -> list[int]:
output_ids_tensor = output_ids_tensor.cpu().reshape([-1])
# TODO(wangyin): This slicing shouldn't be done here.
# model should return output length, ids should be sliced with output length.
output_ids = output_ids_tensor[output_ids_tensor != self.eos_token_id].tolist()
return output_ids