maga_transformer/pipeline/pipeline.py (256 lines of code) (raw):
import os
import logging
import torch
import asyncio
import threading
import platform
import queue
import json
from typing import Any, List, Union, Iterator, Tuple, Callable, Optional, Dict, Generator, AsyncGenerator
from concurrent.futures import Future
from torch.nn.utils.rnn import pad_sequence
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from maga_transformer.utils.util import AtomicCounter
from maga_transformer.utils.time_util import current_time_ms
from maga_transformer.config.exceptions import ExceptionType, FtRuntimeException
from maga_transformer.config.generate_config import GenerateConfig
from maga_transformer.metrics import kmonitor, GaugeMetrics
from maga_transformer.models.base_model import BaseModel, GenerateOutput, GenerateOutputs, GenerateResponse, GenerateInput
from maga_transformer.utils.multimodal_util import MultimodalInput
from maga_transformer.model_factory import ModelFactory, AsyncModel, ModelConfig
from maga_transformer.async_decoder_engine.backend_rpc_server_visitor import BackendRPCServerVisitor
from maga_transformer.pipeline.pipeline_custom_func import PipelineCustomFunc, get_piple_custom_func
from maga_transformer.utils.word_util import remove_padding_eos, get_stop_word_slices, \
truncate_response_with_stop_words, truncate_token_with_stop_word_id, match_stop_words
from maga_transformer.utils.tokenizer_utils import DecodingState
from maga_transformer.utils.weight_type import WEIGHT_TYPE
from maga_transformer.utils.mm_process_engine import MMProcessEngine
from maga_transformer.config.gpt_init_model_parameters import GptInitModelParameters
request_counter = AtomicCounter()
class Pipeline(object):
def __init__(self, model_cls: Union["BaseModel", BaseModel],
model_config: GptInitModelParameters, tokenizer: Optional[PreTrainedTokenizerBase]):
self.model_cls = model_cls
self.model_config = model_config
self.tokenizer = tokenizer
self._special_tokens: int = self.model_config.special_tokens
self._mm_token: str = self.model_config.mm_related_params.special_tokens.get('default_mm_token', '')
self.piple_funcs: PipelineCustomFunc = get_piple_custom_func(self.model_cls)
self.backend_rpc_server_visitor = BackendRPCServerVisitor(model_config)
def stop(self):
if isinstance(self.model_cls, AsyncModel):
logging.info("async model stop")
self.model_cls.stop()
def encode(self, prompt: str):
assert self.tokenizer is not None
return self.tokenizer.encode(prompt)
def decode(self, token_id: int):
assert self.tokenizer is not None
return self.tokenizer.decode([token_id])
@staticmethod
def create_generate_config(generate_config: Union[GenerateConfig, Dict[str, Any]], vocab_size: int,
special_tokens: Any, tokenizer: PreTrainedTokenizerBase, **kwargs: Any) -> GenerateConfig:
if isinstance(generate_config, dict):
config = GenerateConfig.create_generate_config(generate_config, **kwargs)
else:
# 认为是从frontend_worker传递进来的,不需要再处理一遍
config = generate_config
config.add_special_tokens(special_tokens)
config.convert_select_tokens(vocab_size, tokenizer)
config.add_thinking_params(tokenizer)
config.add_stop_ids_from_str(tokenizer)
config.validate()
return config
def __call__(self, prompt: str, urls: Optional[List[str]] = None, **kwargs: Any) -> Iterator[GenerateResponse]:
# if not multimodal model, just pass [[]] * len(prompt)
return self.pipeline(prompt, urls = urls, **kwargs)
def pipeline(self,
prompt: str,
request_id: int = None,
urls: Optional[List[str]] = None,
**kwargs: Any) -> Iterator[GenerateResponse]:
q = queue.Queue()
async def generator():
res = None
try:
res = self.pipeline_async(prompt, request_id, urls, **kwargs)
async for x in res:
q.put(x)
q.put(None)
except Exception as e:
q.put(e)
finally:
# if pipline break, should call aclose() to remove async_generator task from loop
if res is not None:
res.aclose()
def start_loop():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(generator())
backgroud_thread = threading.Thread(target=start_loop)
backgroud_thread.start()
try:
while True:
try:
r = q.get(timeout=0.01)
if r is None:
break
if isinstance(r, Exception):
raise r
yield r
except queue.Empty:
continue
finally:
backgroud_thread.join()
@torch.inference_mode()
def pipeline_async( # type: ignore
self,
prompt: str,
request_id: int = None,
urls: Optional[List[str]] = None,
**kwargs: Any
) -> AsyncGenerator[GenerateResponse, None]:
begin_time = current_time_ms()
if request_id == None:
request_id = request_counter.increment()
generate_config_json = kwargs.pop("generate_config", {})
generate_config = self.create_generate_config(generate_config_json, self.model_config.vocab_size,
self.model_config.special_tokens, self.tokenizer, **kwargs)
# for delete stop word from output
prompt = self.piple_funcs.modify_prompt_func(prompt, generate_config=generate_config.model_dump(), **kwargs)
mm_inputs = []
if self.model_config.is_multimodal:
prompt, mm_inputs = self.piple_funcs.multimodal_modify_prompt_func(prompt, urls, self._mm_token,
generate_config=generate_config.model_dump(), **kwargs)
token_ids = self.piple_funcs.process_encode_func(prompt,
generate_config=generate_config.model_dump(),
tokenizer=self.tokenizer,
add_special_tokens=self.model_config.add_special_tokens,
special_tokens=self._special_tokens,
**kwargs)
if generate_config.sp_advice_prompt != "":
generate_config.sp_advice_prompt_token_ids = self.tokenizer.encode(generate_config.sp_advice_prompt)
kmonitor.report(GaugeMetrics.PRE_PIPELINE_RT_METRIC, current_time_ms() - begin_time)
kmonitor.report(GaugeMetrics.NUM_BEAMS_METRIC, generate_config.num_beams)
kmonitor.report(GaugeMetrics.INPUT_TOKEN_SIZE_METRIC, len(token_ids))
return self.generate_stream(request_id, token_ids, mm_inputs, generate_config, **kwargs)
def process_stop_id(self,
generate_config: GenerateConfig,
generate_output: GenerateOutput,
tokens,
stop_word_ids: List[List[int]],
stop_word_id_slices: List[List[int]]):
if not generate_config.print_stop_words:
if not generate_output.finished:
tokens = truncate_token_with_stop_word_id(tokens, stop_word_id_slices)
else:
tokens = truncate_token_with_stop_word_id(tokens, stop_word_ids)
return tokens
def process_stop_str(self,
generate_config: GenerateConfig,
generate_output: GenerateOutput,
text: str, all_text: str,
stop_word_str_list: List[str],
stop_word_str_slices: List[str],
token_buffer: str,
**kwargs: Any):
generate_output.finished = self.piple_funcs.stop_generate_func(all_text, **kwargs) or generate_output.finished
if stop_word_str_list and not generate_output.finished and match_stop_words(all_text, stop_word_str_list):
generate_output.finished = True
if not generate_config.print_stop_words:
if not generate_config.return_incremental:
if not generate_output.finished:
text = truncate_response_with_stop_words(text, stop_word_str_slices, generate_config.is_streaming)
else:
text = truncate_response_with_stop_words(text, stop_word_str_list, generate_config.is_streaming)
else:
if not generate_output.finished:
text = token_buffer + text
trunc_text = truncate_response_with_stop_words(text, stop_word_str_slices, generate_config.is_streaming)
token_buffer = text[len(trunc_text):]
text = trunc_text
else:
text = truncate_response_with_stop_words(token_buffer + text, stop_word_str_list, generate_config.is_streaming)
return text, token_buffer
def decode_tokens(self,
generate_config: GenerateConfig,
generate_outputs: GenerateOutputs,
stop_word_str_list: List[str],
stop_word_str_slices: List[str],
stop_word_ids: List[int],
stop_word_id_slices: List[int],
decoding_states: List[DecodingState],
token_buffers: List[str],
ouput_tokens_list: List[torch.Tensor],
**kwargs: Any) -> Tuple[List[str], List[int], List[DecodingState], List[str], List[torch.Tensor]]:
texts = []
all_texts = []
output_lens = []
if len(decoding_states) == 0:
if generate_config.num_beams == 1 and generate_config.is_streaming:
decoding_states = [DecodingState() for _ in range(len(generate_outputs.generate_outputs))]
else:
# num_beams不等于1的情况下,不能进行增量decode,因为过去的token id会变化
decoding_states = [None] * len(generate_outputs.generate_outputs)
if len(token_buffers) == 0:
token_buffers = [""] * len(generate_outputs.generate_outputs)
if len(ouput_tokens_list) == 0:
ouput_tokens_list = [torch.empty(0, dtype=torch.int32) for _ in range(len(generate_outputs.generate_outputs))]
# TODO(xinfei.sxf) remove i
i = 0
for generate_output in generate_outputs.generate_outputs:
# all model incremental return output_ids
if generate_config.num_beams == 1:
ouput_tokens_list[i] = torch.cat((ouput_tokens_list[i], generate_output.output_ids), dim=1)
generate_output.output_ids = ouput_tokens_list[i]
tokens = generate_output.output_ids
tokens = remove_padding_eos(tokens, self._special_tokens.eos_token_id)
output_lens.append(tokens.nelement())
tokens = self.process_stop_id(generate_config, generate_output, tokens.tolist(), stop_word_ids, stop_word_id_slices)
text, all_text = self.piple_funcs.process_decode_func(tokens,
generate_config=generate_config.model_dump(),
tokenizer=self.tokenizer,
decoding_state=decoding_states[i],
return_incremental=generate_config.return_incremental,
**kwargs)
text, token_buffers[i] = self.process_stop_str(generate_config, generate_output, text, all_text, stop_word_str_list,
stop_word_str_slices, token_buffers[i], **kwargs)
text = self.piple_funcs.modify_response_func(
text, hidden_states=generate_output.hidden_states,
generate_config=generate_config.model_dump(),
**kwargs)
texts.append(text)
all_texts.append(all_text)
i += 1
return texts, output_lens, decoding_states, token_buffers, ouput_tokens_list
@torch.inference_mode()
async def generate_stream(self, request_id: int, token_ids: List[int], mm_inputs: List[MultimodalInput],
generate_config: GenerateConfig, **kwargs: Any) -> AsyncGenerator[GenerateResponse, None]:
token_type_ids = []
token_ids = torch.tensor(token_ids, dtype=torch.int)
input = GenerateInput(request_id=request_id,
token_ids=token_ids,
mm_inputs=mm_inputs,
generate_config=generate_config,
tokenizer=self.tokenizer,
token_type_ids=token_type_ids)
stop_word_strs = generate_config.stop_words_str
stop_word_str_slices = get_stop_word_slices(stop_word_strs)
stop_word_ids = generate_config.stop_words_list
stop_word_id_slices = get_stop_word_slices(stop_word_ids)
stream: AsyncGenerator[GenerateOutputs, None] = self.backend_rpc_server_visitor.enqueue(input)
decoding_states: List[DecodingState] = []
ouput_tokens_list: List[torch.Tensor] = []
token_buffers: List[str] = []
generate_outputs_cache = GenerateOutputs()
# TODO(xinfei.sxf) add batch and stop test
async for generate_outputs in stream:
if not generate_outputs_cache.generate_outputs:
generate_outputs_cache.generate_outputs = generate_outputs.generate_outputs
else:
generate_outputs_cache.generate_outputs = [out if out.finished else generate_outputs.generate_outputs[i]
for i, out in enumerate(generate_outputs_cache.generate_outputs)]
assert len(generate_outputs_cache.generate_outputs) == len(generate_outputs.generate_outputs)
begin_time = current_time_ms()
generate_texts, output_lens, decoding_states, token_buffers, ouput_tokens_list = self.decode_tokens(
generate_config, generate_outputs_cache, stop_word_strs, stop_word_str_slices,
stop_word_ids, stop_word_id_slices, decoding_states, token_buffers, ouput_tokens_list, **kwargs)
kmonitor.report(GaugeMetrics.POST_PIPELINE_RT_METRIC, current_time_ms() - begin_time)
yield GenerateResponse(generate_outputs=generate_outputs_cache, generate_texts=generate_texts)
if all(output.finished for output in generate_outputs_cache.generate_outputs):
kmonitor.report(GaugeMetrics.FT_ITERATE_COUNT_METRIC, generate_outputs_cache.generate_outputs[0].aux_info.iter_count)
for l in output_lens:
kmonitor.report(GaugeMetrics.OUTPUT_TOKEN_SIZE_METRIC, l)
break