maga_transformer/cpp/normal_engine/NormalGenerateStream.cc (150 lines of code) (raw):
#include "maga_transformer/cpp/normal_engine/NormalGenerateStream.h"
#include "maga_transformer/cpp/core/torch_utils/BufferTorchUtils.h"
namespace rtp_llm {
ErrorResult<GenerateOutputs> NormalGenerateStream::nextOutput() {
// TODO(xinfei.sxf) 某些case下会出现1s的等待
while ((!stopped()) && !finished() && generate_outputs_queue_.isEmpty()) {
checkTimeout();
generate_outputs_queue_.waitNotEmpty();
}
if (stopped()) {
return statusInfo();
}
if (generate_outputs_queue_.isEmpty()) {
if (finished()) {
return ErrorInfo(ErrorCode::FINISHED, "");
} else {
return ErrorInfo(ErrorCode::OUTPUT_QUEUE_IS_EMPTY, "output queue is empty");
}
}
return generate_outputs_queue_.getAndPopFront();
}
bool NormalGenerateStream::hasOutput() {
return !generate_outputs_queue_.isEmpty();
}
GenerateOutputs NormalGenerateStream::prepareGenerateOutput(const StreamUpdateInfo& update_info) {
size_t output_len = seqLength() - last_output_pos_;
GenerateOutputs generate_results;
generate_results.request_id = request_id_;
for (int i = 0; i < tileNum(); i++) {
GenerateOutput generate_output;
generate_output.aux_info.iter_count = iter_count_;
generate_output.aux_info.fallback_tokens = fallback_blocks_ * seqSizePerBlock();
generate_output.aux_info.fallback_times = fallback_times_;
generate_output.output_ids =
device_->allocateBuffer({rtp_llm::DataType::TYPE_INT32, {1lu, output_len}, rtp_llm::AllocationType::HOST}, {});
// TODO(xinfei.sxf) optimize this copy : only copy last token
complete_token_ids_->copyTokensTo(i, generate_output.output_ids->data(), last_output_pos_, output_len);
if (returnLogits() && update_info.logits) {
rtp_llm::BufferPtr host_logits;
if (update_info.logits->shape()[0] == 1) {
host_logits = device_->clone({*update_info.logits, rtp_llm::AllocationType::HOST});
} else {
host_logits = device_->clone({update_info.logits->view(i, 1), rtp_llm::AllocationType::HOST});
}
if (!generate_input_->generate_config->select_tokens_id.empty()) {
auto select_buf = rtp_llm::vector2Buffer(generate_input_->generate_config->select_tokens_id);
generate_output.logits = device_->select({*host_logits, *select_buf, 1});
} else {
// TODO(xinfei.sxf) not set logits in middle step for streaming
generate_output.logits = host_logits;
}
}
if (generate_input_->generate_config->return_hidden_states && update_info.hidden_states) {
if (update_info.hidden_states->shape()[0] == 1) {
generate_output.hidden_states = device_->clone({*update_info.hidden_states, rtp_llm::AllocationType::HOST});
} else {
generate_output.hidden_states = device_->clone({update_info.hidden_states->view(i, 1), rtp_llm::AllocationType::HOST});
}
}
if (loss_) {
RTP_LLM_CHECK_WITH_INFO(loss_index_ == inputLength() - 1,
"loss index should be input len [%d] - 1 but is [%d]",
inputLength(),
loss_index_);
auto loss = loss_;
if (generate_input_->generate_config->calculate_loss == 1) {
loss = device_->clone({*rtp_llm::torchTensor2Buffer(torch::mean(rtp_llm::Buffer2torchTensor(*loss_)).exp()),
rtp_llm::AllocationType::HOST});
}
generate_output.loss = loss;
}
if (generate_input_->generate_config->return_softmax_probs && softmax_probs_) {
generate_output.aux_info.softmax_probs = device_->clone({(*softmax_probs_)[i].view(last_output_pos_, output_len), rtp_llm::AllocationType::HOST});
}
generate_output.finished = sub_generate_status_[i].status == StreamState::FINISHED;
generate_output.aux_info.cost_time_us = autil::TimeUtility::currentTimeInMicroSeconds() - begin_time_us_;
generate_output.aux_info.first_token_cost_time_us = complete_token_ids_->firstTokenLatencyUs();
generate_output.aux_info.wait_time_us = wait_time_us_;
generate_output.aux_info.input_len = generate_input_->promptLength();
generate_output.aux_info.prefix_len = generate_input_->prefix_length;
// TODO(xinfei.sxf) 提前结束的query,output len要设置正确
generate_output.aux_info.output_len = seqLength() - generate_input_->inputLength();
generate_output.aux_info.step_output_len = output_len;
generate_output.aux_info.reuse_len = reuse_length_;
generate_output.aux_info.pd_sep = queryPdSep();
generate_output.aux_info.cum_log_probs =
device_->allocateBuffer({rtp_llm::DataType::TYPE_FP32, {1lu}, rtp_llm::AllocationType::HOST}, {});
if (update_info.cum_log_probs) {
memcpy(generate_output.aux_info.cum_log_probs.value()->data(),
cum_log_probs_->dataWithOffset<float>(i),
sizeof(float));
}
if (generate_input_->generate_config->return_all_probs) {
if (!update_info.all_probs) {
throw std::runtime_error("all_probs is not while generate_config return_all_probs is true");
}
generate_output.aux_info.all_probs = device_->clone(
{all_probs_->view(i, 1), rtp_llm::AllocationType::HOST});
}
generate_results.generate_outputs.push_back(generate_output);
}
return generate_results;
}
void NormalGenerateStream::enqueueGenerateOutput(GenerateOutputs generate_results) {
if (generate_outputs_queue_.getSize() >= generate_outputs_queue_.getCapacity()) {
/* No matter if the queue is full for any reason,
the stream will be set to stop directly to prevent the push to queue from getting stuck. */
setStop(ErrorCode::OUTPUT_QUEUE_FULL, "output queue is full");
} else {
generate_outputs_queue_.push(generate_results);
}
}
void NormalGenerateStream::updateOutput(const StreamUpdateInfo& update_info) {
RTP_LLM_LOG_DEBUG(__PRETTY_FUNCTION__);
// TODO(xinfei.sxf) consider the case of pd-sep first token finished.
if (update_info.loss) {
setLoss(*update_info.loss);
}
if (needReturnHiddenStates()) {
RTP_LLM_CHECK(update_info.all_hidden_states != nullptr);
last_hidden_states_ = device_->clone(
{*update_info.all_hidden_states, rtp_llm::AllocationType::DEVICE});
}
if (generate_input_->generate_config->return_softmax_probs && update_info.softmax_probs) {
RTP_LLM_CHECK(update_info.softmax_probs->dim() == 2);
RTP_LLM_CHECK(update_info.softmax_probs->shape()[1] == update_info.num_new_tokens);
setSoftmaxProbs(*update_info.softmax_probs, seqLength() - update_info.num_new_tokens);
}
finished_ = needFinish();
if (finished_) {
setFinishedWithoutLock();
}
if (update_info.cum_log_probs) {
device_->copy({*cum_log_probs_, *update_info.cum_log_probs});
}
if (update_info.all_probs) {
all_probs_ = device_->clone({*update_info.all_probs, rtp_llm::AllocationType::HOST});
}
//TODO: move it to better position
if (!finished_ && queryPdSep()) {
need_remote_generate_ = true;
}
bool pd_sep_first_token = queryPdSep();
bool need_update = pd_sep_first_token || isStreaming() || finished_;
if (!need_update) {
return;
}
if (seqLength() - last_output_pos_ == 0) {
return;
}
RTP_LLM_LOG_DEBUG("stream [%ld] enqueue generate output", streamId());
enqueueGenerateOutput(prepareGenerateOutput(update_info));
if (stoppedWithoutLock()) {
return;
}
last_output_pos_ = seqLength();
}
}; // namespace rtp_llm