absl::StatusOr NormalBatchStreamProcessor::gatherModelInput()

in maga_transformer/cpp/normal_engine/NormalBatchStreamProcessor.cc [22:213]


absl::StatusOr<GptModelInputs> NormalBatchStreamProcessor::gatherModelInput(const StreamGroups& stream_groups) const {
    RTP_LLM_LOG_DEBUG(__PRETTY_FUNCTION__);
    auto           context_streams = stream_groups.contextStreams();
    auto           decode_streams  = stream_groups.decodeStreams();
    RTP_LLM_LOG_DEBUG("context_streams size = %d, decode_streams size = %d", context_streams.size(), decode_streams.size());
    GptModelInputs model_input;
    size_t         current_tokens_size      = stream_groups.modelExecuteTokenSize();
    size_t         total_batch_size         = stream_groups.totalModelBatchSize();
    size_t         total_decode_batch_size  = stream_groups.totalDecodeBatchSize();
    size_t         total_context_batch_size  = stream_groups.totalContextBatchSize();
    size_t         max_block_size           = stream_groups.maxBlockSize();
    size_t         multimodal_features_len  = stream_groups.mmFeaturesLen();

    const bool has_multimodal_input = is_multimodal_ && stream_groups.has_multimodal_input();
    const bool need_cal_position_id = (mm_position_ids_style_ != PositionIdsStyle::DEFAULT) || has_positional_encoding_;

    model_input.combo_tokens =
        device_->allocateBuffer({rtp_llm::DataType::TYPE_INT32, {current_tokens_size}, rtp_llm::AllocationType::HOST}, {});
    if (max_block_size) {
        model_input.kv_cache_block_id = device_->allocateBuffer(
                {rtp_llm::DataType::TYPE_INT32, {total_batch_size, max_block_size}, rtp_llm::AllocationType::HOST}, {});
        model_input.cache_keys = device_->allocateBuffer(
            {rtp_llm::DataType::TYPE_INT64, {total_context_batch_size, max_block_size}, rtp_llm::AllocationType::HOST}, {});
    }
    model_input.request_id = device_->allocateBuffer(
            {rtp_llm::DataType::TYPE_INT64, {total_context_batch_size}, rtp_llm::AllocationType::HOST}, {});
    model_input.request_pd_separation = device_->allocateBuffer(
            {rtp_llm::DataType::TYPE_BOOL, {total_context_batch_size}, rtp_llm::AllocationType::HOST}, {});
    model_input.input_lengths =
        device_->allocateBuffer({rtp_llm::DataType::TYPE_INT32, {total_batch_size}, rtp_llm::AllocationType::HOST}, {});
    model_input.lora_ids =
        device_->allocateBuffer({rtp_llm::DataType::TYPE_INT32, {total_batch_size}, rtp_llm::AllocationType::HOST}, {});
    model_input.lora_input_lengths =
        device_->allocateBuffer({rtp_llm::DataType::TYPE_INT32, {total_batch_size}, rtp_llm::AllocationType::HOST}, {});
    model_input.sequence_lengths =
        device_->allocateBuffer({rtp_llm::DataType::TYPE_INT32, {total_decode_batch_size}, rtp_llm::AllocationType::HOST}, {});
    model_input.lm_output_indexes =
        device_->allocateBuffer({rtp_llm::DataType::TYPE_INT32, {total_batch_size}, rtp_llm::AllocationType::HOST}, {});
    model_input.prefix_lengths =
        device_->allocateBuffer({rtp_llm::DataType::TYPE_INT32, {total_context_batch_size}, rtp_llm::AllocationType::HOST}, {});
    if (need_cal_position_id) {
        model_input.combo_position_ids =
            device_->allocateBuffer({rtp_llm::DataType::TYPE_INT32, {current_tokens_size * position_id_len_factor_}, rtp_llm::AllocationType::HOST}, {});
    }
    if (has_multimodal_input) {
        model_input.text_tokens_mask =
            device_->allocateBuffer({rtp_llm::DataType::TYPE_INT32, {current_tokens_size}, rtp_llm::AllocationType::HOST}, {});
        model_input.mm_features_locs =
            device_->allocateBuffer({rtp_llm::DataType::TYPE_INT32, {multimodal_features_len}, rtp_llm::AllocationType::HOST}, {});
    }
    model_input.k_block_size = k_block_size_;
    model_input.v_block_size = v_block_size_;
    model_input.seq_size_per_block = seq_size_per_block_;
    model_input.scale_block_size = scale_block_size_;
    model_input.pd_separation = pd_separation_;
    model_input.warmup = warm_up_;

    int*      merged_tokens    = (int*)model_input.combo_tokens->data();
    int*      input_lengths    = (int*)model_input.input_lengths->data();
    int*      lora_ids         = (int*)model_input.lora_ids->data();
    int*      lora_input_lengths = (int*)model_input.lora_input_lengths->data();
    int*      sequence_lengths = (int*)model_input.sequence_lengths->data();
    int*      lm_output_indexes = (int*)model_input.lm_output_indexes->data();
    int*      prefix_lengths   = (int*)model_input.prefix_lengths->data();
    int*      combo_position_ids = need_cal_position_id ? (int*)model_input.combo_position_ids->data() : nullptr;
    int*      merged_text_mask = has_multimodal_input ? (int*)model_input.text_tokens_mask->data() : nullptr;
    int*      mm_features_locs = has_multimodal_input ? (int*)model_input.mm_features_locs->data() : nullptr;
    int       batch_idx        = 0;
    int       input_vocab_size = input_vocab_size_ ? input_vocab_size_ : vocab_size_;

    if (merged_text_mask) {
        std::fill(merged_text_mask, merged_text_mask + current_tokens_size, 1);
    }

    for (const auto& stream : decode_streams) {
        model_input.need_all_logits = model_input.need_all_logits || stream->calculateLoss();
        auto current_batch_size = stream->batchSize();
        auto kv_cache = stream->kvCache();
        RTP_LLM_LOG_DEBUG("decode kv_cache: %s", kv_cache.debugString().c_str());
        RTP_LLM_LOG_DEBUG("decode stream: %s", stream->debugString().c_str());
        for (auto i = 0; i < current_batch_size; ++i) {
            auto currentTokens      = stream->currentExecuteTokens(i);
            if (currentTokens[0] >= input_vocab_size) {
                std::ostringstream error_msg;
                error_msg << "stream [" << stream->streamId() << "] token_id " << currentTokens[0] << " exceed vocab_size " << input_vocab_size;
                return absl::InvalidArgumentError(error_msg.str());
            }
            merged_tokens[batch_idx] = currentTokens[0];
            input_lengths[batch_idx]    = stream->inputLength();
            sequence_lengths[batch_idx] = stream->seqLength() - 1; // need remove
            if (need_cal_position_id) {
                stream->generateNextPositionId(combo_position_ids + batch_idx * position_id_len_factor_);
            }
            lora_ids[batch_idx]         = stream->loraId();
            lora_input_lengths[batch_idx] = 1;
            lm_output_indexes[batch_idx] = batch_idx;
            if (max_block_size) {
                std::memcpy((*model_input.kv_cache_block_id)[batch_idx].data(),
                            kv_cache.batch_block_id[i].data(),
                            kv_cache.batch_block_id[i].size() * sizeof(int));
            }
            batch_idx += 1;
        }
        stream->step();
    }

    std::vector<rtp_llm::BufferPtr> gathered_mm_features;
    int token_idx = batch_idx;
    int cum_output_seq_len = batch_idx;
    int mm_feature_index = 0;

    for (const auto& stream : context_streams) {
        // context stream也需要batch运行是为了fallback的场景和perf test的场景
        model_input.need_all_logits = model_input.need_all_logits || stream->calculateLoss();
        auto current_batch_size = stream->batchSize();
        auto kv_cache                 = stream->kvCache();
        RTP_LLM_LOG_DEBUG("context kv_cache: %s", kv_cache.debugString().c_str());
        RTP_LLM_LOG_DEBUG("context stream: %s", stream->debugString().c_str());

        // TODO(xinfei.sxf) deal with adjusted common seq len.
        for (auto i = 0; i < current_batch_size; ++i) {
            auto input_tokens    = stream->currentExecuteTokens(i);
            auto input_masks     = stream->textTokensMask();
            memcpy(merged_tokens + token_idx, input_tokens.data(), input_tokens.size() * sizeof(int));
            cum_output_seq_len += input_tokens.size();

            for (int index = 0; index < input_tokens.size(); ++index) {
                if (input_tokens[index] >= input_vocab_size && (index >= input_masks.size() || input_masks[index])) {
                    std::ostringstream error_msg;
                    error_msg << "stream [" << stream->streamId() << "] token_id " << input_tokens[index] << " exceed vocab_size " << input_vocab_size;
                    return absl::InvalidArgumentError(error_msg.str());
                }
            }

            input_lengths[batch_idx] = input_tokens.size();
            prefix_lengths[batch_idx - total_decode_batch_size] = stream->prefixLength();
            lm_output_indexes[batch_idx] = cum_output_seq_len - 1;

            if (has_multimodal_input) {
                std::vector<torch::Tensor> mm_features = stream->multimodalFeatures();
                rtp_llm::BufferPtr mm_locs = stream->multimodalLocations();
                if (mm_locs != nullptr) {
                    for (int i = 0;i < mm_locs->size(); ++i) {
                        mm_features_locs[mm_feature_index] = *mm_locs->dataWithOffset<int>(i) + token_idx - stream->reuseLength();
                        mm_feature_index++;
                    }
                    for (auto& mm_feature: mm_features) {
                        auto feature_buffer = torchTensor2Buffer(mm_feature);
                        if (feature_buffer->where() != rtp_llm::MemoryType::MEMORY_GPU) {
                            gathered_mm_features.emplace_back(device_->clone({*feature_buffer}));
                        } else {
                            gathered_mm_features.emplace_back(feature_buffer);
                        }

                    }
                    auto text_token_mask = stream->textTokensMask();
                    memcpy(merged_text_mask + token_idx, text_token_mask.data(), text_token_mask.size() * sizeof(int));
                }
            }

            if (need_cal_position_id) {
                auto context_pos_ids = stream->generateContextPositionIds(device_);
                memcpy(combo_position_ids + token_idx * position_id_len_factor_,
                       context_pos_ids->dataWithOffset<int>(stream->reuseLength() * position_id_len_factor_),
                       (context_pos_ids->size() - stream->reuseLength() * position_id_len_factor_) * context_pos_ids->typeSize());
            }
            lora_ids[batch_idx]           = stream->loraId();
            lora_input_lengths[batch_idx] = input_lengths[batch_idx];
            if (max_block_size) {
                std::memcpy((*model_input.kv_cache_block_id)[batch_idx].data(),
                            kv_cache.batch_block_id[i].data(),
                            kv_cache.batch_block_id[i].size() * sizeof(int));
                if (stream->hasCacheKeys()) {
                    std::memcpy((*model_input.cache_keys)[batch_idx - total_decode_batch_size].data(),
                        stream->cacheKeys(i).data(),
                        stream->cacheKeys(i).size() * sizeof(int64_t));
                }
            }
            *(model_input.request_id->dataWithOffset<int64_t>(batch_idx - total_decode_batch_size)) = stream->streamId();
            *(model_input.request_pd_separation->dataWithOffset<bool>(batch_idx - total_decode_batch_size)) = stream->queryPdSep();
            batch_idx += 1;
            token_idx += input_tokens.size();
        }

        stream->step();
    }

    if (is_multimodal_ && gathered_mm_features.size() > 0) {
        model_input.multimodal_features = std::move(gathered_mm_features);
    }
    return model_input;
}