maga_transformer/cpp/dataclass/Query.h (149 lines of code) (raw):
#pragma once
#include <cstdint>
#include <optional>
#include <sstream>
#include <string>
#include <torch/python.h>
#include "maga_transformer/cpp/dataclass/GenerateConfig.h"
#include "maga_transformer/cpp/utils/ErrorCode.h"
#include "maga_transformer/cpp/position_ids_generator/PositionIdsGenerator.h"
#include "maga_transformer/cpp/core/Buffer.h"
#include "maga_transformer/cpp/core/BufferHelper.h"
#include "maga_transformer/cpp/devices/DeviceFactory.h"
namespace rtp_llm {
struct MMPreprocessConfig {
int32_t width = -1;
int32_t height = -1;
int32_t min_pixels = -1;
int32_t max_pixels = -1;
int32_t fps = -1;
int32_t min_frames = -1;
int32_t max_frames = -1;
MMPreprocessConfig(int32_t width = -1, int32_t height = -1, int32_t min_pixels = -1, int32_t max_pixels = -1, int32_t fps = -1, int32_t min_frames = -1, int32_t max_frames = -1):
width(width), height(height), min_pixels(min_pixels), max_pixels(max_pixels), fps(fps), min_frames(min_frames), max_frames(max_frames) {}
};
struct MultimodalInput {
// public:
std::string url;
torch::Tensor tensor = torch::empty({0});
int32_t mm_type = 0;
MMPreprocessConfig mm_preprocess_config = MMPreprocessConfig();
MultimodalInput(std::string url, torch::Tensor t, int32_t mm_type = 0, int32_t width = -1, int32_t height = -1, int32_t min_pixels = -1, int32_t max_pixels = -1, int32_t fps = -1, int32_t min_frames = -1, int32_t max_frames = -1):
url(url), tensor(t), mm_type(mm_type), mm_preprocess_config(MMPreprocessConfig(width, height, min_pixels, max_pixels, fps, min_frames, max_frames)) {}
MultimodalInput(std::string url):
url(url), tensor(torch::empty(0)) {}
};
struct MultimodalOutput {
std::vector<torch::Tensor> mm_features = {};
std::optional<std::vector<torch::Tensor>> mm_position_ids = std::nullopt;
};
class MultimodalFeature {
public:
std::vector<torch::Tensor> features;
std::vector<MultimodalInput> inputs;
rtp_llm::BufferPtr text_tokens_mask; // text part for 1 and multimodal part for 0
rtp_llm::BufferPtr locs; // multimodal input locations
rtp_llm::BufferPtr expanded_ids;
MultimodalFeature() {}
};
class GenerateInput {
public:
int inputLength() {
RTP_LLM_CHECK(input_ids->shape().size() == 1);
return input_ids->shape()[0];
}
int promptLength() {
return inputLength() - prefix_length;
}
std::string debugString() const {
std::stringstream debug_string;
debug_string << "GenerateInput {"
<< "request_id: " << request_id << ", generate_config:" << generate_config->debugString()
<< ", input_ids:" << input_ids->debugStringWithData<int32_t>() << ", lora_id: " << lora_id
<< ", prefix_length:" << prefix_length << "}";
return debug_string.str();
}
void updatePrefix(const std::vector<int>& prefix_prompt) {
prefix_length = prefix_prompt.size();
auto device = rtp_llm::DeviceFactory::getDefaultDevice();
input_ids = device->concat({{rtp_llm::vector2Buffer(prefix_prompt), input_ids}});
}
public:
int64_t request_id = 0;
std::shared_ptr<GenerateConfig> generate_config;
rtp_llm::BufferPtr input_ids;
int lora_id = -1;
bool need_release_resource = true;
bool fake_query = false;
// For multi-modality models
std::optional<std::vector<MultimodalInput>> multimodal_inputs;
std::optional<std::vector<torch::Tensor>> multimodal_features;
std::optional<rtp_llm::BufferPtr> text_tokens_mask; // text part for 1 and multimodal part for 0
std::optional<rtp_llm::BufferPtr> mm_locs; // multimodal input locations
std::optional<std::vector<torch::Tensor>> mm_position_ids;
int prefix_length = 0;
int64_t begin_time_us = 0;
};
class AuxInfo {
public:
int cost_time_us = 0;
int first_token_cost_time_us = 0;
int wait_time_us = 0;
int iter_count = 0;
int input_len = 0;
int prefix_len = 0;
int reuse_len = 0;
int output_len = 0;
int fallback_tokens = 0;
int fallback_times = 0;
int step_output_len = 0;
bool pd_sep = false;
std::optional<rtp_llm::ConstBufferPtr> cum_log_probs;
std::optional<rtp_llm::ConstBufferPtr> all_probs;
std::optional<rtp_llm::ConstBufferPtr> softmax_probs;
};
class GenerateOutput {
public:
rtp_llm::ConstBufferPtr output_ids;
bool finished;
AuxInfo aux_info;
ErrorInfo error_info;
std::optional<rtp_llm::ConstBufferPtr> hidden_states;
std::optional<rtp_llm::ConstBufferPtr> logits;
std::optional<rtp_llm::ConstBufferPtr> loss;
};
class GenerateOutputs {
public:
std::vector<GenerateOutput> generate_outputs;
int64_t request_id;
};
enum class StreamState {
WAITING,
RUNNING,
PAUSED,
STOPPED,
FINISHED,
REMOTE_RUNNING
};
inline std::string StreamStateToString(StreamState state) {
switch (state) {
case StreamState::WAITING:
return "WAITING";
case StreamState::RUNNING:
return "RUNNING";
case StreamState::PAUSED:
return "PAUSED";
case StreamState::STOPPED:
return "STOPPED";
case StreamState::FINISHED:
return "FINISHED";
case StreamState::REMOTE_RUNNING:
return "REMOTE_RUNNING";
default:
return "Error: Unrecognized Generate State";
}
}
struct GenerateStatus {
StreamState status = StreamState::WAITING;
ErrorInfo error_info;
};
void registerMultimodalInput(const py::module& m);
} // namespace rtp_llm