maga_transformer/openai/openai_endpoint.py (249 lines of code) (raw):
from fastapi import Request
import torch
from typing import Union, Optional, List, Dict, Generator, Coroutine, AsyncGenerator, Any, Iterator
import os
import json
import logging
import json
from functools import partial
from transformers import PreTrainedTokenizerBase
from maga_transformer.utils.util import str_to_bool
from maga_transformer.utils.complete_response_async_generator import CompleteResponseAsyncGenerator
from transformers import PreTrainedTokenizerBase
from maga_transformer.openai.api_datatype import ModelCard, ModelList, ChatMessage, RoleEnum, \
ChatCompletionRequest, ChatCompletionResponse, ChatCompletionResponseChoice, UsageInfo, \
ChatCompletionStreamResponse, \
DebugInfo
from maga_transformer.openai.renderers.custom_renderer import RendererParams, \
StreamResponseObject, RenderedInputs, CustomChatRenderer
from maga_transformer.openai.renderer_factory import ChatRendererFactory
from maga_transformer.openai.renderers.basic_renderer import BasicRenderer
from maga_transformer.config.generate_config import GenerateConfig
from maga_transformer.utils.mm_process_engine import MMProcessEngine
from maga_transformer.config.gpt_init_model_parameters import GptInitModelParameters
from maga_transformer.async_decoder_engine.backend_rpc_server_visitor import BackendRPCServerVisitor
class OpenaiEndopoint():
def __init__(self, model_config: GptInitModelParameters,
tokenizer: PreTrainedTokenizerBase,
backend_rpc_server_visitor: BackendRPCServerVisitor):
self.model_config = model_config
self.max_seq_len = self.model_config.max_seq_len
if (tokenizer == None):
raise AttributeError(f"tokenizer is none!")
self.tokenizer: PreTrainedTokenizerBase = tokenizer
self.backend_rpc_server_visitor = backend_rpc_server_visitor
self.eos_token_id = None
if (isinstance(tokenizer, PreTrainedTokenizerBase)):
self.eos_token_id = tokenizer.eos_token_id
if (self.eos_token_id == None):
self.eos_token_id = self.model_config.special_tokens.eos_token_id
self.stop_words_id_list = self.model_config.special_tokens.stop_words_id_list
render_params = RendererParams(
model_type=os.environ["MODEL_TYPE"],
max_seq_len=self.max_seq_len,
eos_token_id=self.eos_token_id,
stop_word_ids_list=self.stop_words_id_list,
template_type=self.model_config.template_type,
ckpt_path=self.model_config.ckpt_path
)
self.chat_renderer: CustomChatRenderer = ChatRendererFactory.get_renderer(self.tokenizer, render_params)
logging.info(f"Finally openai endpoint uses renderer: {self.chat_renderer} ")
self.template_renderer: CustomChatRenderer = self.chat_renderer \
if isinstance(self.chat_renderer,BasicRenderer) \
else BasicRenderer(self.tokenizer, render_params)
logging.info(f"chat_renderer [{self.chat_renderer}] is created.")
extra_stop_word_ids_list = self.chat_renderer.get_all_extra_stop_word_ids_list()
self.stop_words_id_list.extend(extra_stop_word_ids_list)
self.stop_words_str_list = []
for stop_word_ids in self.stop_words_id_list:
word = self.tokenizer.decode(stop_word_ids)
if len(word):
self.stop_words_str_list.append(word)
env_stop_words_str = os.environ.get('STOP_WORDS_STR', None)
env_stop_words_id = os.environ.get('STOP_WORDS_LIST', None)
env_stop_words_str_list = json.loads(env_stop_words_str) if env_stop_words_str else []
env_stop_words_id_list = json.loads(env_stop_words_id) if env_stop_words_id else []
env_force_stop = os.environ.get('FORCE_STOP_WORDS', None)
if env_force_stop and str_to_bool(env_force_stop):
self.stop_words_str_list = env_stop_words_str_list
self.stop_words_id_list = env_stop_words_id_list
else:
self.stop_words_str_list = self.stop_words_str_list + env_stop_words_str_list
self.stop_words_id_list = self.stop_words_id_list + env_stop_words_id_list
logging.info(f"use stop_words_str_list [{self.stop_words_str_list}], " \
f"stop_words_id_list [{self.stop_words_id_list}]")
async def list_models(self):
global model_args
model_card = ModelCard(id=self.model_config.model_name)
return ModelList(data=[model_card])
def _extract_generation_config(self, request: ChatCompletionRequest) -> GenerateConfig:
# TODO(wangyin): implement this
config = request.extra_configs or GenerateConfig()
if request.stream != None:
config.is_streaming = request.stream
if request.temperature != None:
config.temperature = request.temperature
if request.top_p != None:
config.top_p = request.top_p
if request.max_tokens != None:
config.max_new_tokens = request.max_tokens
if request.n != None:
config.num_return_sequences = request.n
request_stop_words_list = request.stop if request.stop != None else []
if isinstance(request_stop_words_list, str):
request_stop_words_list = [request_stop_words_list]
config.stop_words_str = self.stop_words_str_list + request_stop_words_list
config.stop_words_list = self.stop_words_id_list + self.chat_renderer.tokenize_words(request_stop_words_list)
if request.chat_id != None:
config.chat_id = request.chat_id
if request.seed != None:
config.random_seed = request.seed
if request.logprobs != None:
config.return_all_probs = request.logprobs
if request.logprobs or request.functions:
config.is_streaming = True
config.add_special_tokens(self.model_config.special_tokens)
config.convert_select_tokens(self.model_config.vocab_size, self.tokenizer)
if request.extend_fields and "max_thinking_tokens" in request.extend_fields.keys() \
and isinstance(request.extend_fields["max_thinking_tokens"], int):
config.max_thinking_tokens = request.extend_fields["max_thinking_tokens"]
config.add_thinking_params(self.tokenizer)
return config
async def _collect_complete_response(
self,
choice_generator: Optional[AsyncGenerator[StreamResponseObject, None]],
debug_info: Optional[DebugInfo]) -> ChatCompletionResponse:
all_choices = []
usage = None
aux_info = None
async for response in choice_generator:
if len(response.choices) != len(all_choices):
if (all_choices == []):
all_choices = [
ChatCompletionResponseChoice(
index=i,
message=ChatMessage(
role=choice.delta.role or RoleEnum.assistant,
content=choice.delta.content or None,
function_call=choice.delta.function_call or None,
tool_calls=choice.delta.tool_calls or None,
),
finish_reason=choice.finish_reason,
logprobs=choice.logprobs,
) for i, choice in enumerate(response.choices)
]
else:
raise ValueError(f"response.choices has different length! "
f"[{response.choices}] vs [{all_choices}].")
else:
for i in range(len(all_choices)):
if all_choices[i].message.content == None:
all_choices[i].message.content = (response.choices[i].delta.content or None)
else:
all_choices[i].message.content += (response.choices[i].delta.content or "")
if all_choices[i].message.reasoning_content == None:
all_choices[i].message.reasoning_content = (response.choices[i].delta.reasoning_content or None)
else:
all_choices[i].message.reasoning_content += (response.choices[i].delta.reasoning_content or "")
all_choices[i].message.role = response.choices[i].delta.role or all_choices[i].message.role
all_choices[i].message.function_call = response.choices[i].delta.function_call or all_choices[i].message.function_call
all_choices[i].message.tool_calls = (
response.choices[i].delta.tool_calls
or all_choices[i].message.tool_calls
)
all_choices[i].finish_reason = response.choices[i].finish_reason or all_choices[i].finish_reason
if all_choices[i].logprobs != None:
if response.choices[i].logprobs != None:
all_choices[i].logprobs.content += response.choices[i].logprobs.content
else:
all_choices[i].logprobs = response.choices[i].logprobs
usage = response.usage or usage
aux_info = response.aux_info or aux_info
if (usage == None):
logging.warning(f"No usage returned from stream response. use empty value.")
usage = UsageInfo(
prompt_tokens=0,
total_tokens=0,
completion_tokens=0
)
return ChatCompletionResponse(
choices=all_choices,
usage=usage,
aux_info=aux_info,
model=self.model_config.model_name,
debug_info=debug_info,
)
def _complete_stream_response(
self, choice_generator: AsyncGenerator[StreamResponseObject, None],
debug_info: Optional[DebugInfo]
) -> CompleteResponseAsyncGenerator:
async def response_generator():
debug_info_responded = False
async for response in choice_generator:
yield ChatCompletionStreamResponse(
choices=response.choices,
usage=response.usage,
aux_info=response.aux_info,
debug_info=debug_info if not debug_info_responded else None
)
debug_info_responded = True
complete_response_collect_func = partial(self._collect_complete_response, debug_info=debug_info)
return CompleteResponseAsyncGenerator(response_generator(), complete_response_collect_func)
def _get_debug_info(self, renderer: CustomChatRenderer,
renderered_input: RenderedInputs, gen_config: GenerateConfig) -> DebugInfo:
if renderered_input.rendered_prompt != "":
prompt = renderered_input.rendered_prompt
else:
prompt = self.tokenizer.decode(renderered_input.input_ids)
return DebugInfo(
input_prompt=prompt,
input_ids=renderered_input.input_ids,
input_urls=[mm_input.url for mm_input in renderered_input.multimodal_inputs],
tokenizer_info=str(self.tokenizer),
max_seq_len=self.max_seq_len,
eos_token_id=self.eos_token_id,
stop_word_ids_list=self.stop_words_id_list,
stop_words_list=self.stop_words_str_list,
renderer_info=renderer.get_renderer_info(),
generate_config=gen_config
)
def render_chat(self, chat_request: ChatCompletionRequest):
renderer = self.template_renderer if chat_request.user_template else self.chat_renderer
prepopulate_str = ""
if len(chat_request.messages) > 0 and chat_request.messages[-1].partial:
prepopulate_str = str(chat_request.messages[-1].content)
chat_request.messages.pop()
rendered_input = renderer.render_chat(chat_request)
if prepopulate_str != "":
rendered_input.rendered_prompt += prepopulate_str
rendered_input.input_ids += self.tokenizer.encode(prepopulate_str)
return rendered_input
def chat_completion(
self, request_id: int, chat_request: ChatCompletionRequest, raw_request: Request
) -> CompleteResponseAsyncGenerator:
renderer = self.template_renderer if chat_request.user_template else self.chat_renderer
rendered_input = self.render_chat(chat_request)
generate_config = self._extract_generation_config(chat_request)
mm_inputs = []
if self.model_config.is_multimodal:
mm_inputs = rendered_input.multimodal_inputs
else:
mm_inputs = []
if generate_config.sp_advice_prompt != "":
generate_config.sp_advice_prompt_token_ids = self.tokenizer.encode(generate_config.sp_advice_prompt)
debug_info = self._get_debug_info(renderer, rendered_input, generate_config) \
if chat_request.debug_info else None
choice_generator = renderer.generate_choice(
request_id,
rendered_input.input_ids,
mm_inputs,
generate_config,
self.backend_rpc_server_visitor,
chat_request
)
return self._complete_stream_response(choice_generator, debug_info)
def chat_render(self, chat_request: ChatCompletionRequest) -> DebugInfo:
renderer = self.template_renderer if chat_request.user_template else self.chat_renderer
rendered_input = renderer.render_chat(chat_request)
generate_config = self._extract_generation_config(chat_request)
debug_info = self._get_debug_info(renderer, rendered_input, generate_config)
return debug_info