maga_transformer/cpp/model_rpc/QueryConverter.cc (226 lines of code) (raw):
#include "maga_transformer/cpp/model_rpc/QueryConverter.h"
#include "maga_transformer/cpp/core/Buffer.h"
#include "maga_transformer/cpp/core/Types.h"
#include "maga_transformer/cpp/devices/DeviceFactory.h"
namespace rtp_llm {
#define TRANS_OPTIONAL(name) \
if (config_proto->has_##name()) { \
generate_config->name = config_proto->name().value(); \
}
std::shared_ptr<GenerateConfig> QueryConverter::transGenerateConfig(const GenerateConfigPB* config_proto) {
std::shared_ptr<GenerateConfig> generate_config = std::make_shared<GenerateConfig>();
generate_config->global_request_id = config_proto->global_request_id();
generate_config->max_new_tokens = config_proto->max_new_tokens();
generate_config->min_new_tokens = config_proto->min_new_tokens();
generate_config->num_beams = config_proto->num_beams();
generate_config->num_return_sequences = config_proto->num_return_sequences();
generate_config->return_logits = config_proto->return_logits();
generate_config->return_incremental = config_proto->return_incremental();
generate_config->return_hidden_states = config_proto->return_hidden_states();
generate_config->calculate_loss = config_proto->calculate_loss();
generate_config->is_streaming = config_proto->is_streaming();
generate_config->timeout_ms = config_proto->timeout_ms();
generate_config->sp_edit = config_proto->sp_edit();
generate_config->force_disable_sp_run = config_proto->force_disable_sp_run();
generate_config->return_cum_log_probs = config_proto->return_cum_log_probs();
generate_config->return_all_probs = config_proto->return_all_probs();
generate_config->return_softmax_probs = config_proto->return_softmax_probs();
generate_config->can_use_pd_separation = config_proto->can_use_pd_separation();
generate_config->gen_timeline = config_proto->gen_timeline();
generate_config->select_tokens_id.resize(config_proto->select_tokens_id_size());
memcpy(generate_config->select_tokens_id.data(), config_proto->select_tokens_id().data(), config_proto->select_tokens_id_size() * sizeof(int));
for (const auto& stop_words_proto : config_proto->stop_words_list().rows()) {
std::vector<int> stop_words;
for (const int value : stop_words_proto.values()) {
stop_words.push_back(value);
}
generate_config->stop_words_list.push_back(stop_words);
}
for (const auto& token_id : config_proto->sp_advice_prompt_token_ids()) {
generate_config->sp_advice_prompt_token_ids.push_back(token_id);
}
generate_config->top_k = config_proto->top_k();
generate_config->top_p = config_proto->top_p();
generate_config->temperature = config_proto->temperature();
generate_config->repetition_penalty = config_proto->repetition_penalty();
TRANS_OPTIONAL(no_repeat_ngram_size);
TRANS_OPTIONAL(random_seed);
TRANS_OPTIONAL(top_p_decay);
TRANS_OPTIONAL(top_p_min);
TRANS_OPTIONAL(top_p_reset_ids);
TRANS_OPTIONAL(task_id);
TRANS_OPTIONAL(adapter_name);
generate_config->in_think_mode = config_proto->in_think_mode();
generate_config->max_thinking_tokens = config_proto->max_thinking_tokens();
for (const auto& token_id : config_proto->end_think_token_ids()) {
generate_config->end_think_token_ids.push_back(token_id);
}
return generate_config;
}
std::shared_ptr<GenerateInput> QueryConverter::transQuery(const GenerateInputPB* input) {
std::shared_ptr<GenerateInput> generate_input = std::make_shared<GenerateInput>();
generate_input->request_id = input->request_id();
generate_input->begin_time_us = autil::TimeUtility::currentTimeInMicroSeconds();
if (input->has_generate_config()) {
generate_input->generate_config = transGenerateConfig(&(input->generate_config()));
}
auto device = rtp_llm::DeviceFactory::getDefaultDevice();
generate_input->input_ids = device->allocateBuffer(
{rtp_llm::DataType::TYPE_INT32, {(size_t)input->token_ids_size()}, rtp_llm::AllocationType::HOST}, {});
memcpy(generate_input->input_ids->data(), input->token_ids().data(), generate_input->input_ids->sizeBytes());
if (input->multimodal_inputs_size() > 0) {
std::vector<MultimodalInput> mm_inputs;
for (int i = 0;i < input->multimodal_inputs_size();i++) {
auto mm_input = &input->multimodal_inputs(i);
auto mm_preprocess_config = &mm_input->mm_preprocess_config();
mm_inputs.emplace_back(mm_input->multimodal_url(), torch::empty(1), mm_input->multimodal_type(), mm_preprocess_config->width(),
mm_preprocess_config->height(), mm_preprocess_config->min_pixels(), mm_preprocess_config->max_pixels(), mm_preprocess_config->fps(),
mm_preprocess_config->min_frames(), mm_preprocess_config->max_frames());
}
generate_input->multimodal_inputs = std::move(mm_inputs);
}
return generate_input;
}
std::vector<MultimodalInput> QueryConverter::transMMInput(const MultimodalInputsPB* mm_inputs) {
std::vector<MultimodalInput> inputs_vec;
for (int i = 0;i < mm_inputs->multimodal_inputs_size();i++) {
auto mm_input = &mm_inputs->multimodal_inputs(i);
auto mm_preprocess_config = &mm_input->mm_preprocess_config();
// tensor should also converted from input pb, however it is only used in some embedding model, so just empty for now
inputs_vec.emplace_back(mm_input->multimodal_url(), torch::empty(1), mm_input->multimodal_type(), mm_preprocess_config->width(),
mm_preprocess_config->height(), mm_preprocess_config->min_pixels(), mm_preprocess_config->max_pixels(), mm_preprocess_config->fps());
}
return inputs_vec;
}
MultimodalInputsPB QueryConverter::transMMInputsPB(const std::vector<MultimodalInput> mm_inputs) {
MultimodalInputsPB mm_inputs_pb;
for (auto& mm_input: mm_inputs) {
auto now_input = mm_inputs_pb.add_multimodal_inputs();
now_input->set_multimodal_url(mm_input.url);
now_input->set_multimodal_type(mm_input.mm_type);
transTensorPB(now_input->mutable_multimodal_tensor(), rtp_llm::torchTensor2Buffer(mm_input.tensor).get());
transMMPreprocessConfig(now_input->mutable_mm_preprocess_config(), mm_input.mm_preprocess_config);
}
return mm_inputs_pb;
}
void QueryConverter::transMMPreprocessConfig(MMPreprocessConfigPB* config_pb, const MMPreprocessConfig config) {
config_pb->set_width(config.width);
config_pb->set_height(config.height);
config_pb->set_min_pixels(config.min_pixels);
config_pb->set_max_pixels(config.max_pixels);
config_pb->set_fps(config.fps);
}
MultimodalOutput QueryConverter::transMMOutput(const MultimodalOutputsPB* outputs_pb) {
MultimodalOutput mm_output;
for (int i = 0;i < outputs_pb->multimodal_outputs_size();i++) {
auto output_pb = outputs_pb->multimodal_outputs(i);
mm_output.mm_features.emplace_back(transTensor(output_pb.multimodal_embedding()));
if (output_pb.has_multimodal_pos_id()) {
if (mm_output.mm_position_ids == std::nullopt) {
mm_output.mm_position_ids = std::vector<torch::Tensor>();
}
mm_output.mm_position_ids.value().emplace_back(transTensor(output_pb.multimodal_pos_id()));
}
}
return mm_output;
}
torch::Tensor QueryConverter::transTensor(const TensorPB& tensor_pb) {
std::vector<int64_t> shape(tensor_pb.shape().begin(), tensor_pb.shape().end());
void* data_ptr = nullptr;
switch (tensor_pb.data_type()) {
case TensorPB::FP32: {
data_ptr = const_cast<char*>(tensor_pb.fp32_data().data());
auto options = torch::TensorOptions().dtype(torch::kFloat32);
return torch::from_blob(data_ptr, shape, options).clone();
}
case TensorPB::INT32: {
data_ptr = const_cast<char*>(tensor_pb.int32_data().data());
auto options = torch::TensorOptions().dtype(torch::kInt32);
return torch::from_blob(data_ptr, shape, options).clone();
}
case TensorPB::FP16: {
data_ptr = const_cast<char*>(tensor_pb.fp16_data().data());
auto options = torch::TensorOptions().dtype(torch::kFloat16);
return torch::from_blob(data_ptr, shape, options).clone();
}
case TensorPB::BF16: {
data_ptr = const_cast<char*>(tensor_pb.bf16_data().data());
auto options = torch::TensorOptions().dtype(torch::kBFloat16);
return torch::from_blob(data_ptr, shape, options).clone();
}
default:
throw std::runtime_error("Unsupported data type.");
}
}
void QueryConverter::transTensorPB(TensorPB* t, const rtp_llm::Buffer* buffer) {
RTP_LLM_CHECK(t != nullptr);
auto shape = t->mutable_shape();
auto shape_array = buffer->shape();
shape->Resize(shape_array.size(), 0);
memcpy(shape->mutable_data(), shape_array.data(), shape_array.size() * sizeof(int64_t));
TensorPB_DataType data_type;
switch(buffer->type()) {
case rtp_llm::DataType::TYPE_FP32:
data_type = TensorPB_DataType::TensorPB_DataType_FP32;
t->set_fp32_data(reinterpret_cast<const char*>(buffer->data()), buffer->sizeBytes());
break;
case rtp_llm::DataType::TYPE_INT32:
data_type = TensorPB_DataType::TensorPB_DataType_INT32;
t->set_int32_data(reinterpret_cast<const char*>(buffer->data()), buffer->sizeBytes());
break;
case rtp_llm::DataType::TYPE_FP16:
data_type = TensorPB_DataType::TensorPB_DataType_FP16;
t->set_fp16_data(reinterpret_cast<const char*>(buffer->data()), buffer->sizeBytes());
break;
case rtp_llm::DataType::TYPE_BF16:
data_type = TensorPB_DataType::TensorPB_DataType_BF16;
t->set_bf16_data(reinterpret_cast<const char*>(buffer->data()), buffer->sizeBytes());
break;
default:
throw std::invalid_argument("unsupport buffer data type: " + std::to_string(buffer->type()));
break;
}
t->set_data_type(data_type);
}
void QueryConverter::transResponse(GenerateOutputsPB* outputs, const GenerateOutputs* responses) {
RTP_LLM_LOG_DEBUG(__PRETTY_FUNCTION__);
outputs->set_request_id(responses->request_id);
for (size_t i = 0; i < responses->generate_outputs.size(); i++) {
const auto& response = responses->generate_outputs[i];
GenerateOutputPB* output = outputs->add_generate_outputs();
output->set_finished(response.finished);
auto aux_info = output->mutable_aux_info();
aux_info->set_cost_time_us(response.aux_info.cost_time_us);
aux_info->set_first_token_cost_time_us(response.aux_info.first_token_cost_time_us);
aux_info->set_wait_time_us(response.aux_info.wait_time_us);
aux_info->set_iter_count(response.aux_info.iter_count);
aux_info->set_fallback_tokens(response.aux_info.fallback_tokens);
aux_info->set_fallback_times(response.aux_info.fallback_times);
aux_info->set_input_len(response.aux_info.input_len);
aux_info->set_reuse_len(response.aux_info.reuse_len);
aux_info->set_prefix_len(response.aux_info.prefix_len);
aux_info->set_output_len(response.aux_info.output_len);
aux_info->set_step_output_len(response.aux_info.step_output_len);
aux_info->set_pd_sep(response.aux_info.pd_sep);
if (response.aux_info.cum_log_probs.has_value()) {
transTensorPB(aux_info->mutable_cum_log_probs(), response.aux_info.cum_log_probs.value().get());
}
if (response.aux_info.softmax_probs.has_value()) {
transTensorPB(aux_info->mutable_softmax_probs(), response.aux_info.softmax_probs.value().get());
}
if (response.aux_info.all_probs.has_value()) {
transTensorPB(output->mutable_all_probs(), response.aux_info.all_probs.value().get());
}
transTensorPB(output->mutable_output_ids(), response.output_ids.get());
if (response.hidden_states.has_value()) {
transTensorPB(output->mutable_hidden_states(), response.hidden_states.value().get());
}
if (response.loss.has_value()) {
transTensorPB(output->mutable_loss(), response.loss.value().get());
}
if (response.logits.has_value()) {
transTensorPB(output->mutable_logits(), response.logits.value().get());
}
}
RTP_LLM_LOG_DEBUG("transResponse done");
}
} // namespace rtp_llm