maga_transformer/cpp/stream/GenerateStream.cc (688 lines of code) (raw):
#include <cstddef>
#include <memory>
#include "autil/EnvUtil.h"
#include "maga_transformer/cpp/stream/GenerateStream.h"
#include "maga_transformer/cpp/dataclass/Query.h"
#include "maga_transformer/cpp/utils/AssertUtils.h"
#include "maga_transformer/cpp/metrics/RtpLLMMetrics.h"
#include "maga_transformer/cpp/core/Buffer.h"
#include "maga_transformer/cpp/core/Types.h"
#include "maga_transformer/cpp/core/torch_utils/BufferTorchUtils.h"
#include "maga_transformer/cpp/devices/DeviceFactory.h"
#include "maga_transformer/cpp/th_op/GptInitParameter.h"
using namespace std;
namespace rtp_llm {
GenerateStream::GenerateStream(const shared_ptr<GenerateInput>& input,
const rtp_llm::GptInitParameter& params,
const ResourceContext& resource_context,
kmonitor::MetricsReporterPtr metrics_reporter)
: generate_input_(input)
, max_seq_len_(params.max_seq_len_)
, vocab_size_(params.vocab_size_)
, stream_cache_resource_(this, resource_context, input->need_release_resource)
, need_release_resource_(input->need_release_resource)
, enable_fast_gen_(params.enable_fast_gen_)
, use_cache_store_(params.use_cache_store_)
, gen_timeline_(input->generate_config->gen_timeline)
, metrics_reporter_(metrics_reporter)
, special_tokens_(params.special_tokens_)
, output_mutex_(std::make_shared<std::mutex>())
, mm_position_ids_style_(PositionIdsStyle(params.mm_position_ids_style_))
, dtype_(rtp_llm::getDataType(params.data_type_))
, hidden_size_(params.hidden_size_)
{
if (!updatePrefix(resource_context.system_prompt)) {
return;
}
begin_time_us_ = input->begin_time_us;
device_ = rtp_llm::DeviceFactory::getDefaultDevice();
if (generate_input_->generate_config->calculate_loss && inputLength() > 1) {
loss_ = device_->allocateBuffer(
{rtp_llm::DataType::TYPE_FP32, {(size_t)inputLength() - 1}, rtp_llm::AllocationType::HOST}, {});
}
if (generate_input_->generate_config->return_softmax_probs) {
softmax_probs_ = device_->allocateBuffer(
{rtp_llm::DataType::TYPE_FP32, {(size_t)tileNum(), (size_t)max_seq_len_}, rtp_llm::AllocationType::HOST}, {});
memset(softmax_probs_->data(), 0, softmax_probs_->sizeBytes());
}
complete_token_ids_ = std::make_shared<CompleteTokenIds>(device_, tileNum(), max_seq_len_, params.seq_size_per_block_);
complete_token_ids_->init(input);
last_output_pos_ = seqLength();
max_chunk_len_ = seqLength();
cum_log_probs_ =
device_->allocateBuffer({rtp_llm::DataType::TYPE_FP32, {(size_t)tileNum()}, rtp_llm::AllocationType::HOST}, {});
memset(cum_log_probs_->data(), 0, cum_log_probs_->sizeBytes());
generate_status_.status = StreamState::WAITING;
sub_generate_status_.clear();
sub_generate_status_.resize(tileNum());
for (int i = 0; i < tileNum(); ++i) {
sub_generate_status_[i].status = StreamState::WAITING;
}
stream_cache_resource_.init(tileNum());
perf_test_ = autil::EnvUtil::getEnv("PERF_TEST", false);
// TODO(xinfei.sxf): need fix context block copy
perf_test_ = true;
setReturnAllProbs(generate_input_->generate_config->return_all_probs);
for (size_t i = 0; i < batchSize(); i++) {
StreamThinkInfo think_info(generate_input_->generate_config->in_think_mode,
generate_input_->generate_config->max_thinking_tokens,
generate_input_->generate_config->end_think_token_ids,
std::make_shared<StringContainDFA<size_t, int>>(generate_input_->generate_config->end_think_token_ids));
think_infos_.push_back(think_info);
}
}
void GenerateStream::resetBeginTime(int64_t begin_time_us) {
begin_time_us_ = begin_time_us;
}
bool GenerateStream::hasCacheKeys() const {
return stream_cache_resource_.hasCacheKeys();
}
const std::vector<int64_t>& GenerateStream::cacheKeys(int32_t batch_id) const {
return stream_cache_resource_.cacheKeys(batch_id);
}
absl::StatusOr<int> GenerateStream::acquireCapacity(int token_capacity) {
if (token_capacity <= 0) {
return absl::InternalError("token_capacity is <= 0");
}
if (isChunkStream()) {
// TODO(xinfei.sxf) add min_chunk_len ?
if (current_chunk_len_ == 0) {
current_chunk_len_ = reuse_length_;
}
auto remaining_token = max_chunk_len_ - current_chunk_len_;
last_chunk_len_ = current_chunk_len_;
if (token_capacity > remaining_token) {
current_chunk_len_ = max_chunk_len_;
return remaining_token;
} else {
current_chunk_len_ += token_capacity;
return token_capacity;
}
} else if (!isContextStream()) {
return 1;
}
RTP_LLM_CHECK(false);
return absl::InternalError("unexpected call");
}
void GenerateStream::cancel() {
setStop(ErrorCode::CANCELLED, "cancel stream");
}
absl::StatusOr<int> GenerateStream::initKVBlock(int token_capacity, size_t reserve_step) {
if (generate_status_.status == StreamState::WAITING) {
wait_time_us_ = autil::TimeUtility::currentTimeInMicroSeconds() - begin_time_us_;
} else if (generate_status_.status == StreamState::PAUSED) {
pause_time_us_ += autil::TimeUtility::currentTimeInMicroSeconds() - last_pause_us_;
}
return stream_cache_resource_.initKVBlock(token_capacity, reserve_step);
}
void GenerateStream::fakeInitKVBlock() {
stream_cache_resource_.fakeInitKVBlock();
}
absl::StatusOr<int> GenerateStream::incrKVBlock(int token_capacity, size_t reserve_step) {
return stream_cache_resource_.incrKVBlock(token_capacity, reserve_step);
}
int GenerateStream::tryReleaseKVBlock(int nums) {
auto release_blocks = stream_cache_resource_.tryReleaseKVBlock(nums);
incrFallbackBlock(release_blocks);
return release_blocks;
}
void GenerateStream::releaseResource() {
stream_cache_resource_.releaseResource();
}
void GenerateStream::setNeedReleaseResource(bool need_release_resource) {
need_release_resource_ = need_release_resource;
stream_cache_resource_.setNeedReleaseResource(need_release_resource);
}
int GenerateStream::nextNeedBlockNums(size_t reserve_step) const {
// TODO: maybe need fix when context and reuse
return stream_cache_resource_.singleBatchNeedBlocks(seqLength() + reserve_step) * batchSize();
}
void GenerateStream::incrFallbackBlock(int fallback_blocks) {
fallback_blocks_ += fallback_blocks;
fallback_times_ += 1;
}
std::shared_ptr<GenerateInput> GenerateStream::generateInput() const {
return generate_input_;
}
std::shared_ptr<GenerateConfig>& GenerateStream::generateConfig() const {
return generate_input_->generate_config;
}
bool GenerateStream::isStreaming() const {
if (numBeams() > 1) {
return false;
}
return generate_input_->generate_config->is_streaming;
}
int64_t GenerateStream::streamId() const {
return generate_input_->request_id;
}
int GenerateStream::loraId() const {
return generate_input_->lora_id;
}
std::string GenerateStream::adapterName() const {
return generate_input_->generate_config->adapter_name;
}
rtp_llm::SpecialTokens GenerateStream::specialTokens() const {
return special_tokens_;
}
int GenerateStream::tileNum() const {
return std::max(numBeams(), numReturnSequences());
}
int GenerateStream::batchSize() const {
return seqLength() == inputLength() && !perf_test_ ? 1 : tileNum();
}
int GenerateStream::numBeams() const {
return generate_input_->generate_config->num_beams;
}
int GenerateStream::numReturnSequences() const {
return generate_input_->generate_config->num_return_sequences;
}
bool GenerateStream::calculateLoss() const {
return loss_ && loss_index_ < inputLength() - 1;
}
bool GenerateStream::calculateSoftmaxProbs() const {
return generate_input_->generate_config->return_softmax_probs;
}
bool GenerateStream::returnLogits() const {
return generate_input_->generate_config->return_logits;
}
bool GenerateStream::returnCumLogProbs() const {
return generate_input_->generate_config->return_cum_log_probs;
}
bool GenerateStream::genTimeline() const {
return seqLength() <= inputLength() + 1 ? gen_timeline_ : false;
}
bool GenerateStream::updatePrefix(const std::shared_ptr<SystemPrompt>& system_prompt) {
if (system_prompt) {
auto prefix_param = system_prompt->getPromptParams(*generate_input_->generate_config);
if (!prefix_param.prompt_tokens.empty()) {
auto total_input_len = inputLength() + prefix_param.prompt_tokens.size();
if (total_input_len >= max_seq_len_) {
setStop(ErrorCode::LONG_PROMPT_ERROR, "after update prefix, total input len " + std::to_string(total_input_len)
+ " is greater than max seq len " + std::to_string(max_seq_len_));
return false;
}
generate_input_->updatePrefix(prefix_param.prompt_tokens);
}
}
return true;
}
size_t GenerateStream::maxSeqLen() const {
return max_seq_len_;
}
int GenerateStream::inputLength() const {
return generate_input_->inputLength();
}
int GenerateStream::currentChunkLen() const {
return current_chunk_len_;
}
void GenerateStream::resetChunkLen(int chunk_len, int max_chunk_len) {
last_chunk_len_ = 0;
current_chunk_len_ = chunk_len;
max_chunk_len_ = max_chunk_len;
}
int GenerateStream::seqLength() const {
return complete_token_ids_->seqLength();
}
int GenerateStream::adjustedCommonLen() const {
return tileNum() == 1 ? seqLength() : inputLength() / seqSizePerBlock() * seqSizePerBlock();
}
int GenerateStream::seqSizePerBlock() const {
return stream_cache_resource_.seqSizePerBlock();
}
int GenerateStream::contextLength() const {
int begin_pos = prefixLength();
int end_pos = isChunkStream() ? currentChunkLen() : seqLength();
return end_pos - begin_pos;
}
int GenerateStream::inputPrefixLength() const {
return generate_input_->prefix_length;
}
int GenerateStream::prefixLength() const {
if (fallback_prefix_length_) {
return fallback_prefix_length_;
} else if (last_chunk_len_) {
return last_chunk_len_;
}
return reuse_length_;
}
int GenerateStream::reuseLength() const {
return reuse_length_;
}
int GenerateStream::initialReuseLength() const {
return initial_reuse_length_;
}
void GenerateStream::setReuseLength(int reuse_length) {
reuse_length_ = reuse_length;
if (generate_input_->mm_locs) {
auto& locs = generate_input_->mm_locs.value();
for (int i = locs->size() - 1; i >= 0; --i) {
if (reuse_length_ > *locs->dataWithOffset<int32_t>(i)) {
reuse_mm_length_ = i + 1;
break;
}
}
}
}
void GenerateStream::setInitialReuseLength(int initial_reuse_length) {
initial_reuse_length_ = initial_reuse_length;
}
int GenerateStream::fallbackPrefixLength() const {
return fallback_prefix_length_;
}
void GenerateStream::setFallbackPrefixLength(int fallback_prefix_length) {
fallback_prefix_length_ = fallback_prefix_length;
}
void GenerateStream::incLastOutputPos() {
last_output_pos_++;
}
bool GenerateStream::isContextStream() const {
return is_context_stream_;
}
bool GenerateStream::isChunkStream() const {
return enable_fast_gen_ && current_chunk_len_ < max_chunk_len_;
}
const rtp_llm::BufferPtr& GenerateStream::cumLogProbs() const {
return cum_log_probs_;
}
const rtp_llm::BufferPtr& GenerateStream::completeTokenIds() {
return complete_token_ids_->completeTokenIds();
}
std::vector<int> GenerateStream::completeTokenIdsVec(int batch_idx) {
RTP_LLM_CHECK(batch_idx < tileNum());
return complete_token_ids_->completeTokenIdsVec(batch_idx);
}
std::vector<int> GenerateStream::commonCompleteTokenIdsVec(int batch_idx) {
RTP_LLM_CHECK(batch_idx < tileNum());
return complete_token_ids_->commonCompleteTokenIdsVec(batch_idx);
}
int GenerateStream::currentExecuteTokenSize() {
return currentExecuteTokens(0).size() * batchSize();
}
std::vector<torch::Tensor> GenerateStream::multimodalFeatures() const {
if (generate_input_->multimodal_features) {
auto& features = generate_input_->multimodal_features.value();
return std::vector<torch::Tensor>(features.begin() + reuse_mm_length_, features.end());
} else {
return std::vector<torch::Tensor>();
}
}
int GenerateStream::multimodalFeaturesLength() const {
return multimodalFeatures().size() * batchSize();
}
rtp_llm::BufferPtr GenerateStream::multimodalLocations() const {
if (!generate_input_->mm_locs) {
return nullptr;
}
auto& mm_locs = generate_input_->mm_locs.value();
return mm_locs->slice(reuse_mm_length_, mm_locs->size() - reuse_mm_length_);
}
vector<vector<int>> GenerateStream::multimodalIntervals() const {
if (!generate_input_->mm_locs && !generate_input_->multimodal_features) {
return {};
}
vector<vector<int>> res;
auto locs = generate_input_->mm_locs.value();
auto features = generate_input_->multimodal_features.value();
for (int i = 0;i < locs->size();++i) {
res.emplace_back(vector<int>({*locs->dataWithOffset<int>(i), int(features[i].sizes()[0])}));
}
return res;
}
vector<int> GenerateStream::textTokensMask() const {
if (!generate_input_->text_tokens_mask) {
return {};
}
auto token_masks = rtp_llm::buffer2vector<int>(*generate_input_->text_tokens_mask.value());
if (reuseLength() > 0) {
return vector<int>(token_masks.begin() + reuseLength(), token_masks.end());
} else {
return token_masks;
}
}
rtp_llm::BufferPtr GenerateStream::generateContextPositionIds(rtp_llm::DeviceBase* device) {
optional<vector<rtp_llm::BufferPtr>> position_ids_buffer = nullopt;
if (generate_input_->mm_position_ids.has_value()) {
position_ids_buffer = rtp_llm::torchTensorVec2BufferVec(generate_input_->mm_position_ids.value());
}
context_position_ids_ = PositionIdsGenerator::generatePositionIds(device, generate_input_->inputLength(),
mm_position_ids_style_, generate_input_->mm_locs, position_ids_buffer);
return context_position_ids_.value();
}
void GenerateStream::generateNextPositionId(int32_t* now_pos) {
if (!context_position_ids_) {
return;
}
PositionIdsGenerator::generateNextPositionId(now_pos, seqLength(), mm_position_ids_style_, context_position_ids_.value());
}
vector<int> GenerateStream::currentExecuteTokens(int batch_idx) const {
// TODO(xinfei.sxf) 在query部分回退,重运行case下,这个不对
if (isContextStream()) {
return complete_token_ids_->contextTokens(batch_idx, prefixLength(), contextLength());
} else {
return complete_token_ids_->currentExecuteTokens(batch_idx);
}
}
void GenerateStream::step() {
// iter_count represents the times of the stream participates in running
iter_count_++;
if (isContextStream()) {
setFallbackPrefixLength(0);
}
}
int64_t GenerateStream::getTimeoutMs() const {
return generate_input_->generate_config->timeout_ms;
}
void GenerateStream::checkTimeout() {
auto running_time_ms = (autil::TimeUtility::currentTimeInMicroSeconds() - begin_time_us_) / 1000;
auto timeout_ms = getTimeoutMs();
if (timeout_ms > 0 && timeout_ms < running_time_ms) {
stopAndRelease(ErrorCode::GENERATE_TIMEOUT,
"query has been running " + std::to_string(running_time_ms) + " ms, "
+ "timeout_ms = " + std::to_string(timeout_ms) + ", it's timeout");
}
}
void GenerateStream::setStopWithoutLock(ErrorCode error_code, const std::string& error_msg) {
auto cost_time_ms = (autil::TimeUtility::currentTimeInMicroSeconds() - begin_time_us_) / 1000;
RTP_LLM_LOG_WARNING("stop stream [%d], error msg: [%s], current state [%s], "
"input len [%d], seq len [%d], timeout [%ld] ms, running [%ld] ms",
streamId(), error_msg.c_str(), StreamStateToString(generate_status_.status).c_str(),
inputLength(), seqLength(), getTimeoutMs(), cost_time_ms);
generate_status_.status = StreamState::STOPPED;
generate_status_.error_info = ErrorInfo(error_code, error_msg);
}
void GenerateStream::setStop(ErrorCode error_code, const std::string& error_msg) {
std::lock_guard<std::mutex> lock(*output_mutex_);
setStopWithoutLock(error_code, error_msg);
}
void GenerateStream::stopAndRelease(ErrorCode error_code, const std::string& error_msg) {
setStop(error_code, error_msg);
releaseResource();
}
ErrorInfo GenerateStream::statusInfo() {
std::lock_guard<std::mutex> lock(*output_mutex_);
return generate_status_.error_info;
}
void GenerateStream::setPaused() {
// TODO(xinfei.sxf) fix mutex name
std::lock_guard<std::mutex> lock(*output_mutex_);
if (stoppedWithoutLock()) {
return;
}
is_context_stream_ = true;
generate_status_.status = StreamState::PAUSED;
last_pause_us_ = autil::TimeUtility::currentTimeInMicroSeconds();
}
bool GenerateStream::setRunning() {
std::lock_guard<std::mutex> lock(*output_mutex_);
if (stoppedWithoutLock()) {
return false;
}
generate_status_.status = StreamState::RUNNING;
return true;
}
void GenerateStream::setFinishedWithoutLock() {
generate_status_.status = StreamState::FINISHED;
for (int i = 0; i < tileNum(); ++i) {
sub_generate_status_[i].status = StreamState::FINISHED;
}
}
bool GenerateStream::stoppedWithoutLock() {
return generate_status_.status == StreamState::STOPPED;
}
bool GenerateStream::stopped() {
std::lock_guard<std::mutex> lock(*output_mutex_);
return generate_status_.status == StreamState::STOPPED;
}
bool GenerateStream::waiting() {
std::lock_guard<std::mutex> lock(*output_mutex_);
return generate_status_.status == StreamState::WAITING;
}
bool GenerateStream::paused() {
std::lock_guard<std::mutex> lock(*output_mutex_);
return generate_status_.status == StreamState::PAUSED;
}
std::string GenerateStream::stopReason() {
std::lock_guard<std::mutex> lock(*output_mutex_);
return generate_status_.error_info.ToString();
}
bool GenerateStream::finishedWithoutLock() {
return generate_status_.status == StreamState::FINISHED;
}
bool GenerateStream::running() {
return generate_status_.status == StreamState::RUNNING;
}
void GenerateStream::cancelIfNotRunning() {
std::lock_guard<std::mutex> lock(*output_mutex_);
if (generate_status_.status == StreamState::WAITING
|| generate_status_.status == StreamState::REMOTE_RUNNING) {
auto cost_time_ms = (autil::TimeUtility::currentTimeInMicroSeconds() - begin_time_us_) / 1000;
RTP_LLM_LOG_WARNING("stop stream: %d %s, input len [%d], seq len [%d], timeout: [%ld] ms, running [%ld] ms",
streamId(), "cancel stream in waiting or remote running",
inputLength(), seqLength(),
getTimeoutMs(), cost_time_ms);
generate_status_.status = StreamState::STOPPED;
generate_status_.error_info = ErrorInfo(ErrorCode::CANCELLED, "cancel stream in waiting or remote running");
}
}
bool GenerateStream::finished() {
std::lock_guard<std::mutex> lock(*output_mutex_);
return generate_status_.status == StreamState::FINISHED;
}
bool GenerateStream::needRemoteGenerate() const {
return need_remote_generate_;
}
void GenerateStream::setRemoteGenerate() {
generate_status_.status = StreamState::REMOTE_RUNNING;
}
size_t GenerateStream::iterCount() const {
return iter_count_;
}
void GenerateStream::setKVCache(const BatchKVCacheResource& kv_cache_resource) {
stream_cache_resource_.setKVCache(kv_cache_resource);
}
const BatchKVCacheResource& GenerateStream::kvCache() const {
return stream_cache_resource_.kvCache();
}
const ResourceContext& GenerateStream::resourceContext() const {
return stream_cache_resource_.resourceContext();
}
size_t GenerateStream::maxBlockSize() const {
return stream_cache_resource_.maxBlockSize();
}
size_t GenerateStream::maxTokenNum() const {
return std::min(max_seq_len_,
generate_input_->generate_config->max_new_tokens + generate_input_->inputLength());
}
bool GenerateStream::needFinish() {
return seqLength() >= maxTokenNum() || needFinishBySPTokens();
}
bool GenerateStream::needFinishBySPTokens() {
matchEosToken();
matchStopWordsList();
// num beams, finished by batch 0
if (numBeams() != 1) {
return sub_generate_status_[0].status == StreamState::FINISHED;
}
// num sequence, finished by all batch
return std::all_of(sub_generate_status_.begin(), sub_generate_status_.end(), [](GenerateStatus& generate_status) {
return generate_status.status == StreamState::FINISHED;
});
}
void GenerateStream::matchEosToken() {
for (int i = 0; i < tileNum(); ++i) {
matchEosToken(i);
}
}
void GenerateStream::matchEosToken(int batch_id) {
if (complete_token_ids_->matchEosToken(batch_id, special_tokens_.eos_token_id_)) {
sub_generate_status_[batch_id].status = StreamState::FINISHED;
}
}
std::vector<int> GenerateStream::getLatestTokens(size_t token_num) {
return complete_token_ids_->getLatestTokens(token_num);
}
void GenerateStream::matchStopWordsList() {
if (seqLength() < generate_input_->generate_config->min_new_tokens + inputLength()) {
return;
}
if (seqLength() == inputLength()) {
return;
}
for (int i = 0; i < tileNum(); ++i) {
matchStopWordsList(i);
}
}
void GenerateStream::matchStopWordsList(int batch_id) {
// note: stop_words_list in generate_config contains stop_words_list in special_tokens
bool match = false;
for (auto& stop_words : generate_input_->generate_config->stop_words_list) {
if (complete_token_ids_->matchStopWordsList(batch_id, stop_words)) {
match = true;
break;
}
}
if (match) {
sub_generate_status_[batch_id].status = StreamState::FINISHED;
}
}
void GenerateStream::update(const StreamUpdateInfo& update_info) {
std::lock_guard<std::mutex> lock(*output_mutex_);
RTP_LLM_LOG_DEBUG("stream [%ld] update", streamId());
is_context_stream_ = false;
if (stoppedWithoutLock()) {
return;
}
const auto& new_tokens = update_info.new_tokens;
auto num_new_tokens = update_info.num_new_tokens;
int error_token_id = 0;
if (!complete_token_ids_->update(new_tokens, begin_time_us_, num_new_tokens, generate_input_->inputLength(), maxTokenNum(), vocab_size_, numBeams(), streamId(), error_token_id)) {
setStopWithoutLock(ErrorCode::OUT_OF_VOCAB_RANGE,
"output token id:" + std::to_string(error_token_id) +
" out of vocab size: " + std::to_string(vocab_size_));
return;
}
// TODO(xinfei.sxf) fix this (update_queue)
updateOutput(update_info);
}
// beam_idx: [beam_width] int, the element must less than beam_width.
void GenerateStream::beamSearchKvCacheUpdate(rtp_llm::BufferPtr beam_idx) {
auto beam_idx_vec = rtp_llm::buffer2vector<int>(*beam_idx);
RTP_LLM_CHECK(beam_idx_vec.size() == tileNum());
stream_cache_resource_.beamSearchKvCacheUpdate(beam_idx_vec);
}
void GenerateStream::setLoss(const rtp_llm::Buffer& loss) {
RTP_LLM_CHECK(loss_index_ + loss.size() < inputLength());
device_->copy({loss_->view(loss_index_, loss.size()), loss});
loss_index_ += loss.size();
}
void GenerateStream::setSoftmaxProbs(const rtp_llm::Buffer& softmax_probs, int start_pos) {
RTP_LLM_CHECK(softmax_probs.dim() == 2);
RTP_LLM_CHECK(softmax_probs.shape()[0] == tileNum());
for (int i = 0; i < tileNum(); ++i) {
device_->copy({(*softmax_probs_)[i].view(start_pos, softmax_probs.shape()[1]), softmax_probs[i]});
}
}
rtp_llm::BufferPtr GenerateStream::getLoss() {
return loss_;
}
rtp_llm::BufferPtr GenerateStream::getLastHiddenStates() {
return last_hidden_states_;
}
rtp_llm::BufferPtr GenerateStream::getSoftmaxProbs() {
return softmax_probs_;
}
void GenerateStream::setMetricsReporter(kmonitor::MetricsReporterPtr metrics_reporter) {
metrics_reporter_ = metrics_reporter;
}
void GenerateStream::reportMetric() {
bool cancelled = statusInfo().code() == ErrorCode::CANCELLED;
bool timeout = statusInfo().code() == ErrorCode::GENERATE_TIMEOUT;
if (metrics_reporter_) {
RtpLLMStreamMetricsCollector collector;
collector.qps = true;
collector.cancel_qps = cancelled;
collector.error_qps = stopped() && !cancelled;
collector.is_streaming_qps = generate_input_->generate_config->is_streaming;
collector.not_streaming_qps = !generate_input_->generate_config->is_streaming;
if (finished() || cancelled || timeout) {
collector.reuse_length = initial_reuse_length_;
collector.input_token_length = inputLength();
collector.output_token_length = outputTokenLen();
collector.iterate_count = iter_count_;
collector.query_batch_size = tileNum();
collector.total_latency_us = autil::TimeUtility::currentTimeInMicroSeconds() - begin_time_us_;
collector.first_token_latency_us = complete_token_ids_->firstTokenLatencyUs();
RTP_LLM_LOG_DEBUG("stream [%ld] report first latency us = %ld", streamId(), collector.first_token_latency_us);
collector.wait_latency_us = wait_time_us_;
collector.pause_latency_us = pause_time_us_;
collector.fallback_tokens = fallback_blocks_ * seqSizePerBlock();
collector.fallback_times = fallback_times_;
collector.batch_with_prefill_times = batch_with_prefill_times_;
collector.batch_with_prefill_len = batch_with_prefill_len_;
collector.malloc_failed_times = stream_cache_resource_.mallocFailedTimes();
if (timeout) {
collector.timeout_latency_us = getTimeoutMs() * 1000;
}
}
// pass tag will cause default tags deep copy
static kmonitor::MetricsTags timeout_tag("timeout", "true");
metrics_reporter_->report<RtpLLMStreamMetrics, RtpLLMStreamMetricsCollector>(timeout ? &timeout_tag : nullptr,
&collector);
}
}
std::string GenerateStream::debugString() const {
std::stringstream debug_string;
debug_string << "GenerateStream {"
<< "generate_input:" << generate_input_->debugString() << ", max_seq_len:" << max_seq_len_
<< ", input_length:" << inputLength() << ", seq_length:" << seqLength()
<< ", reuse_length:" << reuse_length_ << ", current_chunk_len:" << current_chunk_len_
<< ", last_chunk_len_:" << last_chunk_len_ << ", max_chunk_len_:" << max_chunk_len_
<< ", batch_size:" << batchSize()
<< ", tile_num:" << tileNum()
<< ", need_release_resource: " << need_release_resource_
<< ", fallback_prefix_length: " << fallback_prefix_length_
<< ", sp_edit_search_index: " << sp_edit_search_index_
<< ", mtp token indices" << mtp_token_index_;
if (last_hidden_states_) {
debug_string << ", hidden_state_token_num: " << last_hidden_states_->shape()[0];
}
debug_string << ", complete_token_ids: [";
for (size_t i = 0; i < tileNum(); i++) {
debug_string << complete_token_ids_->toString(i) << ",";
}
debug_string << ", cum_log_probs: " << cum_log_probs_->debugStringWithData<float>();
debug_string << ", stream_cache_resource: "<< stream_cache_resource_.debugString();
debug_string << "}";
return debug_string.str();
}
int GenerateStream::reuseBlockSize() const {
int reuse_length = reuseLength();
int seq_size_per_block = seqSizePerBlock();
return reuse_length / seq_size_per_block;
}
void GenerateStream::setSeqLength(int seq_length) {
complete_token_ids_->setSeqLength(seq_length);
}
void GenerateStream::setPerfTest(bool perf_test) {
perf_test_ = perf_test;
}
void GenerateStream::setIsContextStream(bool is_context_stream) {
is_context_stream_ = is_context_stream;
}
StreamCacheResource& GenerateStream::streamCacheResource() {
return stream_cache_resource_;
}
void GenerateStream::CopyOnWrite(const GenerateStream& other_stream, bool copy_loss) {
complete_token_ids_ = make_shared<CompleteTokenIds>(*other_stream.complete_token_ids_);
cum_log_probs_ = device_->clone({*other_stream.cum_log_probs_, rtp_llm::AllocationType::HOST});
if (other_stream.calculateLoss() && copy_loss) {
loss_ = device_->clone({*other_stream.loss_, rtp_llm::AllocationType::HOST});
} else {
loss_ = nullptr;
}
stream_cache_resource_.setStream(this);
}
GenerateStream::TimeInfo GenerateStream::getTimeInfo() {
return {begin_time_us_, wait_time_us_,
complete_token_ids_->firstTokenTimeUs(),
complete_token_ids_->firstTokenLatencyUs()};
}
bool GenerateStream::queryPdSep() const {
return generate_input_->generate_config->pd_separation;
}
void GenerateStream::incBatchWithPrefillTimes(int32_t times) {
batch_with_prefill_times_ += times;
}
void GenerateStream::incBatchWithPrefillLen(int32_t len) {
batch_with_prefill_len_ += len;
}
} // namespace rtp_llm