in maga_transformer/cpp/normal_engine/NormalBatchStreamProcessor.cc [22:213]
absl::StatusOr<GptModelInputs> NormalBatchStreamProcessor::gatherModelInput(const StreamGroups& stream_groups) const {
RTP_LLM_LOG_DEBUG(__PRETTY_FUNCTION__);
auto context_streams = stream_groups.contextStreams();
auto decode_streams = stream_groups.decodeStreams();
RTP_LLM_LOG_DEBUG("context_streams size = %d, decode_streams size = %d", context_streams.size(), decode_streams.size());
GptModelInputs model_input;
size_t current_tokens_size = stream_groups.modelExecuteTokenSize();
size_t total_batch_size = stream_groups.totalModelBatchSize();
size_t total_decode_batch_size = stream_groups.totalDecodeBatchSize();
size_t total_context_batch_size = stream_groups.totalContextBatchSize();
size_t max_block_size = stream_groups.maxBlockSize();
size_t multimodal_features_len = stream_groups.mmFeaturesLen();
const bool has_multimodal_input = is_multimodal_ && stream_groups.has_multimodal_input();
const bool need_cal_position_id = (mm_position_ids_style_ != PositionIdsStyle::DEFAULT) || has_positional_encoding_;
model_input.combo_tokens =
device_->allocateBuffer({rtp_llm::DataType::TYPE_INT32, {current_tokens_size}, rtp_llm::AllocationType::HOST}, {});
if (max_block_size) {
model_input.kv_cache_block_id = device_->allocateBuffer(
{rtp_llm::DataType::TYPE_INT32, {total_batch_size, max_block_size}, rtp_llm::AllocationType::HOST}, {});
model_input.cache_keys = device_->allocateBuffer(
{rtp_llm::DataType::TYPE_INT64, {total_context_batch_size, max_block_size}, rtp_llm::AllocationType::HOST}, {});
}
model_input.request_id = device_->allocateBuffer(
{rtp_llm::DataType::TYPE_INT64, {total_context_batch_size}, rtp_llm::AllocationType::HOST}, {});
model_input.request_pd_separation = device_->allocateBuffer(
{rtp_llm::DataType::TYPE_BOOL, {total_context_batch_size}, rtp_llm::AllocationType::HOST}, {});
model_input.input_lengths =
device_->allocateBuffer({rtp_llm::DataType::TYPE_INT32, {total_batch_size}, rtp_llm::AllocationType::HOST}, {});
model_input.lora_ids =
device_->allocateBuffer({rtp_llm::DataType::TYPE_INT32, {total_batch_size}, rtp_llm::AllocationType::HOST}, {});
model_input.lora_input_lengths =
device_->allocateBuffer({rtp_llm::DataType::TYPE_INT32, {total_batch_size}, rtp_llm::AllocationType::HOST}, {});
model_input.sequence_lengths =
device_->allocateBuffer({rtp_llm::DataType::TYPE_INT32, {total_decode_batch_size}, rtp_llm::AllocationType::HOST}, {});
model_input.lm_output_indexes =
device_->allocateBuffer({rtp_llm::DataType::TYPE_INT32, {total_batch_size}, rtp_llm::AllocationType::HOST}, {});
model_input.prefix_lengths =
device_->allocateBuffer({rtp_llm::DataType::TYPE_INT32, {total_context_batch_size}, rtp_llm::AllocationType::HOST}, {});
if (need_cal_position_id) {
model_input.combo_position_ids =
device_->allocateBuffer({rtp_llm::DataType::TYPE_INT32, {current_tokens_size * position_id_len_factor_}, rtp_llm::AllocationType::HOST}, {});
}
if (has_multimodal_input) {
model_input.text_tokens_mask =
device_->allocateBuffer({rtp_llm::DataType::TYPE_INT32, {current_tokens_size}, rtp_llm::AllocationType::HOST}, {});
model_input.mm_features_locs =
device_->allocateBuffer({rtp_llm::DataType::TYPE_INT32, {multimodal_features_len}, rtp_llm::AllocationType::HOST}, {});
}
model_input.k_block_size = k_block_size_;
model_input.v_block_size = v_block_size_;
model_input.seq_size_per_block = seq_size_per_block_;
model_input.scale_block_size = scale_block_size_;
model_input.pd_separation = pd_separation_;
model_input.warmup = warm_up_;
int* merged_tokens = (int*)model_input.combo_tokens->data();
int* input_lengths = (int*)model_input.input_lengths->data();
int* lora_ids = (int*)model_input.lora_ids->data();
int* lora_input_lengths = (int*)model_input.lora_input_lengths->data();
int* sequence_lengths = (int*)model_input.sequence_lengths->data();
int* lm_output_indexes = (int*)model_input.lm_output_indexes->data();
int* prefix_lengths = (int*)model_input.prefix_lengths->data();
int* combo_position_ids = need_cal_position_id ? (int*)model_input.combo_position_ids->data() : nullptr;
int* merged_text_mask = has_multimodal_input ? (int*)model_input.text_tokens_mask->data() : nullptr;
int* mm_features_locs = has_multimodal_input ? (int*)model_input.mm_features_locs->data() : nullptr;
int batch_idx = 0;
int input_vocab_size = input_vocab_size_ ? input_vocab_size_ : vocab_size_;
if (merged_text_mask) {
std::fill(merged_text_mask, merged_text_mask + current_tokens_size, 1);
}
for (const auto& stream : decode_streams) {
model_input.need_all_logits = model_input.need_all_logits || stream->calculateLoss();
auto current_batch_size = stream->batchSize();
auto kv_cache = stream->kvCache();
RTP_LLM_LOG_DEBUG("decode kv_cache: %s", kv_cache.debugString().c_str());
RTP_LLM_LOG_DEBUG("decode stream: %s", stream->debugString().c_str());
for (auto i = 0; i < current_batch_size; ++i) {
auto currentTokens = stream->currentExecuteTokens(i);
if (currentTokens[0] >= input_vocab_size) {
std::ostringstream error_msg;
error_msg << "stream [" << stream->streamId() << "] token_id " << currentTokens[0] << " exceed vocab_size " << input_vocab_size;
return absl::InvalidArgumentError(error_msg.str());
}
merged_tokens[batch_idx] = currentTokens[0];
input_lengths[batch_idx] = stream->inputLength();
sequence_lengths[batch_idx] = stream->seqLength() - 1; // need remove
if (need_cal_position_id) {
stream->generateNextPositionId(combo_position_ids + batch_idx * position_id_len_factor_);
}
lora_ids[batch_idx] = stream->loraId();
lora_input_lengths[batch_idx] = 1;
lm_output_indexes[batch_idx] = batch_idx;
if (max_block_size) {
std::memcpy((*model_input.kv_cache_block_id)[batch_idx].data(),
kv_cache.batch_block_id[i].data(),
kv_cache.batch_block_id[i].size() * sizeof(int));
}
batch_idx += 1;
}
stream->step();
}
std::vector<rtp_llm::BufferPtr> gathered_mm_features;
int token_idx = batch_idx;
int cum_output_seq_len = batch_idx;
int mm_feature_index = 0;
for (const auto& stream : context_streams) {
// context stream也需要batch运行是为了fallback的场景和perf test的场景
model_input.need_all_logits = model_input.need_all_logits || stream->calculateLoss();
auto current_batch_size = stream->batchSize();
auto kv_cache = stream->kvCache();
RTP_LLM_LOG_DEBUG("context kv_cache: %s", kv_cache.debugString().c_str());
RTP_LLM_LOG_DEBUG("context stream: %s", stream->debugString().c_str());
// TODO(xinfei.sxf) deal with adjusted common seq len.
for (auto i = 0; i < current_batch_size; ++i) {
auto input_tokens = stream->currentExecuteTokens(i);
auto input_masks = stream->textTokensMask();
memcpy(merged_tokens + token_idx, input_tokens.data(), input_tokens.size() * sizeof(int));
cum_output_seq_len += input_tokens.size();
for (int index = 0; index < input_tokens.size(); ++index) {
if (input_tokens[index] >= input_vocab_size && (index >= input_masks.size() || input_masks[index])) {
std::ostringstream error_msg;
error_msg << "stream [" << stream->streamId() << "] token_id " << input_tokens[index] << " exceed vocab_size " << input_vocab_size;
return absl::InvalidArgumentError(error_msg.str());
}
}
input_lengths[batch_idx] = input_tokens.size();
prefix_lengths[batch_idx - total_decode_batch_size] = stream->prefixLength();
lm_output_indexes[batch_idx] = cum_output_seq_len - 1;
if (has_multimodal_input) {
std::vector<torch::Tensor> mm_features = stream->multimodalFeatures();
rtp_llm::BufferPtr mm_locs = stream->multimodalLocations();
if (mm_locs != nullptr) {
for (int i = 0;i < mm_locs->size(); ++i) {
mm_features_locs[mm_feature_index] = *mm_locs->dataWithOffset<int>(i) + token_idx - stream->reuseLength();
mm_feature_index++;
}
for (auto& mm_feature: mm_features) {
auto feature_buffer = torchTensor2Buffer(mm_feature);
if (feature_buffer->where() != rtp_llm::MemoryType::MEMORY_GPU) {
gathered_mm_features.emplace_back(device_->clone({*feature_buffer}));
} else {
gathered_mm_features.emplace_back(feature_buffer);
}
}
auto text_token_mask = stream->textTokensMask();
memcpy(merged_text_mask + token_idx, text_token_mask.data(), text_token_mask.size() * sizeof(int));
}
}
if (need_cal_position_id) {
auto context_pos_ids = stream->generateContextPositionIds(device_);
memcpy(combo_position_ids + token_idx * position_id_len_factor_,
context_pos_ids->dataWithOffset<int>(stream->reuseLength() * position_id_len_factor_),
(context_pos_ids->size() - stream->reuseLength() * position_id_len_factor_) * context_pos_ids->typeSize());
}
lora_ids[batch_idx] = stream->loraId();
lora_input_lengths[batch_idx] = input_lengths[batch_idx];
if (max_block_size) {
std::memcpy((*model_input.kv_cache_block_id)[batch_idx].data(),
kv_cache.batch_block_id[i].data(),
kv_cache.batch_block_id[i].size() * sizeof(int));
if (stream->hasCacheKeys()) {
std::memcpy((*model_input.cache_keys)[batch_idx - total_decode_batch_size].data(),
stream->cacheKeys(i).data(),
stream->cacheKeys(i).size() * sizeof(int64_t));
}
}
*(model_input.request_id->dataWithOffset<int64_t>(batch_idx - total_decode_batch_size)) = stream->streamId();
*(model_input.request_pd_separation->dataWithOffset<bool>(batch_idx - total_decode_batch_size)) = stream->queryPdSep();
batch_idx += 1;
token_idx += input_tokens.size();
}
stream->step();
}
if (is_multimodal_ && gathered_mm_features.size() > 0) {
model_input.multimodal_features = std::move(gathered_mm_features);
}
return model_input;
}