maga_transformer/cpp/engine_base/EngineBase.cc (35 lines of code) (raw):

#include "maga_transformer/cpp/engine_base/EngineBase.h" #include "maga_transformer/cpp/utils/SignalUtils.h" #include "maga_transformer/cpp/devices/DeviceFactory.h" #include "autil/EnvUtil.h" #include <stdexcept> using namespace autil; namespace rtp_llm { EngineBase::EngineBase(const EngineInitParams& params) { initDevices(params); lora_manager_ = std::make_shared<lora::LoraManager>(); } EngineBase::~EngineBase() {} std::shared_ptr<GenerateStream> EngineBase::makeStream(const std::shared_ptr<GenerateInput>& input) { throw std::runtime_error("not implemeted"); } void EngineBase::initDevices(const EngineInitParams& params) { const auto rank = params.gpt_init_parameter.dp_rank_ * params.gpt_init_parameter.tp_size_ + params.gpt_init_parameter.tp_rank_; Logger::getEngineLogger().setRank(rank); Logger::getEngineLogger().flush(); rtp_llm::DeviceFactory::initDevices(params.gpt_init_parameter); device_ = rtp_llm::DeviceFactory::getDefaultDevice(); } void EngineBase::addLora(const std::string& adapter_name, rtp_llm::lora::loraLayerWeightsMap lora_a, rtp_llm::lora::loraLayerWeightsMap lora_b) { lora_manager_->addLora(adapter_name, lora_a, lora_b); } void EngineBase::removeLora(const std::string& adapter_name) { lora_manager_->removeLora(adapter_name); } std::shared_ptr<lora::LoraManager> EngineBase::getLoraManager() { return lora_manager_; } } // namespace rtp_llm