maga_transformer/cpp/normal_engine/NormalBatchStreamProcessor.h (55 lines of code) (raw):
#pragma once
#include "maga_transformer/cpp/devices/DeviceBase.h"
#include "maga_transformer/cpp/th_op/GptInitParameter.h"
#include "maga_transformer/cpp/dataclass/MergedQuery.h"
#include "maga_transformer/cpp/stream/StreamGroups.h"
#include "absl/status/statusor.h"
#include "absl/status/status.h"
namespace rtp_llm {
class NormalBatchStreamProcessor {
public:
NormalBatchStreamProcessor(const rtp_llm::GptInitParameter& params,
const CacheConfig& cache_config, bool warm_up):
num_layers_(params.num_layers_),
vocab_size_(params.vocab_size_),
input_vocab_size_(params.input_vocab_size_),
use_int8_kv_cache_(params.kv_cache_data_type_ == rtp_llm::DataType::TYPE_INT8),
has_positional_encoding_(params.has_positional_encoding_),
is_multimodal_(params.is_multimodal_),
mm_position_ids_style_((PositionIdsStyle)params.mm_position_ids_style_),
position_id_len_factor_(params.position_id_len_factor_),
pd_separation_(params.pd_separation_),
k_block_size_(cache_config.k_block_stride),
v_block_size_(cache_config.v_block_stride),
scale_block_size_(cache_config.kv_scale_block_stride),
seq_size_per_block_(cache_config.seq_size_per_block),
warm_up_(warm_up),
device_(rtp_llm::DeviceFactory::getDefaultDevice()) {}
virtual absl::Status dispatch(const StreamGroups& stream_groups,
const MergedOutput& merge_outputs) const;
virtual absl::StatusOr<GptModelInputs> gatherModelInput(const StreamGroups& stream_groups) const;
virtual absl::StatusOr<SamplerInputs> gatherSamplerInput(const StreamGroups& stream_groups,
const GptModelInputs& model_inputs,
const GptModelOutputs& model_output) const;
protected:
SamplerInputs allocateSamplerInputs(const StreamGroups& stream_groups, size_t total_batch_size, const rtp_llm::BufferPtr& sequence_length) const;
void setCommonSamplerInputs(SamplerInputs& sampler_inputs, std::list<GenerateStreamPtr>& all_streams, bool score_batch = false) const;
void setThinkModeLogitsProcessorInputs(SamplerInputs& sampler_inputs, std::list<GenerateStreamPtr>& all_streams, bool score_batch = false) const;
protected:
size_t num_layers_;
size_t vocab_size_;
size_t input_vocab_size_;
bool use_int8_kv_cache_;
bool has_positional_encoding_;
bool is_multimodal_;
PositionIdsStyle mm_position_ids_style_;
size_t position_id_len_factor_;
bool pd_separation_;
// size_t block_size_;
size_t k_block_size_;
size_t v_block_size_;
size_t scale_block_size_;
size_t seq_size_per_block_;
bool warm_up_;
rtp_llm::DeviceBase* device_;
};
} // namespace rtp_llm