maga_transformer/cpp/normal_engine/NormalBatchStreamProcessor.cc (415 lines of code) (raw):
#include <algorithm>
#include <cstring>
#include <memory>
#include <random>
#include <limits>
#include <utility>
#include "c10/core/DeviceType.h"
#include "c10/core/ScalarType.h"
#include "maga_transformer/cpp/models/Sampler.h"
#include "maga_transformer/cpp/core/Buffer.h"
#include "maga_transformer/cpp/utils/AssertUtils.h"
#include "maga_transformer/cpp/core/Types.h"
#include "maga_transformer/cpp/normal_engine/NormalBatchStreamProcessor.h"
#include "maga_transformer/cpp/dataclass/MergedQuery.h"
#include "maga_transformer/cpp/core/torch_utils/BufferTorchUtils.h"
#include "maga_transformer/cpp/devices/utils/DebugUtils.h"
using namespace std;
namespace rtp_llm {
absl::StatusOr<GptModelInputs> NormalBatchStreamProcessor::gatherModelInput(const StreamGroups& stream_groups) const {
RTP_LLM_LOG_DEBUG(__PRETTY_FUNCTION__);
auto context_streams = stream_groups.contextStreams();
auto decode_streams = stream_groups.decodeStreams();
RTP_LLM_LOG_DEBUG("context_streams size = %d, decode_streams size = %d", context_streams.size(), decode_streams.size());
GptModelInputs model_input;
size_t current_tokens_size = stream_groups.modelExecuteTokenSize();
size_t total_batch_size = stream_groups.totalModelBatchSize();
size_t total_decode_batch_size = stream_groups.totalDecodeBatchSize();
size_t total_context_batch_size = stream_groups.totalContextBatchSize();
size_t max_block_size = stream_groups.maxBlockSize();
size_t multimodal_features_len = stream_groups.mmFeaturesLen();
const bool has_multimodal_input = is_multimodal_ && stream_groups.has_multimodal_input();
const bool need_cal_position_id = (mm_position_ids_style_ != PositionIdsStyle::DEFAULT) || has_positional_encoding_;
model_input.combo_tokens =
device_->allocateBuffer({rtp_llm::DataType::TYPE_INT32, {current_tokens_size}, rtp_llm::AllocationType::HOST}, {});
if (max_block_size) {
model_input.kv_cache_block_id = device_->allocateBuffer(
{rtp_llm::DataType::TYPE_INT32, {total_batch_size, max_block_size}, rtp_llm::AllocationType::HOST}, {});
model_input.cache_keys = device_->allocateBuffer(
{rtp_llm::DataType::TYPE_INT64, {total_context_batch_size, max_block_size}, rtp_llm::AllocationType::HOST}, {});
}
model_input.request_id = device_->allocateBuffer(
{rtp_llm::DataType::TYPE_INT64, {total_context_batch_size}, rtp_llm::AllocationType::HOST}, {});
model_input.request_pd_separation = device_->allocateBuffer(
{rtp_llm::DataType::TYPE_BOOL, {total_context_batch_size}, rtp_llm::AllocationType::HOST}, {});
model_input.input_lengths =
device_->allocateBuffer({rtp_llm::DataType::TYPE_INT32, {total_batch_size}, rtp_llm::AllocationType::HOST}, {});
model_input.lora_ids =
device_->allocateBuffer({rtp_llm::DataType::TYPE_INT32, {total_batch_size}, rtp_llm::AllocationType::HOST}, {});
model_input.lora_input_lengths =
device_->allocateBuffer({rtp_llm::DataType::TYPE_INT32, {total_batch_size}, rtp_llm::AllocationType::HOST}, {});
model_input.sequence_lengths =
device_->allocateBuffer({rtp_llm::DataType::TYPE_INT32, {total_decode_batch_size}, rtp_llm::AllocationType::HOST}, {});
model_input.lm_output_indexes =
device_->allocateBuffer({rtp_llm::DataType::TYPE_INT32, {total_batch_size}, rtp_llm::AllocationType::HOST}, {});
model_input.prefix_lengths =
device_->allocateBuffer({rtp_llm::DataType::TYPE_INT32, {total_context_batch_size}, rtp_llm::AllocationType::HOST}, {});
if (need_cal_position_id) {
model_input.combo_position_ids =
device_->allocateBuffer({rtp_llm::DataType::TYPE_INT32, {current_tokens_size * position_id_len_factor_}, rtp_llm::AllocationType::HOST}, {});
}
if (has_multimodal_input) {
model_input.text_tokens_mask =
device_->allocateBuffer({rtp_llm::DataType::TYPE_INT32, {current_tokens_size}, rtp_llm::AllocationType::HOST}, {});
model_input.mm_features_locs =
device_->allocateBuffer({rtp_llm::DataType::TYPE_INT32, {multimodal_features_len}, rtp_llm::AllocationType::HOST}, {});
}
model_input.k_block_size = k_block_size_;
model_input.v_block_size = v_block_size_;
model_input.seq_size_per_block = seq_size_per_block_;
model_input.scale_block_size = scale_block_size_;
model_input.pd_separation = pd_separation_;
model_input.warmup = warm_up_;
int* merged_tokens = (int*)model_input.combo_tokens->data();
int* input_lengths = (int*)model_input.input_lengths->data();
int* lora_ids = (int*)model_input.lora_ids->data();
int* lora_input_lengths = (int*)model_input.lora_input_lengths->data();
int* sequence_lengths = (int*)model_input.sequence_lengths->data();
int* lm_output_indexes = (int*)model_input.lm_output_indexes->data();
int* prefix_lengths = (int*)model_input.prefix_lengths->data();
int* combo_position_ids = need_cal_position_id ? (int*)model_input.combo_position_ids->data() : nullptr;
int* merged_text_mask = has_multimodal_input ? (int*)model_input.text_tokens_mask->data() : nullptr;
int* mm_features_locs = has_multimodal_input ? (int*)model_input.mm_features_locs->data() : nullptr;
int batch_idx = 0;
int input_vocab_size = input_vocab_size_ ? input_vocab_size_ : vocab_size_;
if (merged_text_mask) {
std::fill(merged_text_mask, merged_text_mask + current_tokens_size, 1);
}
for (const auto& stream : decode_streams) {
model_input.need_all_logits = model_input.need_all_logits || stream->calculateLoss();
auto current_batch_size = stream->batchSize();
auto kv_cache = stream->kvCache();
RTP_LLM_LOG_DEBUG("decode kv_cache: %s", kv_cache.debugString().c_str());
RTP_LLM_LOG_DEBUG("decode stream: %s", stream->debugString().c_str());
for (auto i = 0; i < current_batch_size; ++i) {
auto currentTokens = stream->currentExecuteTokens(i);
if (currentTokens[0] >= input_vocab_size) {
std::ostringstream error_msg;
error_msg << "stream [" << stream->streamId() << "] token_id " << currentTokens[0] << " exceed vocab_size " << input_vocab_size;
return absl::InvalidArgumentError(error_msg.str());
}
merged_tokens[batch_idx] = currentTokens[0];
input_lengths[batch_idx] = stream->inputLength();
sequence_lengths[batch_idx] = stream->seqLength() - 1; // need remove
if (need_cal_position_id) {
stream->generateNextPositionId(combo_position_ids + batch_idx * position_id_len_factor_);
}
lora_ids[batch_idx] = stream->loraId();
lora_input_lengths[batch_idx] = 1;
lm_output_indexes[batch_idx] = batch_idx;
if (max_block_size) {
std::memcpy((*model_input.kv_cache_block_id)[batch_idx].data(),
kv_cache.batch_block_id[i].data(),
kv_cache.batch_block_id[i].size() * sizeof(int));
}
batch_idx += 1;
}
stream->step();
}
std::vector<rtp_llm::BufferPtr> gathered_mm_features;
int token_idx = batch_idx;
int cum_output_seq_len = batch_idx;
int mm_feature_index = 0;
for (const auto& stream : context_streams) {
// context stream也需要batch运行是为了fallback的场景和perf test的场景
model_input.need_all_logits = model_input.need_all_logits || stream->calculateLoss();
auto current_batch_size = stream->batchSize();
auto kv_cache = stream->kvCache();
RTP_LLM_LOG_DEBUG("context kv_cache: %s", kv_cache.debugString().c_str());
RTP_LLM_LOG_DEBUG("context stream: %s", stream->debugString().c_str());
// TODO(xinfei.sxf) deal with adjusted common seq len.
for (auto i = 0; i < current_batch_size; ++i) {
auto input_tokens = stream->currentExecuteTokens(i);
auto input_masks = stream->textTokensMask();
memcpy(merged_tokens + token_idx, input_tokens.data(), input_tokens.size() * sizeof(int));
cum_output_seq_len += input_tokens.size();
for (int index = 0; index < input_tokens.size(); ++index) {
if (input_tokens[index] >= input_vocab_size && (index >= input_masks.size() || input_masks[index])) {
std::ostringstream error_msg;
error_msg << "stream [" << stream->streamId() << "] token_id " << input_tokens[index] << " exceed vocab_size " << input_vocab_size;
return absl::InvalidArgumentError(error_msg.str());
}
}
input_lengths[batch_idx] = input_tokens.size();
prefix_lengths[batch_idx - total_decode_batch_size] = stream->prefixLength();
lm_output_indexes[batch_idx] = cum_output_seq_len - 1;
if (has_multimodal_input) {
std::vector<torch::Tensor> mm_features = stream->multimodalFeatures();
rtp_llm::BufferPtr mm_locs = stream->multimodalLocations();
if (mm_locs != nullptr) {
for (int i = 0;i < mm_locs->size(); ++i) {
mm_features_locs[mm_feature_index] = *mm_locs->dataWithOffset<int>(i) + token_idx - stream->reuseLength();
mm_feature_index++;
}
for (auto& mm_feature: mm_features) {
auto feature_buffer = torchTensor2Buffer(mm_feature);
if (feature_buffer->where() != rtp_llm::MemoryType::MEMORY_GPU) {
gathered_mm_features.emplace_back(device_->clone({*feature_buffer}));
} else {
gathered_mm_features.emplace_back(feature_buffer);
}
}
auto text_token_mask = stream->textTokensMask();
memcpy(merged_text_mask + token_idx, text_token_mask.data(), text_token_mask.size() * sizeof(int));
}
}
if (need_cal_position_id) {
auto context_pos_ids = stream->generateContextPositionIds(device_);
memcpy(combo_position_ids + token_idx * position_id_len_factor_,
context_pos_ids->dataWithOffset<int>(stream->reuseLength() * position_id_len_factor_),
(context_pos_ids->size() - stream->reuseLength() * position_id_len_factor_) * context_pos_ids->typeSize());
}
lora_ids[batch_idx] = stream->loraId();
lora_input_lengths[batch_idx] = input_lengths[batch_idx];
if (max_block_size) {
std::memcpy((*model_input.kv_cache_block_id)[batch_idx].data(),
kv_cache.batch_block_id[i].data(),
kv_cache.batch_block_id[i].size() * sizeof(int));
if (stream->hasCacheKeys()) {
std::memcpy((*model_input.cache_keys)[batch_idx - total_decode_batch_size].data(),
stream->cacheKeys(i).data(),
stream->cacheKeys(i).size() * sizeof(int64_t));
}
}
*(model_input.request_id->dataWithOffset<int64_t>(batch_idx - total_decode_batch_size)) = stream->streamId();
*(model_input.request_pd_separation->dataWithOffset<bool>(batch_idx - total_decode_batch_size)) = stream->queryPdSep();
batch_idx += 1;
token_idx += input_tokens.size();
}
stream->step();
}
if (is_multimodal_ && gathered_mm_features.size() > 0) {
model_input.multimodal_features = std::move(gathered_mm_features);
}
return model_input;
}
absl::StatusOr<SamplerInputs>
NormalBatchStreamProcessor::gatherSamplerInput(const StreamGroups& stream_groups,
const GptModelInputs& model_inputs,
const GptModelOutputs& model_output) const {
RTP_LLM_LOG_DEBUG(__PRETTY_FUNCTION__);
RTP_LLM_CHECK(!stream_groups.empty());
const auto& context_streams = stream_groups.contextStreams();
size_t total_decode_batch_size = stream_groups.totalDecodeBatchSize();
auto all_streams = stream_groups.allStreams();
auto total_batch_size = stream_groups.totalSamplerBatchSize();
bool return_all_probs = stream_groups.needReturnAllProbs();
SamplerInputs sampler_inputs = allocateSamplerInputs(stream_groups, total_batch_size, model_inputs.sequence_lengths);
setCommonSamplerInputs(sampler_inputs, all_streams);
setThinkModeLogitsProcessorInputs(sampler_inputs, all_streams);
int batch_idx = 0;
bool return_logits = false;
bool calculate_softmax_probs = false;
for (auto& stream : all_streams) {
const auto& complete_token_ids = stream->completeTokenIds();
auto complete_seq_len = complete_token_ids->shape()[1];
auto seq_len = stream->seqLength();
auto current_batch_size = stream->tileNum();
for (int i = 0; i < current_batch_size; ++i) {
memcpy(sampler_inputs.token_ids->dataWithOffset<int32_t>((batch_idx) * (sampler_inputs.step + 1)),
complete_token_ids->dataWithOffset<int32_t>(i * complete_seq_len),
seq_len * sizeof(int));
batch_idx += 1;
}
return_logits |= stream->returnLogits();
calculate_softmax_probs |= stream->calculateSoftmaxProbs();
RTP_LLM_LOG_DEBUG("stream [%d], complete token ids = [%s]", stream->streamId(), complete_token_ids->debugStringWithData<int32_t>(sampler_inputs.step).c_str());
RTP_LLM_LOG_DEBUG("stream [%d], sampler inputs token ids = [%s]", stream->streamId(), sampler_inputs.token_ids->debugStringWithData<int32_t>().c_str());
}
auto vocab_size = model_output.logits->shape()[1];
sampler_inputs.vocab_size = vocab_size;
if (return_all_probs) {
sampler_inputs.all_probs = device_->allocateBuffer({rtp_llm::DataType::TYPE_FP32, {total_batch_size, vocab_size}, rtp_llm::AllocationType::DEVICE}, {});
device_->bufMemset(*sampler_inputs.all_probs, 0);
}
batch_idx = 0;
// need copy logits when has tile or return logits
if (return_logits || calculate_softmax_probs || (context_streams.size() && total_batch_size > all_streams.size())) {
sampler_inputs.logits = device_->allocateBuffer({model_output.logits->type(), {total_batch_size, vocab_size}, rtp_llm::AllocationType::DEVICE}, {});
device_->copy({sampler_inputs.logits->view(0, total_decode_batch_size), model_output.logits->view(0, total_decode_batch_size)});
} else {
sampler_inputs.logits = model_output.logits;
}
batch_idx += total_decode_batch_size;
size_t logits_offset = batch_idx;
for (auto& stream : context_streams) {
auto current_batch_size = stream->tileNum();
for (int i = 0; i < current_batch_size; ++i) {
device_->copy({sampler_inputs.logits->view(batch_idx, 1), model_output.logits->view(logits_offset, 1)});
batch_idx += 1;
}
logits_offset += 1;
}
RTP_LLM_LOG_DEBUG("sampler inputs logits [%s]",
device_->clone({*sampler_inputs.logits, rtp_llm::AllocationType::HOST})->debugStringWithData<float>(10).c_str());
RTP_LLM_LOG_DEBUG("gatherSamplerInput done");
return std::move(sampler_inputs);
}
SamplerInputs NormalBatchStreamProcessor::allocateSamplerInputs(const StreamGroups& stream_groups, size_t total_batch_size, const rtp_llm::BufferPtr& sequence_lengths) const {
// TODO(xinfei.sxf) don't sample for chunk stream
SamplerInputs sampler_inputs;
sampler_inputs.step = stream_groups.maxSeqLen();;
sampler_inputs.batch_size = total_batch_size;
sampler_inputs.sequence_lengths = device_->allocateBuffer({rtp_llm::DataType::TYPE_INT32, {total_batch_size}, rtp_llm::AllocationType::HOST}, {});
sampler_inputs.grammars.clear();
sampler_inputs.beam_search_sequence_lengths = device_->allocateBuffer({rtp_llm::DataType::TYPE_INT32, {total_batch_size}, rtp_llm::AllocationType::HOST}, {});
sampler_inputs.beam_index = device_->allocateBuffer({rtp_llm::DataType::TYPE_INT32, {total_batch_size}, rtp_llm::AllocationType::HOST}, {});
// TODO(lidongjin.ldj) use bufMemset after arm/amd support this op.
// eg: device_->bufMemset(*sampler_inputs.beam_index, 0);
sampler_inputs.input_lengths = device_->allocateBuffer({rtp_llm::DataType::TYPE_INT32, {total_batch_size}, rtp_llm::AllocationType::HOST}, {});
sampler_inputs.num_beams = device_->allocateBuffer({rtp_llm::DataType::TYPE_UINT64, {total_batch_size}, rtp_llm::AllocationType::HOST}, {});
sampler_inputs.top_k = device_->allocateBuffer({rtp_llm::DataType::TYPE_UINT32, {total_batch_size}, rtp_llm::AllocationType::HOST}, {});
sampler_inputs.top_p = device_->allocateBuffer({rtp_llm::DataType::TYPE_FP32, {total_batch_size}, rtp_llm::AllocationType::HOST}, {});
sampler_inputs.temperature = device_->allocateBuffer({rtp_llm::DataType::TYPE_FP32, {total_batch_size}, rtp_llm::AllocationType::HOST}, {});
sampler_inputs.random_seeds = device_->allocateBuffer({rtp_llm::DataType::TYPE_UINT64, {total_batch_size}, rtp_llm::AllocationType::HOST}, {});
sampler_inputs.repetition_penalty = device_->allocateBuffer({rtp_llm::DataType::TYPE_FP32, {total_batch_size}, rtp_llm::AllocationType::HOST}, {});
sampler_inputs.min_lengths = device_->allocateBuffer({rtp_llm::DataType::TYPE_INT32, {total_batch_size}, rtp_llm::AllocationType::HOST}, {});
sampler_inputs.no_repeat_ngram_size = device_->allocateBuffer({rtp_llm::DataType::TYPE_INT32, {total_batch_size}, rtp_llm::AllocationType::HOST}, {});
if (stream_groups.needReturnCumLogProbs()) {
sampler_inputs.cum_log_probs = device_->allocateBuffer({rtp_llm::DataType::TYPE_FP32, {total_batch_size}, rtp_llm::AllocationType::HOST}, {});
}
sampler_inputs.token_ids = device_->allocateBuffer(
{rtp_llm::DataType::TYPE_INT32, {total_batch_size, sampler_inputs.step + 1}, rtp_llm::AllocationType::HOST}, {});
return sampler_inputs;
}
void NormalBatchStreamProcessor::setCommonSamplerInputs(SamplerInputs& sampler_inputs, std::list<GenerateStreamPtr>& all_streams, bool score_batch) const {
int* input_lengths = sampler_inputs.input_lengths->data<int32_t>();
int* sequence_lengths = sampler_inputs.sequence_lengths->data<int32_t>();
uint64_t* num_beams = sampler_inputs.num_beams->data<uint64_t>();
uint32_t* top_k = sampler_inputs.top_k->data<uint32_t>();
float* top_p = sampler_inputs.top_p->data<float>();
float* temperature = sampler_inputs.temperature->data<float>();
uint64_t* random_seeds = sampler_inputs.random_seeds->data<uint64_t>();
float* repetition_penalty = sampler_inputs.repetition_penalty->data<float>();
int32_t* min_lengths = sampler_inputs.min_lengths->data<int32_t>();
int32_t* no_repeat_ngram_size = sampler_inputs.no_repeat_ngram_size->data<int32_t>();
int* beam_search_sequence_lengths = sampler_inputs.beam_search_sequence_lengths->data<int32_t>();
int batch_idx = 0;
bool has_random_seed = false;
for (auto& stream : all_streams) {
int current_batch_size;
if (!score_batch) {
current_batch_size = stream->tileNum();
} else {
current_batch_size = stream->scoreLen();
}
if (sampler_inputs.cum_log_probs) {
const auto& cum_log_probs = stream->cumLogProbs();
memcpy(sampler_inputs.cum_log_probs->dataWithOffset<float>(batch_idx), cum_log_probs->data(), cum_log_probs->sizeBytes());
}
for (int i = 0; i < current_batch_size; ++i) {
input_lengths[batch_idx] = stream->inputLength();
sequence_lengths[batch_idx] = stream->seqLength();
beam_search_sequence_lengths[batch_idx] = stream->seqLength();
// TODO(xinfei.sxf) fix num beams after sampler support
num_beams[batch_idx] = stream->numBeams();
top_k[batch_idx] = stream->generateConfig()->top_k;
top_p[batch_idx] = stream->generateConfig()->top_p;
temperature[batch_idx] = stream->generateConfig()->temperature;
repetition_penalty[batch_idx] = stream->generateConfig()->repetition_penalty;
min_lengths[batch_idx] = stream->generateConfig()->min_new_tokens;
if (stream->generateConfig()->random_seed.has_value()) {
random_seeds[batch_idx] = stream->generateConfig()->random_seed.value();
has_random_seed = true;
} else {
std::random_device rd;
std::mt19937_64 gen(rd());
std::uniform_int_distribution<std::int64_t> distrib(0, std::numeric_limits<std::int64_t>::max());
random_seeds[batch_idx] = distrib(gen);
}
no_repeat_ngram_size[batch_idx] = stream->generateConfig()->no_repeat_ngram_size.value_or(0);
batch_idx += 1;
}
}
if (!has_random_seed) {
sampler_inputs.random_seeds.reset();
}
}
void NormalBatchStreamProcessor::setThinkModeLogitsProcessorInputs(SamplerInputs& sampler_inputs, std::list<GenerateStreamPtr>& all_streams, bool score_batch) const {
std::vector<StreamThinkInfo> think_infos;
for (auto& stream : all_streams) {
const std::vector<StreamThinkInfo> streamThinkInfo = stream->streamThinkInfo();
think_infos.insert(think_infos.end(), streamThinkInfo.begin(), streamThinkInfo.end());
}
RTP_LLM_CHECK(think_infos.size() == sampler_inputs.batch_size);
BaseLogitsProcessorPtr processor_ptr = std::make_shared<ThinkModeLogitsProcessor>(device_, think_infos);
sampler_inputs.grammars.push_back(processor_ptr);
}
absl::Status NormalBatchStreamProcessor::dispatch(const StreamGroups& stream_groups,
const MergedOutput& merge_outputs) const {
RTP_LLM_LOG_DEBUG(__PRETTY_FUNCTION__);
const auto& model_output = merge_outputs.model_output;
const auto& sampler_output = merge_outputs.sampler_output;
const auto& new_all_token_ids = sampler_output.token_ids;
RTP_LLM_LOG_DEBUG("new_all_token_ids = [%s]", new_all_token_ids->debugStringWithData<int32_t>().c_str());
const size_t step = new_all_token_ids->shape()[1];
size_t total_batch_size = stream_groups.totalSamplerBatchSize();
RTP_LLM_CHECK(total_batch_size == new_all_token_ids->shape()[0]);
int batch_idx = 0;
int offset = 0;
int token_offset = 0;
bool return_all_probs = stream_groups.needReturnAllProbs();
rtp_llm::BufferPtr new_tokens_all = device_->allocateBuffer({rtp_llm::DataType::TYPE_INT32, {(size_t)total_batch_size, (size_t)1}, rtp_llm::AllocationType::HOST}, {});
for (auto& stream : stream_groups.allStreams()) {
if (stream->isChunkStream()) {
continue;
}
auto current_batch_size = stream->tileNum();
auto new_tokens = new_tokens_all->slice(batch_idx, current_batch_size);
auto token_size = stream->currentExecuteTokenSize();
auto batch = stream->isContextStream() ? 1 : current_batch_size;
auto batch_logits = model_output.logits->slice(offset, batch);
auto batch_hidden_states = model_output.hidden_states->slice(offset, batch);
BufferPtr batch_cum_log_probs;
if (sampler_output.cum_log_probs) {
batch_cum_log_probs = sampler_output.cum_log_probs->slice(batch_idx, current_batch_size);
}
auto all_probs = return_all_probs ? sampler_output.all_probs->slice(batch_idx, current_batch_size) : nullptr;
BufferPtr loss = nullptr;
BufferPtr beam_index = (sampler_output.beam_index == nullptr) ? nullptr : sampler_output.beam_index->slice(batch_idx, current_batch_size);
if (stream->calculateLoss()) {
auto all_logits = model_output.all_logits->view(token_offset, token_size - 1);
auto tokens = stream->currentExecuteTokens(0);
rtp_llm::BufferPtr label = device_->clone({{rtp_llm::MemoryType::MEMORY_CPU, rtp_llm::DataType::TYPE_INT32, {tokens.size() - 1}, tokens.data() + 1}});
loss = device_->loss({all_logits, *label});
}
BufferPtr all_hidden_states = nullptr;
if (stream->needReturnHiddenStates()) {
all_hidden_states = model_output.all_hidden_states->slice(token_offset, token_size);
}
BufferPtr batch_softmax_result;
BufferPtr current_softmax_result;
if (stream->calculateSoftmaxProbs()) {
current_softmax_result = device_->allocateBuffer({rtp_llm::DataType::TYPE_FP32, {(size_t)current_batch_size, (size_t)1}, rtp_llm::AllocationType::HOST}, {});
batch_softmax_result = device_->softmax({batch_logits, std::nullopt, std::nullopt, 1.0f, DataType::TYPE_FP32, std::nullopt});
}
for (int i = 0; i < current_batch_size; ++i) {
memcpy(new_tokens->dataWithOffset<int32_t>(i), new_all_token_ids->dataWithOffset<int32_t>(batch_idx * step + step - 1), sizeof(int32_t));
if (stream->calculateSoftmaxProbs()) {
device_->copy({(*current_softmax_result)[i], (*batch_softmax_result)[i].view(*(new_tokens->dataWithOffset<int32_t>(i)), 1)});
}
if (sampler_output.success && !(*(sampler_output.success->dataWithOffset<bool>(batch_idx)))) {
stream->setStop(ErrorCode::UNKNOWN_ERROR, "sampler generate token id failed");
}
batch_idx += 1;
}
RTP_LLM_LOG_DEBUG("stream [%d], new_tokens = [%s]", stream->streamId(), new_tokens->debugStringWithData<int32_t>().c_str());
if (stream->numBeams() > 1 && beam_index != nullptr) {
StreamUpdateInfo update_info{new_all_token_ids, 1, batch_hidden_states, batch_logits,
current_softmax_result, batch_cum_log_probs, all_probs, loss, all_hidden_states};
stream->update(update_info);
stream->beamSearchKvCacheUpdate(beam_index);
} else {
stream->update({new_tokens, 1, batch_hidden_states, batch_logits,
current_softmax_result, batch_cum_log_probs, all_probs, loss, all_hidden_states});
}
offset += batch;
token_offset += token_size;
}
RTP_LLM_LOG_DEBUG("dispatch done");
return absl::OkStatus();
}
} // namespace rtp_llm