maga_transformer/cpp/multimodal_processor/MultimodalProcessor.h (39 lines of code) (raw):
#pragma once
#include <vector>
#include <torch/python.h>
#include "maga_transformer/cpp/dataclass/Query.h"
#include "maga_transformer/cpp/utils/ErrorCode.h"
#include "maga_transformer/cpp/utils/StatusUtil.h"
#include "maga_transformer/cpp/utils/PyUtils.h"
#include "maga_transformer/cpp/core/Buffer.h"
#include "maga_transformer/cpp/th_op/GptInitParameter.h"
namespace py = pybind11;
namespace rtp_llm {
struct ExpandedOutput {
rtp_llm::BufferPtr expanded_ids;
rtp_llm::BufferPtr text_tokens_mask;
rtp_llm::BufferPtr locs;
ExpandedOutput(rtp_llm::BufferPtr expanded_ids = nullptr, rtp_llm::BufferPtr text_tokens_mask = nullptr, rtp_llm::BufferPtr locs = nullptr):
expanded_ids(expanded_ids), text_tokens_mask(text_tokens_mask), locs(locs) {}
};
class MultimodalProcessor {
public:
MultimodalProcessor(py::object mm_process_engine, rtp_llm::GptInitParameter params):
mm_process_engine_(mm_process_engine), gpt_init_parameter_(params), sep_token_ids_(params.mm_sep_tokens_), include_sep_tokens_(params.include_sep_tokens_), max_seq_len_(params.max_seq_len_) {}
protected:
py::object mm_process_engine_;
rtp_llm::GptInitParameter gpt_init_parameter_;
private:
std::vector<std::vector<int64_t>> sep_token_ids_;
bool include_sep_tokens_;
int64_t max_seq_len_;
ErrorInfo getStrHash(int32_t* token_ids, std::string& url, int mm_emb_len);
virtual ErrorResult<MultimodalOutput> MultimodalEmbedding(const std::vector<rtp_llm::MultimodalInput> mm_inputs) = 0;
ErrorResult<ExpandedOutput> expandTokenIds(const std::vector<torch::Tensor>& mm_embedding, rtp_llm::BufferPtr token_ids, const std::vector<rtp_llm::MultimodalInput> mm_inputs);
ErrorResult<std::vector<std::pair<int32_t, int32_t>>> getMultimodalTags(rtp_llm::BufferPtr token_ids);
ErrorInfo checkExpandLength(const ExpandedOutput& expand_output);
public:
ErrorInfo updateMultimodalFeatures(std::shared_ptr<rtp_llm::GenerateInput>& input);
ErrorResult<MultimodalFeature> getMultimodalFeatures(const rtp_llm::BufferPtr& input_ids, const std::vector<MultimodalInput> &mm_inputs);
};
}