maga_transformer/cpp/model_rpc/model_rpc_client.py (188 lines of code) (raw):

import sys import os from typing import Any, Optional, AsyncGenerator import asyncio import numpy as np import functools import logging import torch import grpc from grpc import StatusCode from maga_transformer.utils.util import AtomicCounter from maga_transformer.cpp.proto.model_rpc_service_pb2_grpc import RpcServiceStub from maga_transformer.models.base_model import GenerateInput, GenerateOutput, GenerateOutputs, AuxInfo from maga_transformer.cpp.proto.model_rpc_service_pb2 import TensorPB from maga_transformer.cpp.proto.model_rpc_service_pb2 import MMPreprocessConfigPB from maga_transformer.cpp.proto.model_rpc_service_pb2 import MultimodalInputPB from maga_transformer.cpp.proto.model_rpc_service_pb2 import GenerateInputPB from maga_transformer.cpp.proto.model_rpc_service_pb2 import GenerateOutputsPB from maga_transformer.cpp.proto.model_rpc_service_pb2 import ErrorDetailsPB from maga_transformer.distribute.worker_info import g_master_info, WorkerInfo from maga_transformer.distribute.worker_info import g_worker_info, g_parallel_info from maga_transformer.config.exceptions import FtRuntimeException, ExceptionType from maga_transformer.config.gpt_init_model_parameters import GptInitModelParameters from maga_transformer.utils.grpc_util import trans_option, trans_option_cast, trans_tensor from maga_transformer.distribute.gang_info import get_gang_info, GangInfo from maga_transformer.utils.concurrency_controller import ConcurrencyException, get_global_controller MAX_GRPC_TIMEOUT_SECONDS = 3600 def trans_input(input_py: GenerateInput): input_pb = GenerateInputPB() input_pb.request_id = input_py.request_id input_pb.token_ids.extend(input_py.token_ids.reshape(-1).tolist()) trans_multimodal_input(input_py, input_pb) generate_config_pb = input_pb.generate_config generate_config_pb.max_new_tokens = input_py.generate_config.max_new_tokens generate_config_pb.max_thinking_tokens = input_py.generate_config.max_thinking_tokens generate_config_pb.end_think_token_ids.extend(input_py.generate_config.end_think_token_ids) generate_config_pb.in_think_mode = input_py.generate_config.in_think_mode generate_config_pb.num_beams = input_py.generate_config.num_beams generate_config_pb.num_return_sequences = input_py.generate_config.num_return_sequences generate_config_pb.min_new_tokens = input_py.generate_config.min_new_tokens generate_config_pb.top_k = input_py.generate_config.top_k generate_config_pb.top_p = input_py.generate_config.top_p generate_config_pb.temperature = input_py.generate_config.temperature generate_config_pb.sp_edit = input_py.generate_config.sp_edit generate_config_pb.force_disable_sp_run = input_py.generate_config.force_disable_sp_run generate_config_pb.repetition_penalty = input_py.generate_config.repetition_penalty trans_option(generate_config_pb, input_py.generate_config, "no_repeat_ngram_size") trans_option(generate_config_pb, input_py.generate_config, "random_seed") trans_option(generate_config_pb, input_py.generate_config, "top_p_decay") trans_option(generate_config_pb, input_py.generate_config, "top_p_min") trans_option(generate_config_pb, input_py.generate_config, "top_p_reset_ids") trans_option(generate_config_pb, input_py.generate_config, "adapter_name") trans_option_cast(generate_config_pb, input_py.generate_config, "task_id", functools.partial(str)) generate_config_pb.select_tokens_id.extend(input_py.generate_config.select_tokens_id) generate_config_pb.calculate_loss = input_py.generate_config.calculate_loss generate_config_pb.return_logits = input_py.generate_config.return_logits generate_config_pb.return_incremental = input_py.generate_config.return_incremental generate_config_pb.return_hidden_states = input_py.generate_config.return_hidden_states generate_config_pb.is_streaming = input_py.generate_config.is_streaming generate_config_pb.timeout_ms = input_py.generate_config.timeout_ms if input_py.generate_config.sp_advice_prompt_token_ids: generate_config_pb.sp_advice_prompt_token_ids.extend(input_py.generate_config.sp_advice_prompt_token_ids) generate_config_pb.return_cum_log_probs = input_py.generate_config.return_cum_log_probs generate_config_pb.return_all_probs = input_py.generate_config.return_all_probs generate_config_pb.return_softmax_probs = input_py.generate_config.return_softmax_probs generate_config_pb.can_use_pd_separation = input_py.generate_config.can_use_pd_separation generate_config_pb.gen_timeline = input_py.generate_config.gen_timeline generate_config_pb.global_request_id = input_py.generate_config.global_request_id for i in range(len(input_py.generate_config.stop_words_list)): stop_words = generate_config_pb.stop_words_list.rows.add() stop_words.values.extend(input_py.generate_config.stop_words_list[i]) return input_pb def trans_multimodal_input(input_py: GenerateInput, input_pb: GenerateInputPB): for mm_input in input_py.mm_inputs: mm_input_pb = MultimodalInputPB() mm_input_pb.multimodal_url = mm_input.url mm_input_pb.multimodal_type = mm_input.mm_type mm_preprocess_config_pb = mm_input_pb.mm_preprocess_config mm_preprocess_config_pb.width = mm_input.config.width mm_preprocess_config_pb.height = mm_input.config.height mm_preprocess_config_pb.min_pixels = mm_input.config.min_pixels mm_preprocess_config_pb.max_pixels = mm_input.config.max_pixels mm_preprocess_config_pb.fps = mm_input.config.fps mm_preprocess_config_pb.min_frames = mm_input.config.min_frames mm_preprocess_config_pb.max_frames = mm_input.config.max_frames input_pb.multimodal_inputs.append(mm_input_pb) def trans_output(input_py: GenerateInput, outputs_pb: GenerateOutputsPB) -> GenerateOutputs: logging.debug("outputs_pb = " + str(outputs_pb)) outputs_py = GenerateOutputs() for output_pb in outputs_pb.generate_outputs: output_py = GenerateOutput() output_py.finished = output_pb.finished output_py.aux_info = AuxInfo(cost_time=output_pb.aux_info.cost_time_us / 1000.0, first_token_cost_time=output_pb.aux_info.first_token_cost_time_us / 1000.0, wait_time=output_pb.aux_info.wait_time_us / 1000.0, iter_count=output_pb.aux_info.iter_count, input_len=output_pb.aux_info.input_len, reuse_len=output_pb.aux_info.reuse_len, prefix_len=output_pb.aux_info.prefix_len, output_len=output_pb.aux_info.output_len, step_output_len=output_pb.aux_info.step_output_len, fallback_tokens=output_pb.aux_info.fallback_tokens, fallback_times=output_pb.aux_info.fallback_times, pd_sep=output_pb.aux_info.pd_sep) # TODO(xinfei.sxf) cum_log_probs is not right, ignore it temporarily if output_pb.aux_info.HasField('cum_log_probs'): output_py.aux_info.cum_log_probs = trans_tensor(output_pb.aux_info.cum_log_probs).tolist() if output_pb.aux_info.HasField('softmax_probs'): output_py.aux_info.softmax_probs = trans_tensor(output_pb.aux_info.softmax_probs).tolist() output_py.output_ids = trans_tensor(output_pb.output_ids) output_py.input_ids = input_py.token_ids.reshape(1, -1) if output_pb.HasField('hidden_states'): output_py.hidden_states = trans_tensor(output_pb.hidden_states) if output_pb.HasField('loss'): # when calculate_loss 1, result should be one element if input_py.generate_config.calculate_loss == 1: output_py.loss = trans_tensor(output_pb.loss)[0] else: output_py.loss = trans_tensor(output_pb.loss) if output_pb.HasField('logits'): output_py.logits = trans_tensor(output_pb.logits) if output_pb.HasField('all_probs'): output_py.all_probs = trans_tensor(output_pb.all_probs) outputs_py.generate_outputs.append(output_py) return outputs_py class ModelRpcClient(object): def __init__(self, config: GptInitModelParameters, address: Optional[str] = None): # 创建到服务器的连接 if not address: address = f'localhost:{g_worker_info.rpc_server_port}' self._addresses = [] # for test usage hack_ep_single_entry = bool(int(os.environ.get('HACK_EP_SINGLE_ENTRY', 0))) logging.info(f"hack ep single entry: {hack_ep_single_entry}") if (g_parallel_info.dp_size > 1) and (not hack_ep_single_entry): members_info_str = f"[world_rank: {g_parallel_info.world_rank}]"+ \ f"[tp_size: {g_parallel_info.tp_size}] all members: " + "{" members = get_gang_info().members for member in members: members_info_str += f"{member}\n" if member.local_rank % g_parallel_info.tp_size == 0: self._addresses.append(f'{member.ip}:{member.rpc_server_port}') members_info_str += "}" logging.info(f"{members_info_str}") else: self._addresses = [address] logging.info(f"client connect to rpc addresses: {self._addresses}") self.model_config = config async def enqueue(self, input_py: GenerateInput) -> AsyncGenerator[GenerateOutputs, None]: request_timeout_ms = input_py.generate_config.timeout_ms rpc_timeout_ms = self.model_config.max_rpc_timeout_ms \ if self.model_config.max_rpc_timeout_ms > 0 else MAX_GRPC_TIMEOUT_SECONDS * 1000 if request_timeout_ms == None or request_timeout_ms <= 0: grpc_timeout_seconds = rpc_timeout_ms / 1000 else: grpc_timeout_seconds = request_timeout_ms / 1000 input_py.generate_config.timeout_ms = (int)(grpc_timeout_seconds * 1000) input_pb = trans_input(input_py) response_iterator = None try: async with grpc.aio.insecure_channel(self._addresses[input_py.request_id % len(self._addresses)]) as channel: stub = RpcServiceStub(channel) response_iterator = stub.GenerateStreamCall(input_pb, timeout=grpc_timeout_seconds) # 调用服务器方法并接收流式响应 count = 0 async for response in response_iterator.__aiter__(): count += 1 yield trans_output(input_py, response) except grpc.RpcError as e: # TODO(xinfei.sxf) 非流式的请求无法取消了 if response_iterator: response_iterator.cancel() error_details = ErrorDetailsPB() metadata = e.trailing_metadata() if 'grpc-status-details-bin' in metadata and error_details.ParseFromString(metadata['grpc-status-details-bin']): logging.error(f"request: [{input_pb.request_id}] RPC failed: " f"{e.code()}, {e.details()}, detail error code is " f"{ExceptionType.from_value(error_details.error_code)}") raise FtRuntimeException(ExceptionType(error_details.error_code), error_details.error_message) else: logging.error(f"request: [{input_pb.request_id}] RPC failed: " f"error code is {e.code()}, detail is {e.details()}") if e.code() == StatusCode.DEADLINE_EXCEEDED: raise FtRuntimeException(ExceptionType.GENERATE_TIMEOUT, e.details()) elif e.code() == StatusCode.CANCELLED: raise FtRuntimeException(ExceptionType.CANCELLED_ERROR, e.details()) else: raise FtRuntimeException(ExceptionType.UNKNOWN_ERROR, e.details()) except Exception as e: logging.error(f'rpc unknown error:{str(e)}') raise e finally: if response_iterator: response_iterator.cancel()