maga_transformer/cpp/speculative_engine/SpeculativeEngine.cc (576 lines of code) (raw):

#include <algorithm> #include <cstdint> #include "maga_transformer/cpp/speculative_engine/SpeculativeEngine.h" #include "maga_transformer/cpp/utils/StatusUtil.h" #include "maga_transformer/cpp/stream/StreamCacheResource.h" #include "maga_transformer/cpp/normal_engine/NormalGenerateStream.h" #include "maga_transformer/cpp/cache/CacheConfigCreator.h" #include "maga_transformer/cpp/speculative_engine/SpeculativeOnlineAdaptor.h" #include "maga_transformer/cpp/speculative_engine/SpeculativeScheduler.h" #include "maga_transformer/cpp/speculative_engine/propose_executor/VanillaExecutor.h" #include "maga_transformer/cpp/speculative_engine/propose_executor/MTPExecutor.h" #include "maga_transformer/cpp/speculative_engine/score_executor/ScoreExecutor.h" #include "maga_transformer/cpp/system_prompt/SystemPromptConstructor.h" #include "maga_transformer/cpp/utils/Logger.h" #include "maga_transformer/cpp/devices/utils/DebugUtils.h" using namespace std; namespace rtp_llm { SpeculativeEngine::SpeculativeEngine(const EngineInitParams& engine_init_params, std::unique_ptr<ProposeModelEngineInitParams> propose_model_engine_init_params): EngineBase(engine_init_params), metrics_reporter_(engine_init_params.metrics_reporter), propose_model_params_(std::move(propose_model_engine_init_params)), score_model_params_(std::move(engine_init_params)), sp_type_(propose_model_params_->sp_type) {}; SpeculativeEngine::~SpeculativeEngine() { RTP_LLM_LOG_INFO("destory speculative engine"); (void)stop(); } std::shared_ptr<GenerateStream> SpeculativeEngine::enqueueMinFakeQuery(int32_t max_new_tokens, bool fake_hidden_states) { RTP_LLM_LOG_DEBUG("enqueue min fake query"); std::shared_ptr<GenerateInput> fake_input = make_shared<GenerateInput>(); fake_input->input_ids = device_->allocateBuffer( {rtp_llm::DataType::TYPE_INT32, {(size_t)1}, rtp_llm::AllocationType::HOST}); std::default_random_engine generator; std::uniform_int_distribution<int> distribution(0, score_model_params_.gpt_init_parameter.vocab_size_ - 1); for (size_t i = 0; i < fake_input->input_ids->size(); ++i) { *fake_input->input_ids->dataWithOffset<int32_t>(i) = distribution(generator); } fake_input->generate_config = make_shared<GenerateConfig>(); if (fake_hidden_states) { fake_input->generate_config->max_new_tokens = max_new_tokens + 1; } else { fake_input->generate_config->max_new_tokens = max_new_tokens; } fake_input->generate_config->top_k = 1; fake_input->begin_time_us = autil::TimeUtility::currentTimeInMicroSeconds(); fake_input->fake_query = true; auto stream = makeStream(fake_input); stream->setMetricsReporter(nullptr); if (fake_hidden_states) { auto dtype = rtp_llm::getDataType(score_model_params_.gpt_init_parameter.data_type_); auto fake_hidden_states = device_->allocateBuffer( {dtype, {1, (size_t)score_model_params_.gpt_init_parameter.hidden_size_}, rtp_llm::AllocationType::DEVICE}); stream->setReturnLastHiddenStates(true); BufferPtr new_tokens = device_->allocateBuffer( {rtp_llm::DataType::TYPE_INT32, {(size_t)1, 1}, rtp_llm::AllocationType::HOST}); *new_tokens->dataWithOffset<int32_t>(0) = distribution(generator); StreamUpdateInfo update_info{new_tokens, (int)1, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, fake_hidden_states}; stream->update(update_info); stream->setIsContextStream(false); } enqueue(stream); return stream; } absl::Status SpeculativeEngine::init() { RTP_LLM_LOG_INFO(__PRETTY_FUNCTION__); std::optional<WarmUpResult> warm_up_result = std::nullopt; if (score_model_params_.gpt_init_parameter.warm_up_) { // warm up const rtp_llm::GptInitParameter& score_gpt_params = score_model_params_.gpt_init_parameter; RTP_LLM_LOG_INFO("warm up (max_context_batch_size %d, max_seq_len %d calculate_loss %d) query begin", score_gpt_params.max_context_batch_size_, score_gpt_params.max_seq_len_, int(score_gpt_params.warm_up_with_loss_)); warm_up_result = warmUp(); RTP_LLM_LOG_INFO("warm up done, max runtime used memory: %ld bytes (%ld MiB), device reserved memory: %ld bytes (%ld MiB)", warm_up_result->max_used_memory, warm_up_result->max_used_memory / 1024 / 1024, warm_up_result->device_reserved_bytes, warm_up_result->device_reserved_bytes / 1024 / 1024); } RETURN_IF_STATUS_ERROR(initCacheManager(warm_up_result)); RTP_LLM_LOG_INFO("create cache manager done"); propose_executor_ = createProposeExecutor(score_model_params_, propose_model_params_, device_, resource_context_.propose_cache_manager, resource_context_.mtp_cache_managers, getLoraManager()); RTP_LLM_LOG_INFO("create speculative executor done"); score_executor_.reset( new ScoreExecutor(score_model_params_, device_, resource_context_.cache_manager, getLoraManager())); scheduler_.reset( new SpeculativeScheduler(score_model_params_.gpt_init_parameter, resource_context_.cache_manager, metrics_reporter_)); RTP_LLM_LOG_INFO("create fifo scheduler done"); online_adaptor_.reset(new SpeculativeOnlineAdaptor()); RTP_LLM_LOG_INFO("create online adaptor"); speculative_sampler_ = createSpeculativeSampler(propose_model_params_, device_); RTP_LLM_LOG_INFO("create speculative sampler"); speculative_updater_.reset( new SpeculativeUpdater(resource_context_, createSpeculativeUpdaterConfig(propose_model_params_))); RETURN_IF_STATUS_ERROR(startLoop()); if (device_->getDeviceProperties().tp_rank == 0) { initLoadBalance(); } return absl::OkStatus(); } void SpeculativeEngine::initLoadBalance() { RTP_LLM_LOG_INFO("init load balance start"); auto stream = enqueueMinFakeQuery(3, false); while(!stream->finished() && !stream->stopped()) { RTP_LLM_LOG_INFO("wait load balance init run over for 1s"); this_thread::sleep_for(std::chrono::seconds(1)); } RTP_LLM_LOG_INFO("init load balance done and (StepPerMin: %ld , StepLatencyUs: %ld)", step_recorder_.getStepPerMin(), step_recorder_.getStepLatency()); } absl::StatusOr<GenerateStreamPtr> SpeculativeEngine::preRun(const std::shared_ptr<GenerateInput>& generate_input, preRunMode mode) { std::shared_ptr<GenerateStream> score_stream = std::make_shared<NormalGenerateStream>( generate_input, score_model_params_.gpt_init_parameter, resource_context_, nullptr); std::shared_ptr<GenerateStream> propose_stream = nullptr; if (mode == preRunMode::prefill_warm_up) { score_stream->setPerfTest(true); } else if (mode == preRunMode::decode_warm_up) { score_stream->setIsContextStream(false); } else if (mode == preRunMode::build_system_prompt) { THROW_IF_STATUSOR_ERROR(score_stream->initKVBlock(0, 0)); }; if (propose_model_params_->gpt_model()) { propose_stream = std::make_shared<NormalGenerateStream>(*score_stream); } std::list<GenerateStreamPtr> score_streams{score_stream}; THROW_IF_STATUS_ERROR(score_executor_->normalProcess(score_streams)); if (propose_model_params_->gpt_model()) { THROW_IF_STATUS_ERROR(propose_executor_->normalProcess({propose_stream})); } return score_streams.front(); } absl::Status SpeculativeEngine::initCacheManager(std::optional<WarmUpResult> warm_up_result) { if (propose_model_params_->gpt_model()) { const auto& config = CacheConfigCreator::createSpConfig( score_model_params_.gpt_init_parameter, propose_model_params_->getGptInitParameter(), warm_up_result, propose_model_params_->isMTP()); auto scorer_cache_config = std::get<0>(config); auto proposer_cache_config = std::get<1>(config); resource_context_.cache_manager = make_shared<CacheManager>(scorer_cache_config, device_, false, metrics_reporter_); if (propose_model_params_->isMTP()) { auto layer_num = propose_model_params_->getGptInitParameter().num_layers_; RTP_LLM_LOG_INFO("mtp cache manager init use layer num : %d", layer_num); for (int i = 0; i < layer_num; i++) { RTP_LLM_CHECK(proposer_cache_config.layer_num == 1); resource_context_.mtp_cache_managers.push_back( std::make_shared<CacheManager>(proposer_cache_config, device_, false, metrics_reporter_) ); } } else { resource_context_.propose_cache_manager = make_shared<CacheManager>(proposer_cache_config, device_, false, metrics_reporter_); } } else { const auto& config = CacheConfigCreator::createConfig(score_model_params_.gpt_init_parameter, warm_up_result); resource_context_.cache_manager = make_shared<CacheManager>(config, device_, false, metrics_reporter_); } return absl::OkStatus(); } WarmUpResult SpeculativeEngine::warmUp() { const rtp_llm::GptInitParameter& socre_gpt_params = score_model_params_.gpt_init_parameter; std::shared_ptr<GenerateInput> fake_input = make_shared<GenerateInput>(); fake_input->input_ids = device_->allocateBuffer( {rtp_llm::DataType::TYPE_INT32, {(size_t)socre_gpt_params.max_seq_len_ - 1}, rtp_llm::AllocationType::HOST}); std::memset(fake_input->input_ids->data(), 0, fake_input->input_ids->sizeBytes()); fake_input->generate_config = make_shared<GenerateConfig>(); fake_input->generate_config->num_return_sequences = socre_gpt_params.max_context_batch_size_; fake_input->generate_config->calculate_loss = int(socre_gpt_params.warm_up_with_loss_); fake_input->generate_config->top_k = 2; fake_input->begin_time_us = autil::TimeUtility::currentTimeInMicroSeconds(); device_->setTraceMemory(true); score_executor_.reset(new ScoreExecutor(score_model_params_, device_, nullptr, nullptr, true)); if (propose_model_params_->isVanilla()) { propose_executor_.reset(new VanillaExecutor(propose_model_params_, device_, nullptr, nullptr, true)); } else if (propose_model_params_->isMTP()) { propose_executor_.reset(new MTPExecutor(propose_model_params_, device_, {nullptr}, nullptr, true)); } THROW_IF_STATUSOR_ERROR(preRun(fake_input, preRunMode::prefill_warm_up)); const auto device_status = device_->getDeviceStatus(); device_->setTraceMemory(false); (void)score_executor_.reset(nullptr); if (propose_model_params_->gpt_model()) { (void)propose_executor_.reset(nullptr); } return WarmUpResult({ device_status.device_memory_status.preserved_bytes, device_status.device_memory_status.max_consumed_bytes}); } absl::Status SpeculativeEngine::initSystemPrompt() { resource_context_.reuse_cache = score_model_params_.gpt_init_parameter.reuse_cache_; if (!score_model_params_.gpt_init_parameter.multi_task_prompt_tokens_.empty()) { resource_context_.reuse_cache = true; CHECK_AND_RETURN_REF(system_prompt_param, SystemPromptConstructor::construct( score_model_params_.gpt_init_parameter, this, resource_context_.cache_manager.get(), device_->getDeviceProperties().tp_rank == 0)); resource_context_.system_prompt.reset(new SystemPrompt(system_prompt_param)); } return absl::OkStatus(); } LoadBalanceInfo SpeculativeEngine::getLoadBalanceInfo() { auto kv_cache_info = resource_context_.cache_manager->getKVCacheInfo(); return LoadBalanceInfo{(int64_t)step_recorder_.getStepLatency(), (int64_t)step_recorder_.getStepCount(), (int64_t)step_recorder_.getStepPerMin(), (int64_t)kv_cache_info.available_kv_cache, (int64_t)kv_cache_info.total_kv_cache, (int64_t)scheduler_->onflightStreams()}; } absl::Status SpeculativeEngine::startLoop() { RTP_LLM_LOG_INFO("start init system prompt"); THROW_IF_STATUS_ERROR(initSystemPrompt()); RTP_LLM_LOG_INFO("init system prompt done"); RTP_LLM_LOG_INFO("start speculative engine loop"); running_ = true; loop_thread_ = std::thread(&SpeculativeEngine::loop, this); return absl::OkStatus(); } absl::Status SpeculativeEngine::stop() { RTP_LLM_LOG_INFO("stop speculative engine"); running_ = false; RETURN_IF_STATUS_ERROR(scheduler_->stop()); if (loop_thread_.joinable()) { loop_thread_.join(); } return absl::OkStatus(); } void SpeculativeEngine::loop() { RTP_LLM_LOG_INFO("loop begin"); device_->preRun(); while (running_) { auto status = step(); if (!status.ok()) { RTP_LLM_LOG_ERROR("step running error: %s", status.ToString().c_str()); THROW_IF_STATUS_ERROR(trySaveStepError()); } } } absl::Status SpeculativeEngine::trySaveStepError() const { return absl::UnimplementedError("can not save yet!"); } std::shared_ptr<GenerateStream> SpeculativeEngine::makeStream(const std::shared_ptr<GenerateInput>& input) { std::shared_ptr<GenerateStream> stream = std::make_shared<NormalGenerateStream>(input, score_model_params_.gpt_init_parameter, resource_context_, metrics_reporter_); return stream; } void SpeculativeEngine::enqueue(std::shared_ptr<GenerateStream>& stream) { (void)scheduler_->enqueue(stream); } std::shared_ptr<GenerateStream> SpeculativeEngine::enqueue(const std::shared_ptr<GenerateInput>& input) { std::shared_ptr<GenerateStream> stream = std::make_shared<NormalGenerateStream>( input, score_model_params_.gpt_init_parameter, resource_context_, metrics_reporter_); (void)scheduler_->enqueue(stream); return stream; } void SpeculativeEngine::tpSyncDisableSPRun(bool& all_streams_disable_sp_run) { if (device_->getDeviceProperties().tp_size <= 1) { return; } auto disable_sp_run = device_->allocateBuffer({rtp_llm::DataType::TYPE_INT32, {1}, rtp_llm::AllocationType::HOST}); auto disable_sp_run_ptr = disable_sp_run->data<int32_t>(); disable_sp_run_ptr[(size_t)0] = all_streams_disable_sp_run; device_->broadcast({{disable_sp_run}, 0}); device_->syncCommunication(false); device_->syncAndCheck(); all_streams_disable_sp_run = disable_sp_run_ptr[(size_t)0]; } void SpeculativeEngine::dpAndTpSyncNeedHiddenStates(bool& need_hidden_states) { const auto properties = device_->getDeviceProperties(); size_t world_size = properties.dp_size; if (world_size <= 1) { return; } size_t local_rank = properties.dp_rank; RTP_LLM_LOG_DEBUG("local_rank is %d", local_rank); auto flag = device_->allocateBuffer({rtp_llm::DataType::TYPE_INT32, {world_size}, rtp_llm::AllocationType::HOST}); auto flag_ptr = flag->data<int32_t>(); flag_ptr[(size_t)local_rank] = need_hidden_states; printBufferData(*flag, "before dp flag"); device_->allGather({{flag}, ParallelMode::DP}); device_->syncCommunication(false); device_->syncAndCheck(); printBufferData(*flag, "after dp flag"); need_hidden_states = (std::accumulate(flag_ptr, flag_ptr + world_size, 0) >= 1); } absl::Status SpeculativeEngine::step() { list<GenerateStreamPtr> streams; if (device_->getDeviceProperties().tp_rank == 0) { if (scheduler_->empty() || step_recorder_.empty()) { step_recorder_.reset(); step_recorder_.registerStep(autil::TimeUtility::currentTimeInMicroSeconds(), propose_executor_->reserveStep() / 2); } auto reserve_step = propose_executor_->reserveStep() + 1; CHECK_AND_ASSIGN(streams, scheduler_->schedule(reserve_step)); if (streams.empty()) { if (score_model_params_.gpt_init_parameter.dp_size_ > 1) { if (score_model_params_.gpt_init_parameter.pd_separation_ == 1) { enqueueMinFakeQuery(1, false); } else { enqueueMinFakeQuery(1, true); } } return absl::OkStatus(); } if (score_model_params_.gpt_init_parameter.dp_size_ > 1 && score_model_params_.gpt_init_parameter.pd_separation_ == 0) { bool has_hidden_states = false; for (auto stream : streams) { if (stream->getLastHiddenStates() != nullptr) { has_hidden_states = true; break; } } if (!has_hidden_states) { enqueueMinFakeQuery(1, true); return absl::OkStatus(); } } } for (auto& stream : streams) { RTP_LLM_LOG_DEBUG("pre stream[%d]: %s", stream->streamId(), stream->debugString().c_str()); } bool all_streams_disable_sp_run = !streams.empty() && std::all_of(streams.begin(), streams.end(), [](const auto& stream) { return stream->disableSpRun(); }); tpSyncDisableSPRun(all_streams_disable_sp_run); if (all_streams_disable_sp_run) { if (sp_type_ == "mtp") { for (auto& stream : streams) { stream->setReturnLastHiddenStates(true); } } return normStep(streams); } if (sp_type_ == "mtp") { // Make sure each stream is able to save the hidden states value in each calculation result for (auto& stream : streams) { stream->setReturnLastHiddenStates(true); } return noPrefillProposeStep(streams); } else { return prefillProposeStep(streams); } } absl::Status SpeculativeEngine::normStep(std::list<GenerateStreamPtr>& streams) { int64_t propose_begin_time_us = autil::TimeUtility::currentTimeInMicroSeconds(); int64_t score_begin_time_us = 0; int64_t sampler_begin_time_us = 0; int64_t update_begin_time_us = 0; int64_t total_propose_token_num = 0; int64_t total_accepted_token_num = 0; score_begin_time_us = autil::TimeUtility::currentTimeInMicroSeconds(); THROW_IF_STATUS_ERROR(score_executor_->normalProcess(streams)); sampler_begin_time_us = autil::TimeUtility::currentTimeInMicroSeconds(); update_begin_time_us = sampler_begin_time_us; total_propose_token_num = 0; total_accepted_token_num = streams.size(); for (auto& stream : streams) { stream->setReuseLength(stream->seqLength() - 1); stream->setFallbackPrefixLength(stream->reuseLength()); stream->setSpEditRun(false); RTP_LLM_LOG_DEBUG("stream [%d], topk = [%d], topp = [%f], propose_tokens = 0, accept_tokens = 1", stream->streamId(), stream->generateConfig()->top_k, stream->generateConfig()->top_p); } for (auto& stream : streams) { RTP_LLM_LOG_DEBUG("post stream[%d]: %s", stream->streamId(), stream->debugString().c_str()); } if (device_->getDeviceProperties().tp_rank == 0) { reportMetrics(propose_begin_time_us, score_begin_time_us, sampler_begin_time_us, update_begin_time_us, total_propose_token_num, total_accepted_token_num); for (auto& stream : streams) { if (stream->finished()) { step_recorder_.addStepCount(stream->iterCount()); } } step_recorder_.registerStep(autil::TimeUtility::currentTimeInMicroSeconds(), total_accepted_token_num / streams.size()); } return absl::OkStatus(); } absl::Status SpeculativeEngine::prefillProposeStep(std::list<GenerateStreamPtr>& streams) { int64_t propose_begin_time_us = autil::TimeUtility::currentTimeInMicroSeconds(); int64_t score_begin_time_us = 0; int64_t sampler_begin_time_us = 0; int64_t update_begin_time_us = 0; int64_t total_propose_token_num = 0; int64_t total_accepted_token_num = 0; ProposeOutput propose_output; CHECK_AND_ASSIGN(propose_output, propose_executor_->propose(streams)); RTP_LLM_LOG_DEBUG("propose_output: %s", propose_output.debugString().c_str()); if (propose_output.hasNoPropose()) { score_begin_time_us = autil::TimeUtility::currentTimeInMicroSeconds(); THROW_IF_STATUS_ERROR(score_executor_->normalProcess(streams)); sampler_begin_time_us = autil::TimeUtility::currentTimeInMicroSeconds(); update_begin_time_us = sampler_begin_time_us; total_propose_token_num = 0; total_accepted_token_num = streams.size(); for (auto& stream : streams) { stream->setReuseLength(stream->seqLength() - 1); stream->setFallbackPrefixLength(stream->reuseLength()); stream->setSpEditRun(false); RTP_LLM_LOG_DEBUG("stream [%d], topk = [%d], topp = [%f], propose_tokens = 0, accept_tokens = 1", stream->streamId(), stream->generateConfig()->top_k, stream->generateConfig()->top_p); } } else { score_begin_time_us = autil::TimeUtility::currentTimeInMicroSeconds(); CHECK_AND_RETURN_REF(score_output, score_executor_->score(streams, propose_output)); RTP_LLM_LOG_DEBUG("score_output: %s", score_output.debugString().c_str()); if (device_->getDeviceProperties().tp_rank == 0) { sampler_begin_time_us = autil::TimeUtility::currentTimeInMicroSeconds(); CHECK_AND_RETURN_REF(sampler_output, speculative_sampler_->sample(streams, propose_output, score_output)); RTP_LLM_LOG_DEBUG("sampler_output: %s", sampler_output.debugString().c_str()); update_begin_time_us = autil::TimeUtility::currentTimeInMicroSeconds(); RETURN_IF_STATUS_ERROR(speculative_updater_->update(streams, sampler_output)); for (const auto& output : sampler_output.outputs) { total_propose_token_num += output.propose_step; total_accepted_token_num += output.accepted_token_nums; } } } for (auto& stream : streams) { RTP_LLM_LOG_DEBUG("post stream[%d]: %s", stream->streamId(), stream->debugString().c_str()); } if (device_->getDeviceProperties().tp_rank == 0) { reportMetrics(propose_begin_time_us, score_begin_time_us, sampler_begin_time_us, update_begin_time_us, total_propose_token_num, total_accepted_token_num); for (auto& stream : streams) { if (stream->finished()) { step_recorder_.addStepCount(stream->iterCount()); } } step_recorder_.registerStep(autil::TimeUtility::currentTimeInMicroSeconds(), total_accepted_token_num / streams.size()); } return absl::OkStatus(); } bool SpeculativeEngine::checkAllHasHiddenStates(std::list<GenerateStreamPtr>& streams) { bool flag = true; for (auto& stream : streams) { if (stream->getLastHiddenStates() == nullptr) { flag = false; } } flag = !streams.empty() && flag; tpSyncDisableSPRun(flag); return flag; }; std::list<GenerateStreamPtr> SpeculativeEngine::extractFirstPrefillStreams(std::list<GenerateStreamPtr>& streams) { std::list<GenerateStreamPtr> need_prefill; for (auto& stream : streams) { if (stream->getLastHiddenStates() == nullptr) { need_prefill.push_back(stream); } } return need_prefill; } absl::Status SpeculativeEngine::noPrefillProposeStep(std::list<GenerateStreamPtr>& streams) { std::list<GenerateStreamPtr> propose_streams; std::list<GenerateStreamPtr> prefill_streams; if (device_->getDeviceProperties().tp_rank == 0) { for (auto& stream: streams) { if (stream->getLastHiddenStates() != nullptr) { propose_streams.emplace_back(stream); } else { prefill_streams.emplace_back(stream); } } for (auto& stream : propose_streams) { RTP_LLM_LOG_DEBUG("pre propose stream[%d]: %s", stream->streamId(), stream->debugString().c_str()); } for (auto& stream : prefill_streams) { RTP_LLM_LOG_DEBUG("pre prefill stream[%d]: %s", stream->streamId(), stream->debugString().c_str()); } } // base model generate current hidden states. // mtp model according to last hidden states from base model, // generate one token. // TODO(lidongjin) support multi mtp model. int64_t propose_begin_time_us = autil::TimeUtility::currentTimeInMicroSeconds(); int64_t total_propose_token_num = 0; int64_t total_accepted_token_num = 0; int64_t score_begin_time_us = 0; int64_t sampler_begin_time_us = 0; int64_t update_begin_time_us = 0; ProposeOutput propose_output; { bool skip_propose = propose_streams.empty(); tpSyncDisableSPRun(skip_propose); if (!skip_propose) { RTP_LLM_LOG_DEBUG("propose step"); CHECK_AND_ASSIGN(propose_output, propose_executor_->propose(propose_streams)); RTP_LLM_LOG_DEBUG("propose_output: %s", propose_output.debugString().c_str()); } else { RTP_LLM_LOG_DEBUG("skip propose"); } for (const GenerateStreamPtr& stream : prefill_streams) { size_t stream_id = stream->streamId(); propose_output.outputs[stream_id] = std::make_shared<SpeculativeExecutorStreamOutput>(); propose_output.outputs[stream_id]->propose_step = 0; } } // base model score propose new tokens. { RTP_LLM_LOG_DEBUG("score step"); score_begin_time_us = autil::TimeUtility::currentTimeInMicroSeconds(); CHECK_AND_RETURN_REF(score_output, score_executor_->score(streams, propose_output)); RTP_LLM_LOG_DEBUG("score_output: %s", score_output.debugString().c_str()); if (device_->getDeviceProperties().tp_rank == 0) { RTP_LLM_LOG_DEBUG("sample step"); sampler_begin_time_us = autil::TimeUtility::currentTimeInMicroSeconds(); CHECK_AND_RETURN_REF(sampler_output, speculative_sampler_->sample(streams, propose_output, score_output)); RTP_LLM_LOG_DEBUG("sampler_output: %s", sampler_output.debugString().c_str()); RTP_LLM_LOG_DEBUG("update step"); update_begin_time_us = autil::TimeUtility::currentTimeInMicroSeconds(); RETURN_IF_STATUS_ERROR(speculative_updater_->update(streams, sampler_output)); for (const auto& output : sampler_output.outputs) { total_propose_token_num += output.propose_step; total_accepted_token_num += output.accepted_token_nums; } } for (auto& stream : streams) { RTP_LLM_LOG_DEBUG("post stream[%d]: %s", stream->streamId(), stream->debugString().c_str()); } } if (device_->getDeviceProperties().tp_rank == 0) { reportMetrics(propose_begin_time_us, score_begin_time_us, sampler_begin_time_us, update_begin_time_us, total_propose_token_num, total_accepted_token_num); for (auto& stream : streams) { if (stream->finished()) { step_recorder_.addStepCount(stream->iterCount()); } } step_recorder_.registerStep(autil::TimeUtility::currentTimeInMicroSeconds(), total_accepted_token_num / streams.size()); } return absl::OkStatus(); } void SpeculativeEngine::reportMetrics(int64_t propose_begin_time_us, int64_t score_begin_time_us, int64_t sampler_begin_time_us, int64_t update_begin_time_us, int64_t total_propose_token_num, int64_t total_accepted_token_num) { if (!metrics_reporter_) { return; } int64_t current_time = autil::TimeUtility::currentTimeInMicroSeconds(); int64_t propose_time = score_begin_time_us - propose_begin_time_us; int64_t score_time = sampler_begin_time_us - score_begin_time_us; int64_t sampler_time = update_begin_time_us - sampler_begin_time_us; int64_t update_time = current_time - update_begin_time_us; int64_t total_step_time = current_time - propose_begin_time_us; RTP_LLM_LOG_DEBUG("total_step_time: %ld, propose_time: %ld, score_time: %ld, sampler_time: %ld, update_time: %ld", total_step_time, propose_time, score_time, sampler_time, update_time); RtpLLMSpeculativeEngineMetricsCollector collector{total_step_time, propose_time, score_time, sampler_time, update_time, total_propose_token_num, total_accepted_token_num}; metrics_reporter_->report<RtpLLMSpeculativeEngineMetrics, RtpLLMSpeculativeEngineMetricsCollector>(nullptr, &collector); } } // namespace rtp_llm