maga_transformer/cpp/model_rpc/PrefillRpcServer.cc (408 lines of code) (raw):

#include "autil/TimeUtility.h" #include "maga_transformer/cpp/utils/Cm2Config.h" #include "maga_transformer/cpp/model_rpc/QueryConverter.h" #include "maga_transformer/cpp/model_rpc/PrefillRpcServer.h" #include "maga_transformer/cpp/devices/utils/DebugUtils.h" #include <cstring> #include <memory> #include <unistd.h> #include <limits.h> using namespace std; using namespace autil::legacy; using grpc::Status; using grpc::ClientContext; namespace rtp_llm { #define CLIENT_GRPC_RET_IF_ERROR(prefill_context, state, error_code_value) \ if (!(state)) { \ auto new_error_code = error_code_value; \ string new_error_msg = "decode addr is " + prefill_context.decode_addr + ", "; \ new_error_msg += "execute time is " + std::to_string(prefill_context.executeTimeMs()) + "ms, "; \ new_error_msg += "request timeout is " + std::to_string(prefill_context.request_timeout_ms) + "ms, "; \ if (prefill_context.getStream()) { \ auto first_token_rt_ms = prefill_context.getStream()->getTimeInfo().first_token_rt_us / 1000; \ if (first_token_rt_ms) { \ new_error_msg += "stream first token rt is " + std::to_string(first_token_rt_ms) + "ms, "; \ } \ auto wait_time_ms = prefill_context.getStream()->getTimeInfo().wait_time_us / 1000; \ if (wait_time_ms) { \ new_error_msg += "stream wait time is " + std::to_string(first_token_rt_ms) + "ms, "; \ } \ } \ auto status = prefill_context.closeGrpcStream(); \ if (!status.ok()) { \ const auto& error_msg = status.error_message(); \ if (error_msg.find("Connect Failed") != std::string::npos) { \ new_error_code = ErrorCode::CONNECT_FAILED; \ prefill_context.closeGrpcConnection(); \ } else if (error_msg.find("No route to host") != std::string::npos) { \ new_error_code = ErrorCode::CONNECT_FAILED; \ prefill_context.closeGrpcConnection(); \ } else if (error_msg.find("Connection reset by peer") != std::string::npos) { \ new_error_code = ErrorCode::CONNECTION_RESET_BY_PEER; \ prefill_context.closeGrpcConnection(); \ } else if (error_msg.find("Connection timed out") != std::string::npos) { \ new_error_code = ErrorCode::CONNECT_TIMEOUT; \ prefill_context.closeGrpcConnection(); \ } else if (error_msg.find("Deadline Exceeded") != std::string::npos) { \ new_error_code = ErrorCode::DEADLINE_EXCEEDED; \ prefill_context.closeGrpcConnection(); \ } \ new_error_msg += error_msg; \ if (status.error_code() == grpc::StatusCode::RESOURCE_EXHAUSTED) { \ new_error_code = ErrorCode::DECODE_MALLOC_FAILED; \ } \ } else { \ if (prefill_context.client_stream) { \ new_error_msg += "server disconnected with status::ok"; \ } \ } \ if (prefill_context.getStream()) { \ prefill_context.getStream()->setStop(new_error_code, new_error_msg); \ } \ prefill_context.error_info = ErrorInfo(new_error_code, new_error_msg); \ prefill_context.error_status = serializeErrorMsg( \ prefill_context.request_key, prefill_context.error_info); \ return; \ } grpc::Status PrefillRpcServer::init(const EngineInitParams& maga_init_params, py::object mm_process_engine, std::unique_ptr<rtp_llm::ProposeModelEngineInitParams> propose_params) { meta_.reset(new PrefillRpcServerRuntimeMeta()); RTP_LLM_CHECK_WITH_INFO(maga_init_params.gpt_init_parameter.pd_separation_, "prefill's pd_separation must be true"); auto ret = RemoteRpcServer::init(maga_init_params, mm_process_engine, std::move(propose_params)); if (!ret.ok()) { return ret; } initLoadBalancer(); return grpc::Status::OK; } void PrefillRpcServer::initLoadBalancer() { auto config = makeConfig(); if (maga_init_params_.gpt_init_parameter.load_balance_policy_name_ == "RR") { load_balancer_ = std::make_shared<RRLoadBalancer>(); } else { load_balancer_ = std::make_shared<WRRLoadBalancer>(); } RTP_LLM_CHECK_WITH_INFO(load_balancer_->init(config), "load_balancer init failed"); RTP_LLM_LOG_INFO("load balancer init success"); } LoadBalancerInitParams PrefillRpcServer::makeConfig() { char* use_local_env = std::getenv("USE_LOCAL"); SubscribeServiceConfig subscribe_config; if (use_local_env) { // fake test char* remote_rpc_server_ip_env = std::getenv("REMOTE_RPC_SERVER_IP"); RTP_LLM_CHECK_WITH_INFO(remote_rpc_server_ip_env, "rpc server ip must be not empty"); vector<string> remote_addrs = split(string(remote_rpc_server_ip_env), ','); RTP_LLM_CHECK_WITH_INFO(!remote_addrs.empty(), "REMOTE_RPC_SERVER_IP contains no valid addresses"); decode_cluster_name_ = "LOCAL"; LocalSubscribeServiceConfig local_config; if (remote_addrs.size() > 1) { for (const string& addr : remote_addrs) { auto [ip, port_str] = split_ip_port(addr); RTP_LLM_CHECK_WITH_INFO(!ip.empty() && !port_str.empty(), "Invalid address format in REMOTE_RPC_SERVER_IP_LIST: " + addr); uint32_t port = parse_port(port_str); RTP_LLM_LOG_INFO("Adding remote rpc server addr: %s:%u", ip.c_str(), port); // rpc port, http port local_config.nodes.emplace_back(decode_cluster_name_, ip, port, port + 4); } } else { const auto& addr = remote_addrs.front(); auto [ip, port_str] = split_ip_port(addr); uint32_t port; if (ip.empty() || port_str.empty()) { RTP_LLM_LOG_WARNING("Using Deprecated method to get remote rpc server addr"); ip = remote_addrs.front(); port = maga_init_params_.gpt_init_parameter.remote_rpc_server_port_; } else { port = parse_port(port_str); } RTP_LLM_LOG_INFO("Adding remote rpc server addr: %s:%u", ip.c_str(), port); // rpc port, http port local_config.nodes.emplace_back(decode_cluster_name_, ip, port, port + 4); } subscribe_config.local_configs.push_back(local_config); } else { char* decode_cm2_config_env = std::getenv("RTP_LLM_DECODE_CM2_CONFIG"); RTP_LLM_CHECK_WITH_INFO(decode_cm2_config_env, "decode_cm2_config_env must be not empty"); string decode_cm2_config_str = string(decode_cm2_config_env); Cm2ClusterConfig decode_cm2_config; try { FromJsonString(decode_cm2_config, decode_cm2_config_str); } catch (autil::legacy::ExceptionBase &e) { RTP_LLM_CHECK_WITH_INFO("create json from str[%s] failed", decode_cm2_config_str.c_str()); } decode_cluster_name_ = decode_cm2_config.cluster_name; CM2SubscribeServiceConfig cm2_service_config; cm2_service_config.zk_host = decode_cm2_config.zk_host; cm2_service_config.zk_path = decode_cm2_config.zk_path; cm2_service_config.zk_timeout_ms = 10 * 1000; cm2_service_config.clusters = {decode_cm2_config.cluster_name}; subscribe_config.cm2_configs.push_back(cm2_service_config); } LoadBalancerInitParams params; params.subscribe_config = subscribe_config; params.update_interval_ms = 100; params.sync_status_interval_ms = maga_init_params_.gpt_init_parameter.sync_status_interval_ms_; return params; } ErrorInfo PrefillRpcServer::waitStreamBeforeRun(std::shared_ptr<GenerateStream> stream) { static int max_wait_timeout_us = maga_init_params_.gpt_init_parameter.prefill_max_wait_timeout_ms_; auto begin_time_us = currentTimeUs(); while (stream->waiting()) { usleep(100); auto current_time_us = currentTimeUs(); auto cost_time_us = current_time_us - begin_time_us; if (cost_time_us > max_wait_timeout_us) { string new_error_msg = "wait to run timeout, timeout is " + std::to_string(max_wait_timeout_us) + " us"; stream->setStop(ErrorCode::WAIT_TO_RUN_TIMEOUT, new_error_msg); return ErrorInfo(ErrorCode::WAIT_TO_RUN_TIMEOUT, new_error_msg); } } if (stream->stopped()) { return stream->statusInfo(); } return ErrorInfo::OkStatus(); } void PrefillRpcServer::getRpcConnection(PrefillGenerateContext& prefill_context) { RTP_LLM_LOG_DEBUG("request [%ld] get rpc connection", prefill_context.request_id); auto host = load_balancer_->chooseHost(decode_cluster_name_, prefill_context.rpc_context.request->generate_config().global_request_id()); if (!host || host->ip.empty()) { prefill_context.error_info = ErrorInfo(ErrorCode::GET_HOST_FAILED, "get host for decode cluster " + decode_cluster_name_ + " failed"); prefill_context.error_status = serializeErrorMsg(prefill_context.request_key, prefill_context.error_info); return; } auto decode_addr = host->ip + ":" + std::to_string(host->rpc_port); auto connect_status = resource_.rpc_pool.getConnection(decode_addr); if (!connect_status.ok()) { prefill_context.error_info = ErrorInfo(ErrorCode::GET_CONNECTION_FAILED, "get grpc connection for decode addr " + decode_addr + " failed"); prefill_context.error_status = serializeErrorMsg(prefill_context.request_key, prefill_context.error_info); return; } prefill_context.decode_addr = decode_addr; prefill_context.grpc_connection = connect_status.value(); RTP_LLM_LOG_DEBUG("request [%ld] get rpc connection done", prefill_context.request_id); } void PrefillRpcServer::multimodalProcess(PrefillGenerateContext& prefill_context) { auto input = QueryConverter::transQuery(prefill_context.rpc_context.request); input->generate_config->pd_separation = true; input->generate_config->force_disable_sp_run = true; prefill_context.generate_input = input; if (mm_processor_ != nullptr && input->multimodal_inputs) { auto result = mm_processor_->updateMultimodalFeatures(input); CLIENT_GRPC_RET_IF_ERROR(prefill_context, result.ok(), result.code()); auto mutable_request = const_cast<GenerateInputPB*>(prefill_context.rpc_context.request); mutable_request->clear_token_ids(); // TODO(xinfei.sxf) optimize copy for (size_t i = 0; i < input->input_ids->size(); i++) { mutable_request->add_token_ids(*input->input_ids->dataWithOffset<int32_t>(i)); } } } void PrefillRpcServer::remoteAllocateResource(PrefillGenerateContext& prefill_context) { RTP_LLM_LOG_DEBUG("request [%ld] start to remote allocate resource", prefill_context.request_id); prefill_context.client_context.reset(new ClientContext()); auto request_timeout_ms = prefill_context.request_timeout_ms; auto max_rpc_timeout_ms = maga_init_params_.gpt_init_parameter.max_rpc_timeout_ms_; auto final_timeout_ms = max_rpc_timeout_ms > 0 ? max_rpc_timeout_ms : MAX_GRPC_TIMEOUT_MS; final_timeout_ms = request_timeout_ms > 0 ? request_timeout_ms : final_timeout_ms; auto deadline = std::chrono::system_clock::now() + std::chrono::milliseconds(final_timeout_ms); prefill_context.client_context->set_deadline(deadline); prefill_context.client_stream = std::move( prefill_context.grpc_connection.stub->RemoteGenerate(prefill_context.client_context.get())); auto& client_stream = prefill_context.client_stream; GenerateRequestPB alloc_request; alloc_request.set_stage(RemoteStage::ALLOCATE); alloc_request.set_client_id(process_id_); alloc_request.set_request_id(prefill_context.request_id); // TODO(xinfei.sxf) reduce copy GenerateInputPB* new_request = new GenerateInputPB(*prefill_context.rpc_context.request); alloc_request.set_allocated_input(new_request); for(auto& addrs : prefill_context.prefill_worker_cache_store_addrs) { alloc_request.add_peer_addrs(addrs); } CLIENT_GRPC_RET_IF_ERROR(prefill_context, client_stream->Write(alloc_request), ErrorCode::REMOTE_ALLOCATE_RESOURCE_WRITE_FAILED); GenerateOutputsPB allocate_response; CLIENT_GRPC_RET_IF_ERROR(prefill_context, client_stream->Read(&allocate_response), ErrorCode::REMOTE_ALLOCATE_RESOURCE_READ_FAILED); RTP_LLM_LOG_DEBUG("request [%ld] remote allocate resource done", prefill_context.request_id); } void PrefillRpcServer::enqueueRequest(PrefillGenerateContext& prefill_context) { RTP_LLM_LOG_DEBUG("request [%ld] trans query", prefill_context.request_id); auto lora_guard = lora::LoraResourceGuard(engine_->getLoraManager(), prefill_context.generate_input->generate_config->adapter_name); RTP_LLM_LOG_DEBUG("request [%ld] trans to stream success", prefill_context.request_id); auto stream = engine_->enqueue(prefill_context.generate_input); prefill_context.setStream(stream); RTP_LLM_LOG_DEBUG("request [%ld] enqueue success", prefill_context.request_id); } void PrefillRpcServer::remoteLoadCacheStart(PrefillGenerateContext& prefill_context) { RTP_LLM_LOG_DEBUG("request [%ld] remote load cache", prefill_context.request_id); prefill_context.error_info = waitStreamBeforeRun(prefill_context.getStream()); if (prefill_context.error_info.hasError()) { prefill_context.error_status = serializeErrorMsg(prefill_context.request_key, prefill_context.error_info); return; } AtomicGuard request_guard(loading_cache_requests_); GenerateRequestPB load_request; load_request.set_client_id(process_id_); load_request.set_request_id(prefill_context.request_id); load_request.set_start_time(currentTimeUs()); CLIENT_GRPC_RET_IF_ERROR(prefill_context, prefill_context.client_stream->Write(load_request), ErrorCode::REMOTE_LOAD_KV_CACHE_FAILED); } void PrefillRpcServer::pollLocalOutput(PrefillGenerateContext& prefill_context) { RTP_LLM_LOG_DEBUG("request [%ld] start to poll local output", prefill_context.request_id); auto first_status = pollStreamOutput(prefill_context.server_context, prefill_context.request_key, prefill_context.rpc_context.writer, prefill_context.getStream()); if (!first_status.ok()) { prefill_context.error_status = first_status; return; } RTP_LLM_LOG_DEBUG("request [%ld] poll local output end", prefill_context.request_id); if (prefill_context.getStream()->finished()) { prefill_context.finished = true; prefill_context.error_status = grpc::Status::OK; } } void PrefillRpcServer::remoteLoadCacheEnd(PrefillGenerateContext& prefill_context) { GenerateOutputsPB load_response; CLIENT_GRPC_RET_IF_ERROR(prefill_context, prefill_context.client_stream->Read(&load_response), ErrorCode::REMOTE_LOAD_KV_CACHE_FAILED); auto error_code = transRPCErrorCode(load_response.error_info().error_code()); CLIENT_GRPC_RET_IF_ERROR(prefill_context, error_code == ErrorCode::NONE_ERROR, error_code); RTP_LLM_LOG_DEBUG("request [%ld] remote load cache done", prefill_context.request_id); prefill_context.getStream()->releaseResource(); } void PrefillRpcServer::remoteGenerate(PrefillGenerateContext& prefill_context) { RTP_LLM_LOG_DEBUG("request [%ld] start to remote generate", prefill_context.request_id); auto first_token = prefill_context.getStream()->currentExecuteTokens()[0]; GenerateRequestPB generate_request; generate_request.set_client_id(process_id_); generate_request.set_request_id(prefill_context.request_id); generate_request.set_first_generate_token_id(first_token); generate_request.set_stage(RemoteStage::GENERATE); CLIENT_GRPC_RET_IF_ERROR(prefill_context, prefill_context.client_stream->Write(generate_request), ErrorCode::REMOTE_GENERATE_FAILED); } void PrefillRpcServer::pollRemoteOutput(PrefillGenerateContext& prefill_context) { RTP_LLM_LOG_DEBUG("request [%ld] start to poll remote output", prefill_context.request_id); auto& request_id = prefill_context.request_id; GenerateOutputsPB response; auto initial_reuse_len = prefill_context.getStream()->initialReuseLength(); auto first_token_rt_us = prefill_context.getStream()->getTimeInfo().first_token_rt_us; while (prefill_context.client_stream->Read(&response)) { if (prefill_context.server_context->IsCancelled()) { RTP_LLM_LOG_WARNING("request [%ld] cancel by user", request_id); prefill_context.error_status = grpc::Status(grpc::StatusCode::CANCELLED, "request cancelled"); return; } if (response.generate_outputs_size() == 0) { RTP_LLM_LOG_ERROR("request [%ld] generate output size is 0", request_id); break; } for (size_t i = 0; i < response.generate_outputs_size(); i++) { response.mutable_generate_outputs(i)->mutable_aux_info()->set_pd_sep(true); } int64_t cost_time_us = currentTimeUs() - prefill_context.request_begin_time_us; for (size_t i = 0; i < response.generate_outputs_size(); i++) { response.mutable_generate_outputs(i)->mutable_aux_info()->set_first_token_cost_time_us(first_token_rt_us); response.mutable_generate_outputs(i)->mutable_aux_info()->set_cost_time_us(cost_time_us); response.mutable_generate_outputs(i)->mutable_aux_info()->set_reuse_len(initial_reuse_len); } if (!prefill_context.rpc_context.writer->Write(response)) { RTP_LLM_LOG_WARNING("request [%ld] write outputs pb failed", request_id); prefill_context.error_status = grpc::Status(grpc::StatusCode::INTERNAL, "request write outputs pb failed"); return; } } CLIENT_GRPC_RET_IF_ERROR(prefill_context, prefill_context.closeGrpcStream().ok(), ErrorCode::REMOTE_GENERATE_FAILED); prefill_context.getStream()->setFinishedWithoutLock(); } grpc::Status PrefillRpcServer::prepareAllocateResource(PrefillGenerateContext& prefill_context) { EXECUTE_STAGE_FUNC(getRpcConnection, prefill_context); EXECUTE_STAGE_FUNC(multimodalProcess, prefill_context); EXECUTE_STAGE_FUNC(remoteAllocateResource, prefill_context); return grpc::Status::OK; } EngineScheduleInfo PrefillRpcServer::getEngineScheduleInfo() { auto info = meta_->getEngineScheduleInfo(); auto last_schedule_time = engine_->getLastScheduleTime(); // in case last_schedule_delta is negative info.last_schedule_delta = std::max((int64_t)0, autil::TimeUtility::currentTimeInMilliSeconds() - last_schedule_time); return info; } grpc::Status PrefillRpcServer::GenerateStreamCall(grpc::ServerContext* server_context, const GenerateInputPB* request, grpc::ServerWriter<GenerateOutputsPB>* writer) { RTP_LLM_LOG_DEBUG("request [%ld] start generate stream call", request->request_id()); auto pd_separation = request->generate_config().max_new_tokens() > 1 && request->generate_config().num_beams() <= 1 && request->generate_config().num_return_sequences() <= 1 && request->generate_config().can_use_pd_separation(); if (!pd_separation) { return LocalRpcServer::GenerateStreamCall(server_context, request, writer); } AtomicGuardPtr request_guard = make_shared<AtomicGuard>(onflight_requests_); RPCContext rpc_context{request, writer}; auto prefill_context = PrefillGenerateContext(&this->resource(), rpc_context, request->generate_config().timeout_ms(), server_context, metrics_reporter_, meta_); prefill_context.onflight_requests = onflight_requests_; prefill_context.loading_cache_requests = loading_cache_requests_; auto max_retry_times = maga_init_params_.gpt_init_parameter.prefill_retry_times_; auto max_retry_timeout_ms = maga_init_params_.gpt_init_parameter.prefill_retry_timeout_ms_; try { EXECUTE_WITH_RETRY(prepareAllocateResource, prefill_context, max_retry_times, max_retry_timeout_ms); if (prefill_context.hasError()) { RTP_LLM_LOG_WARNING("request [%ld] prepare allocate resource failed after retry [%d] times, cost time ms [%ld], " "max retry time [%ld], max retry timeout ms [%ld]", prefill_context.request_id, prefill_context.retry_times, prefill_context.retry_cost_time_ms, max_retry_times + 1, max_retry_timeout_ms); if (maga_init_params_.gpt_init_parameter.pd_sep_enable_fallback_) { RTP_LLM_LOG_WARNING("request [%ld] fallback to local server"); request_guard.reset(); return LocalRpcServer::GenerateStreamCall(server_context, request, writer); } return grpc::Status::OK; } EXECUTE_STAGE_FUNC(enqueueRequest, prefill_context); EXECUTE_STAGE_FUNC(remoteLoadCacheStart, prefill_context); EXECUTE_STAGE_FUNC(pollLocalOutput, prefill_context); meta_->dequeue(prefill_context.request_id, prefill_context.getStream()); EXECUTE_STAGE_FUNC(remoteLoadCacheEnd, prefill_context); EXECUTE_STAGE_FUNC(remoteGenerate, prefill_context); EXECUTE_STAGE_FUNC(pollRemoteOutput, prefill_context); prefill_context.stat_info.nextStage(); } catch (const std::exception& e) { auto error_msg = "request [" + prefill_context.request_key + "] catch exception [" + e.what() + "]"; prefill_context.error_status = grpc::Status(grpc::StatusCode::INTERNAL, error_msg); return prefill_context.error_status; } catch (...) { auto error_msg = "request [" + prefill_context.request_key + "] catch unknown exception"; prefill_context.error_status = grpc::Status(grpc::StatusCode::INTERNAL, error_msg); return prefill_context.error_status; } RTP_LLM_LOG_DEBUG("request [%ld] all done", prefill_context.request_id); return grpc::Status::OK; } bool PrefillRpcServer::ready() { if (maga_init_params_.gpt_init_parameter.pd_sep_enable_fallback_) { return true; } if (!load_balancer_) { RTP_LLM_LOG_INFO("load balance is nullptr, server is not ready"); return false; } auto ret = load_balancer_->isReady(decode_cluster_name_); if (!ret) { RTP_LLM_LOG_INFO("load balancer is not ready now"); } return ret; } grpc::Status PrefillRpcServer::RemoteFinish(grpc::ServerContext* ontext, const RemoteFinishRequestPB* request, EmptyPB* response) { auto request_id = request->request_id(); resource_.cache_store->markRequestEnd(std::to_string(request_id)); return grpc::Status::OK; } } // namespace rtp_llm