maga_transformer/cpp/th_op/multi_gpu_gpt/RtpLLMOp.cc (218 lines of code) (raw):

#include <cstddef> #include <memory> #include <tuple> #include "autil/Log.h" #include "c10/util/intrusive_ptr.h" #include <grpcpp/grpcpp.h> #include <grpcpp/resource_quota.h> #include "maga_transformer/cpp/dataclass/EngineInitParameter.h" #include "maga_transformer/cpp/dataclass/LoadBalance.h" #include "maga_transformer/cpp/metrics/RtpLLMMetrics.h" #include "maga_transformer/cpp/utils/AssertUtils.h" #include "maga_transformer/cpp/th_op/GptInitParameter.h" #include "maga_transformer/cpp/th_op/multi_gpu_gpt/RtpLLMOp.h" using namespace std; namespace th = torch; namespace torch_ext { RtpLLMOp::RtpLLMOp() {} void RtpLLMOp::init(py::object model, py::object mm_process_engine, py::object propose_model, py::object token_processor) { RTP_LLM_LOG_DEBUG(__PRETTY_FUNCTION__); rtp_llm::EngineInitParams params = initModel(model); std::unique_ptr<rtp_llm::ProposeModelEngineInitParams> propose_params = initProposeModel(propose_model); pybind11::gil_scoped_release release; grpc_server_thread_ = std::thread(&RtpLLMOp::initRPCServer, this, std::move(params), std::move(mm_process_engine), std::move(propose_params), std::move(token_processor)); grpc_server_thread_.detach(); while (!is_server_ready_) { sleep(1); // wait 1s for server ready } } rtp_llm::EngineInitParams RtpLLMOp::initModel(py::object model) { try { auto [gpt_init_params, gpt_weight] = rtp_llm::prepareEngineInitParams(model); rtp_llm::EngineInitParams params(gpt_init_params, std::move(*gpt_weight)); if (gpt_init_params.tp_rank_ == 0) { // kmon metric init (void)rtp_llm::initKmonitorFactory(); auto kmon_tags = rtp_llm::getHippoTags(); kmon_tags.AddTag("dp_rank", std::to_string(gpt_init_params.dp_rank_)); params.metrics_reporter.reset(new kmonitor::MetricsReporter("", "", kmon_tags)); } return params; } catch (const std::exception& e){ RTP_LLM_FAIL("init engine params failed, error msg: %s", e.what()); return rtp_llm::EngineInitParams(); } } std::unique_ptr<rtp_llm::ProposeModelEngineInitParams> RtpLLMOp::initProposeModel(py::object propose_model) { try { if (propose_model.is_none()) { return nullptr; } std::unique_ptr<rtp_llm::ProposeModelEngineInitParams> params = nullptr; std::string sp_type = propose_model.attr("sp_type").cast<std::string>(); size_t gen_num_per_circle = propose_model.attr("gen_num_per_circle").cast<size_t>(); if (sp_type == "vanilla") { auto [gpt_init_params, gpt_weight] = rtp_llm::prepareEngineInitParams(propose_model, true); params = std::make_unique<rtp_llm::ProposeModelEngineInitParams>(sp_type, gen_num_per_circle, gpt_init_params, std::move(*gpt_weight)); } else if (sp_type == "mtp") { params = rtp_llm::prepareMTPEngineInitParams(propose_model); } else if (sp_type == "deterministic") { params = std::make_unique<rtp_llm::ProposeModelEngineInitParams>(sp_type, gen_num_per_circle); } else if (sp_type == "eagle") { RTP_LLM_FAIL("sp_type %s not support", sp_type.c_str()); } else { RTP_LLM_FAIL("sp_type %s not support", sp_type.c_str()); } return params; } catch (const std::exception& e ){ RTP_LLM_FAIL("init propose engine params failed, error msg: %s", e.what()); return nullptr; } } void RtpLLMOp::addLora(const std::string& adapter_name, py::object py_lora_a_weights, py::object py_lora_b_weights) { auto convert = rtp_llm::WeightsConverter(true); auto lora_a_weights = convert.convertLayerWeights_(py_lora_a_weights); auto lora_b_weights = convert.convertLayerWeights_(py_lora_b_weights); model_rpc_service_->addLora(adapter_name, *lora_a_weights, *lora_b_weights); } void RtpLLMOp::removeLora(const std::string& adapter_name) { model_rpc_service_->removeLora(adapter_name); } rtp_llm::LoadBalanceInfo RtpLLMOp::getLoadBalanceInfo() { return model_rpc_service_->getLoadBalanceInfo(); } rtp_llm::EngineScheduleInfo RtpLLMOp::getEngineScheduleInfo() { return model_rpc_service_->getEngineScheduleInfo(); } void RtpLLMOp::initRPCServer( const rtp_llm::EngineInitParams maga_init_params, py::object mm_process_engine, std::unique_ptr<rtp_llm::ProposeModelEngineInitParams> propose_params, py::object token_processor) { auto http_port = maga_init_params.gpt_init_parameter.http_port_; auto model_rpc_port = maga_init_params.gpt_init_parameter.model_rpc_port_; auto use_cache_store = maga_init_params.gpt_init_parameter.use_cache_store_; std::string server_address("0.0.0.0:" + std::to_string(model_rpc_port)); { pybind11::gil_scoped_acquire acquire; if (use_cache_store) { model_rpc_service_.reset(new rtp_llm::RemoteRpcServiceImpl()); } else { model_rpc_service_.reset(new rtp_llm::LocalRpcServiceImpl()); } grpc::Status grpc_status = model_rpc_service_->init(maga_init_params, std::move(mm_process_engine), std::move(propose_params)); if (!grpc_status.ok()) { RTP_LLM_FAIL("init rpc server failed, error msg: %s", grpc_status.error_message().c_str()); } std::string http_server_address("tcp:0.0.0.0:" + std::to_string(http_port)); http_server_.reset(new rtp_llm::HttpApiServer(model_rpc_service_->getEngine(), model_rpc_service_->getMultimodalProcessor(), http_server_address, maga_init_params, token_processor)); if (model_rpc_port < 0) { is_server_ready_ = true; return; } } grpc::ServerBuilder builder; builder.AddChannelArgument(GRPC_ARG_MAX_CONCURRENT_STREAMS, 100000); builder.AddListeningPort(server_address, grpc::InsecureServerCredentials()); builder.RegisterService(model_rpc_service_.get()); grpc_server_ = builder.BuildAndStart(); RTP_LLM_CHECK_WITH_INFO(grpc_server_ != nullptr, "grpc server start failed at address " + server_address); RTP_LLM_LOG_INFO("Server listening on %s", server_address.c_str()); is_server_ready_ = true; grpc_server_->Wait(); RTP_LLM_LOG_INFO("Server exit on %s", server_address.c_str()); } bool RtpLLMOp::ready() { return model_rpc_service_->ready(); } void RtpLLMOp::startHttpServer(py::object model_weights_loader, py::object lora_infos, py::object gang_info, py::object tokenizer, py::object render) { if (http_server_ == nullptr) { RTP_LLM_FAIL("normal HTTP Server nullptr error."); return; } if (http_server_->start(model_weights_loader, lora_infos, gang_info, tokenizer, render)) { RTP_LLM_LOG_INFO("normal HTTP Server listening on %s", http_server_->getListenAddr().c_str()); } else { RTP_LLM_FAIL("normal HTTP Server start fail."); } } void RtpLLMOp::updateSchedulerInfo(const std::string& scheduler_info) { model_rpc_service_->getEngine()->getScheduler().updateSchedulerInfo(scheduler_info); } void RtpLLMOp::stop() { int64_t STOP_TIMEOUT_MS = 60 * 1000; if (!is_server_shutdown_) { if (grpc_server_) { auto begin_wait_us = autil::TimeUtility::currentTimeInMicroSeconds(); while (auto onflight_request = model_rpc_service_->onflightRequestNum()) { RTP_LLM_LOG_INFO("rpc service has [%lu] onflight request, waitting 1s", onflight_request); sleep(1); if (autil::TimeUtility::currentTimeInMicroSeconds() - begin_wait_us > STOP_TIMEOUT_MS * 1000) { RTP_LLM_LOG_INFO("rpc service wait timeout, no more waiting"); break; } } RTP_LLM_LOG_INFO("Server shutdowning"); grpc_server_->Shutdown(); grpc_server_.reset(); } { pybind11::gil_scoped_release release; model_rpc_service_->stop(); pybind11::gil_scoped_acquire acquire; } model_rpc_service_.reset(); if (http_server_) { http_server_->stop(); http_server_.reset(); } is_server_shutdown_ = true; rtp_llm::stopKmonitorFactory(); } } RtpLLMOp::~RtpLLMOp() { stop(); } void registerRtpLLMOp(const py::module& m) { pybind11::class_<torch_ext::RtpLLMOp>(m, "RtpLLMOp") .def(pybind11::init<>()) .def("init", &torch_ext::RtpLLMOp::init, py::arg("model"), py::arg("mm_process_engine"), py::arg("propose_model"), py::arg("token_processor")) .def("start_http_server", &torch_ext::RtpLLMOp::startHttpServer, py::arg("model_weights_loader"), py::arg("lora_infos"), py::arg("gang_info"), py::arg("tokenizer"), py::arg("render")) .def("add_lora", &torch_ext::RtpLLMOp::addLora, py::arg("adapter_name"), py::arg("lora_a_weights"), py::arg("lora_b_weights")) .def("remove_lora", &torch_ext::RtpLLMOp::removeLora, py::arg("adapter_name")) .def("get_load_balance_info", &torch_ext::RtpLLMOp::getLoadBalanceInfo) .def("get_engine_schedule_info", &torch_ext::RtpLLMOp::getEngineScheduleInfo) .def("update_scheduler_info", &torch_ext::RtpLLMOp::updateSchedulerInfo) .def("stop", &torch_ext::RtpLLMOp::stop) .def("ready", &torch_ext::RtpLLMOp::ready); } } // namespace torch_ext