maga_transformer/cpp/model_rpc/model_rpc_client.py (188 lines of code) (raw):
import sys
import os
from typing import Any, Optional, AsyncGenerator
import asyncio
import numpy as np
import functools
import logging
import torch
import grpc
from grpc import StatusCode
from maga_transformer.utils.util import AtomicCounter
from maga_transformer.cpp.proto.model_rpc_service_pb2_grpc import RpcServiceStub
from maga_transformer.models.base_model import GenerateInput, GenerateOutput, GenerateOutputs, AuxInfo
from maga_transformer.cpp.proto.model_rpc_service_pb2 import TensorPB
from maga_transformer.cpp.proto.model_rpc_service_pb2 import MMPreprocessConfigPB
from maga_transformer.cpp.proto.model_rpc_service_pb2 import MultimodalInputPB
from maga_transformer.cpp.proto.model_rpc_service_pb2 import GenerateInputPB
from maga_transformer.cpp.proto.model_rpc_service_pb2 import GenerateOutputsPB
from maga_transformer.cpp.proto.model_rpc_service_pb2 import ErrorDetailsPB
from maga_transformer.distribute.worker_info import g_master_info, WorkerInfo
from maga_transformer.distribute.worker_info import g_worker_info, g_parallel_info
from maga_transformer.config.exceptions import FtRuntimeException, ExceptionType
from maga_transformer.config.gpt_init_model_parameters import GptInitModelParameters
from maga_transformer.utils.grpc_util import trans_option, trans_option_cast, trans_tensor
from maga_transformer.distribute.gang_info import get_gang_info, GangInfo
from maga_transformer.utils.concurrency_controller import ConcurrencyException, get_global_controller
MAX_GRPC_TIMEOUT_SECONDS = 3600
def trans_input(input_py: GenerateInput):
input_pb = GenerateInputPB()
input_pb.request_id = input_py.request_id
input_pb.token_ids.extend(input_py.token_ids.reshape(-1).tolist())
trans_multimodal_input(input_py, input_pb)
generate_config_pb = input_pb.generate_config
generate_config_pb.max_new_tokens = input_py.generate_config.max_new_tokens
generate_config_pb.max_thinking_tokens = input_py.generate_config.max_thinking_tokens
generate_config_pb.end_think_token_ids.extend(input_py.generate_config.end_think_token_ids)
generate_config_pb.in_think_mode = input_py.generate_config.in_think_mode
generate_config_pb.num_beams = input_py.generate_config.num_beams
generate_config_pb.num_return_sequences = input_py.generate_config.num_return_sequences
generate_config_pb.min_new_tokens = input_py.generate_config.min_new_tokens
generate_config_pb.top_k = input_py.generate_config.top_k
generate_config_pb.top_p = input_py.generate_config.top_p
generate_config_pb.temperature = input_py.generate_config.temperature
generate_config_pb.sp_edit = input_py.generate_config.sp_edit
generate_config_pb.force_disable_sp_run = input_py.generate_config.force_disable_sp_run
generate_config_pb.repetition_penalty = input_py.generate_config.repetition_penalty
trans_option(generate_config_pb, input_py.generate_config, "no_repeat_ngram_size")
trans_option(generate_config_pb, input_py.generate_config, "random_seed")
trans_option(generate_config_pb, input_py.generate_config, "top_p_decay")
trans_option(generate_config_pb, input_py.generate_config, "top_p_min")
trans_option(generate_config_pb, input_py.generate_config, "top_p_reset_ids")
trans_option(generate_config_pb, input_py.generate_config, "adapter_name")
trans_option_cast(generate_config_pb, input_py.generate_config, "task_id", functools.partial(str))
generate_config_pb.select_tokens_id.extend(input_py.generate_config.select_tokens_id)
generate_config_pb.calculate_loss = input_py.generate_config.calculate_loss
generate_config_pb.return_logits = input_py.generate_config.return_logits
generate_config_pb.return_incremental = input_py.generate_config.return_incremental
generate_config_pb.return_hidden_states = input_py.generate_config.return_hidden_states
generate_config_pb.is_streaming = input_py.generate_config.is_streaming
generate_config_pb.timeout_ms = input_py.generate_config.timeout_ms
if input_py.generate_config.sp_advice_prompt_token_ids:
generate_config_pb.sp_advice_prompt_token_ids.extend(input_py.generate_config.sp_advice_prompt_token_ids)
generate_config_pb.return_cum_log_probs = input_py.generate_config.return_cum_log_probs
generate_config_pb.return_all_probs = input_py.generate_config.return_all_probs
generate_config_pb.return_softmax_probs = input_py.generate_config.return_softmax_probs
generate_config_pb.can_use_pd_separation = input_py.generate_config.can_use_pd_separation
generate_config_pb.gen_timeline = input_py.generate_config.gen_timeline
generate_config_pb.global_request_id = input_py.generate_config.global_request_id
for i in range(len(input_py.generate_config.stop_words_list)):
stop_words = generate_config_pb.stop_words_list.rows.add()
stop_words.values.extend(input_py.generate_config.stop_words_list[i])
return input_pb
def trans_multimodal_input(input_py: GenerateInput, input_pb: GenerateInputPB):
for mm_input in input_py.mm_inputs:
mm_input_pb = MultimodalInputPB()
mm_input_pb.multimodal_url = mm_input.url
mm_input_pb.multimodal_type = mm_input.mm_type
mm_preprocess_config_pb = mm_input_pb.mm_preprocess_config
mm_preprocess_config_pb.width = mm_input.config.width
mm_preprocess_config_pb.height = mm_input.config.height
mm_preprocess_config_pb.min_pixels = mm_input.config.min_pixels
mm_preprocess_config_pb.max_pixels = mm_input.config.max_pixels
mm_preprocess_config_pb.fps = mm_input.config.fps
mm_preprocess_config_pb.min_frames = mm_input.config.min_frames
mm_preprocess_config_pb.max_frames = mm_input.config.max_frames
input_pb.multimodal_inputs.append(mm_input_pb)
def trans_output(input_py: GenerateInput, outputs_pb: GenerateOutputsPB) -> GenerateOutputs:
logging.debug("outputs_pb = " + str(outputs_pb))
outputs_py = GenerateOutputs()
for output_pb in outputs_pb.generate_outputs:
output_py = GenerateOutput()
output_py.finished = output_pb.finished
output_py.aux_info = AuxInfo(cost_time=output_pb.aux_info.cost_time_us / 1000.0,
first_token_cost_time=output_pb.aux_info.first_token_cost_time_us / 1000.0,
wait_time=output_pb.aux_info.wait_time_us / 1000.0,
iter_count=output_pb.aux_info.iter_count,
input_len=output_pb.aux_info.input_len,
reuse_len=output_pb.aux_info.reuse_len,
prefix_len=output_pb.aux_info.prefix_len,
output_len=output_pb.aux_info.output_len,
step_output_len=output_pb.aux_info.step_output_len,
fallback_tokens=output_pb.aux_info.fallback_tokens,
fallback_times=output_pb.aux_info.fallback_times,
pd_sep=output_pb.aux_info.pd_sep)
# TODO(xinfei.sxf) cum_log_probs is not right, ignore it temporarily
if output_pb.aux_info.HasField('cum_log_probs'):
output_py.aux_info.cum_log_probs = trans_tensor(output_pb.aux_info.cum_log_probs).tolist()
if output_pb.aux_info.HasField('softmax_probs'):
output_py.aux_info.softmax_probs = trans_tensor(output_pb.aux_info.softmax_probs).tolist()
output_py.output_ids = trans_tensor(output_pb.output_ids)
output_py.input_ids = input_py.token_ids.reshape(1, -1)
if output_pb.HasField('hidden_states'):
output_py.hidden_states = trans_tensor(output_pb.hidden_states)
if output_pb.HasField('loss'):
# when calculate_loss 1, result should be one element
if input_py.generate_config.calculate_loss == 1:
output_py.loss = trans_tensor(output_pb.loss)[0]
else:
output_py.loss = trans_tensor(output_pb.loss)
if output_pb.HasField('logits'):
output_py.logits = trans_tensor(output_pb.logits)
if output_pb.HasField('all_probs'):
output_py.all_probs = trans_tensor(output_pb.all_probs)
outputs_py.generate_outputs.append(output_py)
return outputs_py
class ModelRpcClient(object):
def __init__(self, config: GptInitModelParameters, address: Optional[str] = None):
# 创建到服务器的连接
if not address:
address = f'localhost:{g_worker_info.rpc_server_port}'
self._addresses = []
# for test usage
hack_ep_single_entry = bool(int(os.environ.get('HACK_EP_SINGLE_ENTRY', 0)))
logging.info(f"hack ep single entry: {hack_ep_single_entry}")
if (g_parallel_info.dp_size > 1) and (not hack_ep_single_entry):
members_info_str = f"[world_rank: {g_parallel_info.world_rank}]"+ \
f"[tp_size: {g_parallel_info.tp_size}] all members: " + "{"
members = get_gang_info().members
for member in members:
members_info_str += f"{member}\n"
if member.local_rank % g_parallel_info.tp_size == 0:
self._addresses.append(f'{member.ip}:{member.rpc_server_port}')
members_info_str += "}"
logging.info(f"{members_info_str}")
else:
self._addresses = [address]
logging.info(f"client connect to rpc addresses: {self._addresses}")
self.model_config = config
async def enqueue(self, input_py: GenerateInput) -> AsyncGenerator[GenerateOutputs, None]:
request_timeout_ms = input_py.generate_config.timeout_ms
rpc_timeout_ms = self.model_config.max_rpc_timeout_ms \
if self.model_config.max_rpc_timeout_ms > 0 else MAX_GRPC_TIMEOUT_SECONDS * 1000
if request_timeout_ms == None or request_timeout_ms <= 0:
grpc_timeout_seconds = rpc_timeout_ms / 1000
else:
grpc_timeout_seconds = request_timeout_ms / 1000
input_py.generate_config.timeout_ms = (int)(grpc_timeout_seconds * 1000)
input_pb = trans_input(input_py)
response_iterator = None
try:
async with grpc.aio.insecure_channel(self._addresses[input_py.request_id % len(self._addresses)]) as channel:
stub = RpcServiceStub(channel)
response_iterator = stub.GenerateStreamCall(input_pb, timeout=grpc_timeout_seconds)
# 调用服务器方法并接收流式响应
count = 0
async for response in response_iterator.__aiter__():
count += 1
yield trans_output(input_py, response)
except grpc.RpcError as e:
# TODO(xinfei.sxf) 非流式的请求无法取消了
if response_iterator:
response_iterator.cancel()
error_details = ErrorDetailsPB()
metadata = e.trailing_metadata()
if 'grpc-status-details-bin' in metadata and error_details.ParseFromString(metadata['grpc-status-details-bin']):
logging.error(f"request: [{input_pb.request_id}] RPC failed: "
f"{e.code()}, {e.details()}, detail error code is "
f"{ExceptionType.from_value(error_details.error_code)}")
raise FtRuntimeException(ExceptionType(error_details.error_code), error_details.error_message)
else:
logging.error(f"request: [{input_pb.request_id}] RPC failed: "
f"error code is {e.code()}, detail is {e.details()}")
if e.code() == StatusCode.DEADLINE_EXCEEDED:
raise FtRuntimeException(ExceptionType.GENERATE_TIMEOUT, e.details())
elif e.code() == StatusCode.CANCELLED:
raise FtRuntimeException(ExceptionType.CANCELLED_ERROR, e.details())
else:
raise FtRuntimeException(ExceptionType.UNKNOWN_ERROR, e.details())
except Exception as e:
logging.error(f'rpc unknown error:{str(e)}')
raise e
finally:
if response_iterator:
response_iterator.cancel()