maga_transformer/cpp/stream/GenerateStream.h (289 lines of code) (raw):
#pragma once
#include "absl/status/statusor.h"
#include "autil/TimeUtility.h"
#include "autil/SynchronizedQueue.h"
#include "kmonitor/client/MetricsReporter.h"
#include "maga_transformer/cpp/models/GptModel.h"
#include "maga_transformer/cpp/models/Sampler.h"
#include "maga_transformer/cpp/stream/StreamCacheResource.h"
#include "maga_transformer/cpp/stream/CompleteTokenIds.h"
#include "maga_transformer/cpp/system_prompt/SystemPrompt.h"
#include "maga_transformer/cpp/position_ids_generator/PositionIdsGenerator.h"
namespace rtp_llm {
// WARNGING: buffer in generate stream should all be host to avoid gpu buffer hold more time (except kv cache)
struct StreamUpdateInfo {
const rtp_llm::BufferPtr new_tokens;
int num_new_tokens;
const rtp_llm::BufferPtr hidden_states;
const rtp_llm::BufferPtr logits;
const rtp_llm::BufferPtr softmax_probs;
const rtp_llm::BufferPtr cum_log_probs;
const rtp_llm::BufferPtr all_probs;
const rtp_llm::BufferPtr loss;
// for mtp
const rtp_llm::BufferPtr all_hidden_states;
};
class GenerateStream {
public:
GenerateStream(const std::shared_ptr<GenerateInput>& query, const rtp_llm::GptInitParameter& params,
const ResourceContext& resource_context, kmonitor::MetricsReporterPtr metrics_reporter);
virtual ~GenerateStream() {
reportMetric();
releaseResource();
}
public:
// Exported to python world.
virtual void cancel();
virtual ErrorResult<GenerateOutputs> nextOutput() = 0;
virtual bool hasOutput() {return false;}
virtual void updateOutput(const StreamUpdateInfo& update_info) = 0;
void update(const StreamUpdateInfo& update_info);
virtual size_t scoreLen() const {
return 1;
}
// Only used in C++ world.
int reuseBlockSize() const;
void fakeInitKVBlock();
virtual absl::StatusOr<int> initKVBlock(int token_capacity, size_t reserve_step = 0);
virtual absl::StatusOr<int> incrKVBlock(int token_capacity, size_t reserve_step = 0);
virtual int tryReleaseKVBlock(int nums);
virtual void releaseResource();
int nextNeedBlockNums(size_t reserve_step) const;
void setNeedReleaseResource(bool need_release_resource);
void incrFallbackBlock(int fallback_blocks);
bool hasCacheKeys() const;
const std::vector<int64_t>& cacheKeys(int32_t batch_id = 0) const;
std::shared_ptr<GenerateInput> generateInput() const;
std::shared_ptr<GenerateConfig>& generateConfig() const;
std::vector<int> textTokensMask() const;
bool isStreaming() const;
int64_t streamId() const;
int loraId() const;
std::string adapterName() const;
rtp_llm::SpecialTokens specialTokens() const;
int tileNum() const;
int batchSize() const;
int numBeams() const;
int numReturnSequences() const;
bool calculateLoss() const;
bool calculateSoftmaxProbs() const;
bool returnLogits() const;
bool returnCumLogProbs() const;
bool genTimeline() const;
bool updatePrefix(const std::shared_ptr<SystemPrompt>& system_prompt);
size_t maxSeqLen() const;
int inputLength() const;
int seqLength() const;
// NOTE: In generatestream, set seq len must use setSeqLength api, we need to save start_check_seq_length_
// for checking EOS and stop words
void setSeqLength(int seq_length);
int adjustedCommonLen() const;
int seqSizePerBlock() const;
int contextLength() const;
int prefixLength() const;
int inputPrefixLength() const;
int reuseLength() const;
int initialReuseLength() const;
size_t maxTokenNum() const;
void setReuseLength(int reuse_length);
void setInitialReuseLength(int initial_reuse_length);
int fallbackPrefixLength() const;
void setFallbackPrefixLength(int fallback_prefix_length);
void incLastOutputPos();
absl::StatusOr<int> acquireCapacity(int token_capacity);
int currentChunkLen() const;
void resetChunkLen(int chunck_len, int max_chunk_len);
bool isContextStream() const;
bool isChunkStream() const;
const rtp_llm::BufferPtr& cumLogProbs() const;
const rtp_llm::BufferPtr& completeTokenIds();
std::vector<int> completeTokenIdsVec(int batch_idx = 0);
std::vector<int> commonCompleteTokenIdsVec(int batch_idx = 0);
int currentExecuteTokenSize();
std::vector<int> currentExecuteTokens(int batch_idx = 0) const;
void step();
std::vector<torch::Tensor> multimodalFeatures() const;
int multimodalFeaturesLength() const;
rtp_llm::BufferPtr multimodalLocations() const;
std::vector<std::vector<int>> multimodalIntervals() const;
int64_t getTimeoutMs() const;
void checkTimeout();
void setStop(ErrorCode error_code, const std::string& error_msg);
void setStopWithoutLock(ErrorCode error_code, const std::string& error_msg);
void stopAndRelease(ErrorCode error_code, const std::string& error_msg);
ErrorInfo statusInfo();
bool isDoneWithoutLock(int batch_id) const;
void setPaused();
bool setRunning();
bool stoppedWithoutLock();
virtual bool stopped();
bool paused();
std::string stopReason();
virtual bool finished();
bool running();
bool waiting();
bool finishedWithoutLock();
void cancelIfNotRunning();
void setFinishedWithoutLock();
bool needRemoteGenerate() const;
void setRemoteGenerate();
size_t iterCount() const;
const ResourceContext& resourceContext() const;
void setKVCache(const BatchKVCacheResource &kv_cache_resource);
void setLoss(const rtp_llm::Buffer& loss);
void setSoftmaxProbs(const rtp_llm::Buffer& softmax_probs, int start_pos);
const BatchKVCacheResource& kvCache() const;
size_t maxBlockSize() const;
bool needFinish();
bool needFinishBySPTokens();
void matchEosToken();
void matchEosToken(int batch_id);
void matchStopWordsList();
void matchStopWordsList(int batch_id);
void setMetricsReporter(kmonitor::MetricsReporterPtr metrics_reporter);
void reportMetric();
std::string debugString() const;
void resetBeginTime(int64_t begin_time_us);
// for test
void setIsContextStream(bool is_context_stream);
rtp_llm::BufferPtr getLoss();
rtp_llm::BufferPtr getLastHiddenStates();
void setLastHiddenStates(rtp_llm::BufferPtr hidden_states) {
last_hidden_states_ = hidden_states;
};
rtp_llm::BufferPtr getSoftmaxProbs();
StreamCacheResource& streamCacheResource();
void setPerfTest(bool perf_test_);
absl::Status releaseSequenceKVCache(size_t total_seq_len, size_t release_seq_len) {
return stream_cache_resource_.releaseSequenceKVCache(total_seq_len, release_seq_len);
}
void CopyOnWrite(const GenerateStream& other_stream, bool copy_loss = true);
void setReturnAllProbs(bool return_all_probs) {
return_all_probs_ = return_all_probs;
}
bool getReturnAllProbs() {
return return_all_probs_;
}
void setAccepedBounsToken(bool acceped_bouns_token) {
acceped_bouns_token_ = acceped_bouns_token;
}
bool getAccepedBounsToken() {
return acceped_bouns_token_;
}
void beamSearchKvCacheUpdate(rtp_llm::BufferPtr beam_idx);
rtp_llm::BufferPtr generateContextPositionIds(rtp_llm::DeviceBase* device);
void generateNextPositionId(int32_t* now_pos);
int64_t vocabSize() const {
return vocab_size_;
}
size_t outputTokenLen() const {
return seqLength() - inputLength();
}
size_t spEditSearchIndex() const {
return sp_edit_search_index_;
}
void incSpEditSearchIndex(size_t accepted_num) {
if (sp_edit_run_) {
sp_edit_search_index_ += accepted_num;
}
}
void setSpEditRun(bool is_sp_edit_run) {
sp_edit_run_ = is_sp_edit_run;
}
bool spEditFirstTime() const {
return sp_edit_first_time_;
}
void setSpEditFirstTime(bool sp_edit_first_time) {
sp_edit_first_time_ = sp_edit_first_time;
}
void setReturnLastHiddenStates(bool flag) {
return_all_hidden_states_ = flag;
}
bool forceDisableSpRun() const {
return generate_input_->generate_config->force_disable_sp_run;
}
bool disableSpRun() const {
return numBeams() > 1 || forceDisableSpRun();
}
bool needReturnHiddenStates() {
return return_all_hidden_states_;
}
void setMtpTokenIndex(int mtp_token_index) {
mtp_token_index_ = mtp_token_index;
}
rtp_llm::BufferPtr returnEmptyHiddenStates() {
RTP_LLM_CHECK(last_hidden_states_ == nullptr);
RTP_LLM_CHECK(seqLength() > 0);
last_hidden_states_ = device_->allocateBuffer(
{dtype_, {(size_t)seqLength(), hidden_size_}, rtp_llm::AllocationType::DEVICE});
return last_hidden_states_;
}
std::vector<int> getLatestTokens(size_t token_num);
void incBatchWithPrefillTimes(int32_t times);
void incBatchWithPrefillLen(int32_t len);
const std::vector<StreamThinkInfo> streamThinkInfo() {
return think_infos_;
}
public:
struct TimeInfo {
int64_t begin_time_us;
int64_t wait_time_us;
int64_t first_token_time_us;
int64_t first_token_rt_us;
};
TimeInfo getTimeInfo();
bool queryPdSep() const;
protected:
rtp_llm::DeviceBase* device_;
std::shared_ptr<GenerateInput> generate_input_;
GenerateStatus generate_status_;
std::vector<GenerateStatus> sub_generate_status_;
int max_seq_len_;
bool acceped_bouns_token_ = false;
int64_t vocab_size_;
std::shared_ptr<CompleteTokenIds> complete_token_ids_;
int64_t begin_time_us_;
int64_t last_pause_us_ = 0;
int64_t pause_time_us_ = 0;
int64_t wait_time_us_ = 0;
StreamCacheResource stream_cache_resource_;
bool is_context_stream_ = true;
size_t iter_count_ = 0;
size_t last_output_pos_ = 0;
int initial_reuse_length_ = 0;
int reuse_length_ = 0;
int reuse_mm_length_ = 0;
int fallback_blocks_ = 0;
int fallback_times_ = 0;
int fallback_prefix_length_ = 0;
// TOOD(xinfei.sxf) fix state
bool done_ = false;
bool released_ = false;
bool need_release_resource_ = true;
bool enable_fast_gen_ = false;
bool return_all_probs_ = false;
int current_chunk_len_ = 0;
int last_chunk_len_ = 0;
int max_chunk_len_ = 0;
int sp_edit_search_index_ = 0;
bool sp_edit_first_time_ = true;
bool sp_edit_run_ = false;
bool last_block_aligned_ = false;
bool need_remote_generate_ = false;
bool use_cache_store_ = false;
bool gen_timeline_ = false;
// The number of times this stream has been interfered by prefills
int32_t batch_with_prefill_times_ = 0;
int32_t batch_with_prefill_len_ = 0;
kmonitor::MetricsReporterPtr metrics_reporter_;
rtp_llm::SpecialTokens special_tokens_;
rtp_llm::BufferPtr cum_log_probs_;
rtp_llm::BufferPtr all_probs_;
rtp_llm::BufferPtr softmax_probs_;
rtp_llm::BufferPtr loss_;
rtp_llm::BufferPtr last_hidden_states_;
int loss_index_ = 0;
std::shared_ptr<std::mutex> output_mutex_;
bool return_all_hidden_states_ = false;
int mtp_token_index_ = 0;
std::optional<rtp_llm::BufferPtr> context_position_ids_;
PositionIdsStyle mm_position_ids_style_;
std::vector<StreamThinkInfo> think_infos_;
rtp_llm::DataType dtype_;
size_t hidden_size_;
// just for bool test
bool perf_test_ = false;
friend class StreamCacheResource;
};
typedef std::shared_ptr<GenerateStream> GenerateStreamPtr;
} // namespace rtp_llm