maga_transformer/cpp/normal_engine/NormalEngine.cc (277 lines of code) (raw):
#include "maga_transformer/cpp/stream/GenerateStream.h"
#include "maga_transformer/cpp/engine_base/EngineBase.h"
#include "maga_transformer/cpp/normal_engine/NormalExecutor.h"
#include "maga_transformer/cpp/normal_engine/NormalEngine.h"
#include "maga_transformer/cpp/normal_engine/NormalGenerateStream.h"
#include "maga_transformer/cpp/utils/StatusUtil.h"
#include "maga_transformer/cpp/schedulers/FIFOScheduler.h"
#include "maga_transformer/cpp/schedulers/BatchDecodeScheduler.h"
#include "maga_transformer/cpp/cache/CacheConfigCreator.h"
#include "maga_transformer/cpp/system_prompt/SystemPromptConstructor.h"
#include "maga_transformer/cpp/utils/Logger.h"
#include "maga_transformer/cpp/utils/AssertUtils.h"
#include "autil/TimeUtility.h"
#include <memory>
#include <thread>
#include <random>
using namespace std;
namespace rtp_llm {
NormalEngine::NormalEngine(const EngineInitParams& params) :
EngineBase(params),
params_(params.gpt_init_parameter),
metrics_reporter_(params.metrics_reporter)
{
RTP_LLM_LOG_INFO(__PRETTY_FUNCTION__);
std::optional<WarmUpResult> warm_up_result = std::nullopt;
if (params_.warm_up_ && (!params_.is_multimodal_)) {
// warm up
RTP_LLM_LOG_INFO("warm up (max_context_batch_size %d, max_seq_len %d calculate_loss %d) query begin",
params_.max_context_batch_size_, params_.max_seq_len_, int(params_.warm_up_with_loss_));
warm_up_result = warmUp(params);
RTP_LLM_LOG_INFO("warm up done, max runtime used memory: %ld bytes (%ld MiB), device reserved memory: %ld bytes (%ld MiB)",
warm_up_result->max_used_memory,
warm_up_result->max_used_memory / 1024 / 1024,
warm_up_result->device_reserved_bytes,
warm_up_result->device_reserved_bytes / 1024 / 1024);
} else {
RTP_LLM_LOG_INFO("skip warm up.");
}
initCacheManager(warm_up_result);
RTP_LLM_LOG_INFO("create cache manager done");
executor_.reset(new NormalExecutor(params, resource_context_.cache_manager, device_, getLoraManager()));
RTP_LLM_LOG_INFO("create normal executor done");
initScheduler();
(void)startLoop();
if (device_->getDeviceProperties().tp_rank == 0 && scheduler_->canLoadBalance()) {
initLoadBalance();
}
}
void NormalEngine::initScheduler() {
if (getenv("USE_BATCH_DECODE_SCHEDULER") && std::string(getenv("USE_BATCH_DECODE_SCHEDULER")) == "1") {
scheduler_.reset(new BatchDecodeScheduler(params_, resource_context_.cache_manager, metrics_reporter_, device_));
RTP_LLM_LOG_INFO("create batch decode scheduler done");
} else {
scheduler_.reset(new FIFOScheduler(params_, resource_context_.cache_manager, metrics_reporter_));
RTP_LLM_LOG_INFO("create fifo scheduler done");
}
}
NormalEngine::~NormalEngine() {
RTP_LLM_LOG_INFO("destory normal engine");
(void)stop();
}
absl::StatusOr<GenerateStreamPtr> NormalEngine::preRun(
const std::shared_ptr<GenerateInput>& generate_input, preRunMode mode) {
auto stream = std::make_shared<NormalGenerateStream>(generate_input, params_, resource_context_, nullptr);
if (mode == preRunMode::prefill_warm_up) {
stream->setPerfTest(true);
} else if (mode == preRunMode::decode_warm_up) {
stream->setIsContextStream(false);
stream->fakeInitKVBlock();
} else if (mode == preRunMode::build_system_prompt) {
THROW_IF_STATUSOR_ERROR(stream->initKVBlock(0, 0));
};
std::list<GenerateStreamPtr> streams{stream};
THROW_IF_STATUS_ERROR(executor_->process(streams));
return stream;
}
int64_t NormalEngine::getLastScheduleTime() {
return scheduler_->lastScheduleTime();
}
WarmUpResult NormalEngine::warmUp(const EngineInitParams& params) {
if (params_.isPDFusion() || params_.isPrefillRole()) {
return prefillWarmUp(params);
} else {
return decodeWarmUp(params);
}
}
std::shared_ptr<GenerateInput> NormalEngine::makeFakeInput(size_t seq_len) {
std::shared_ptr<GenerateInput> fake_input = make_shared<GenerateInput>();
fake_input->generate_config = make_shared<GenerateConfig>();
fake_input->input_ids = device_->allocateBuffer(
{rtp_llm::DataType::TYPE_INT32, {seq_len}, rtp_llm::AllocationType::HOST});
std::default_random_engine generator;
size_t token_size = params_.embedding_size_ ? std::min(params_.embedding_size_, params_.vocab_size_) : params_.vocab_size_;
std::uniform_int_distribution<int> distribution(0, token_size - 1);
for (size_t i = 0; i < fake_input->input_ids->size(); ++i) {
*fake_input->input_ids->dataWithOffset<int32_t>(i) = distribution(generator);
}
return fake_input;
}
WarmUpResult NormalEngine::prefillWarmUp(const EngineInitParams& params) {
auto fake_input = makeFakeInput((size_t)params_.max_seq_len_ - 1);
fake_input->generate_config->num_return_sequences = params_.max_context_batch_size_;
fake_input->generate_config->calculate_loss = int(params_.warm_up_with_loss_);
fake_input->begin_time_us = autil::TimeUtility::currentTimeInMicroSeconds();
device_->setTraceMemory(true);
executor_.reset(new NormalExecutor(params, nullptr, device_, nullptr, true));
THROW_IF_STATUSOR_ERROR(preRun(fake_input, preRunMode::prefill_warm_up));
const auto device_status = device_->getDeviceStatus();
device_->setTraceMemory(false);
(void)executor_.reset(nullptr);
return WarmUpResult({
device_status.device_memory_status.preserved_bytes,
device_status.device_memory_status.max_consumed_bytes});
}
WarmUpResult NormalEngine::decodeWarmUp(const EngineInitParams& params) {
auto fake_input = makeFakeInput((size_t)params_.max_seq_len_ - 1);
fake_input->generate_config->num_return_sequences = params_.max_generate_batch_size_;
fake_input->generate_config->calculate_loss = int(params_.warm_up_with_loss_);
fake_input->begin_time_us = autil::TimeUtility::currentTimeInMicroSeconds();
device_->setTraceMemory(true);
auto cache_config = CacheConfigCreator::createBasicConfig(params_);
cache_config.seq_size_per_block = params_.seq_size_per_block_;
cache_config.block_nums = 5;
auto cache_manager = make_shared<CacheManager>(cache_config, device_, true);
executor_.reset(new NormalExecutor(params, cache_manager, device_, nullptr, true));
THROW_IF_STATUSOR_ERROR(preRun(fake_input, preRunMode::decode_warm_up));
const auto device_status = device_->getDeviceStatus();
device_->setTraceMemory(false);
(void)executor_.reset(nullptr);
return WarmUpResult({
device_status.device_memory_status.preserved_bytes,
device_status.device_memory_status.max_consumed_bytes});
}
std::shared_ptr<GenerateStream> NormalEngine::enqueueMinFakeQuery(int32_t max_new_tokens) {
RTP_LLM_LOG_DEBUG("enqueue min fake query");
auto fake_input = makeFakeInput(1);
fake_input->generate_config->max_new_tokens = max_new_tokens;
fake_input->generate_config->top_k = 1;
fake_input->begin_time_us = autil::TimeUtility::currentTimeInMicroSeconds();
fake_input->fake_query = true;
auto stream = makeStream(fake_input);
stream->setMetricsReporter(nullptr);
enqueue(stream);
return stream;
}
void NormalEngine::initLoadBalance() {
RTP_LLM_LOG_INFO("init load balance start");
auto stream = enqueueMinFakeQuery(3);
while(!stream->finished() && !stream->stopped()) {
RTP_LLM_LOG_INFO("wait load balance init run over for 1s");
this_thread::sleep_for(std::chrono::seconds(1));
}
RTP_LLM_LOG_INFO("init load balance done and (StepPerMin: %ld , StepLatencyUs: %ld)",
step_recorder_.getStepPerMin(), step_recorder_.getStepLatency());
}
void NormalEngine::initCacheManager(std::optional<WarmUpResult> warm_up_result) {
auto result = CacheConfigCreator::createConfig(params_, warm_up_result);
RTP_LLM_LOG_INFO("create cache manager with block nums %d, block size %ld KB",
result.block_nums, result.block_size / 1024);
resource_context_.cache_manager = make_shared<CacheManager>(result, device_, false, metrics_reporter_);
}
absl::Status NormalEngine::initSystemPrompt() {
resource_context_.reuse_cache = params_.reuse_cache_;
if (!params_.multi_task_prompt_tokens_.empty()) {
resource_context_.reuse_cache = true;
CHECK_AND_RETURN_REF(system_prompt_param,
SystemPromptConstructor::construct(params_, this, resource_context_.cache_manager.get(),
device_->getDeviceProperties().tp_rank == 0));
resource_context_.system_prompt.reset(new SystemPrompt(system_prompt_param));
}
return absl::OkStatus();
}
LoadBalanceInfo NormalEngine::getLoadBalanceInfo() {
auto kv_cache_info = resource_context_.cache_manager->getKVCacheInfo();
return LoadBalanceInfo{
(int64_t)step_recorder_.getStepLatency(),
(int64_t)step_recorder_.getStepCount(),
(int64_t)step_recorder_.getStepPerMin(),
(int64_t)kv_cache_info.available_kv_cache,
(int64_t)kv_cache_info.total_kv_cache,
(int64_t)scheduler_->onflightStreams()
};
}
absl::Status NormalEngine::startLoop() {
RTP_LLM_LOG_INFO("start init system prompt");
THROW_IF_STATUS_ERROR(initSystemPrompt());
RTP_LLM_LOG_INFO("init system prompt done");
RTP_LLM_LOG_INFO("start normal engine loop");
running_ = true;
loop_thread_ = autil::Thread::createThread(std::bind(&NormalEngine::loop, this), "normal_engine_loop");
return absl::OkStatus();
}
absl::Status NormalEngine::stop() {
RTP_LLM_LOG_INFO("stop normal engine");
running_ = false;
RETURN_IF_STATUS_ERROR(scheduler_->stop());
loop_thread_->join();
return absl::OkStatus();
}
void NormalEngine::loop() {
RTP_LLM_LOG_INFO("loop begin");
device_->preRun();
while (running_) {
auto status = step();
if (!status.ok()) {
RTP_LLM_LOG_ERROR("step running error: %s", status.ToString().c_str());
THROW_IF_STATUS_ERROR(trySaveStepError());
}
}
}
absl::Status NormalEngine::trySaveStepError() const {
return absl::UnimplementedError("can not save yet!");
}
std::shared_ptr<GenerateStream> NormalEngine::makeStream(const std::shared_ptr<GenerateInput>& input) {
std::shared_ptr<GenerateStream> stream = std::make_shared<NormalGenerateStream>(
input, params_, resource_context_, metrics_reporter_);
return stream;
}
void NormalEngine::enqueue(std::shared_ptr<GenerateStream>& stream) {
(void)scheduler_->enqueue(stream);
}
std::shared_ptr<GenerateStream> NormalEngine::enqueue(const std::shared_ptr<GenerateInput>& input) {
std::shared_ptr<GenerateStream> stream = std::make_shared<NormalGenerateStream>(
input, params_, resource_context_, metrics_reporter_);
(void)scheduler_->enqueue(stream);
return stream;
}
absl::Status NormalEngine::step() {
list<GenerateStreamPtr> streams;
if (device_->getDeviceProperties().tp_rank == 0) {
if (scheduler_->empty() || step_recorder_.empty()) {
step_recorder_.reset();
step_recorder_.registerStep(autil::TimeUtility::currentTimeInMicroSeconds());
}
CHECK_AND_ASSIGN(streams, scheduler_->schedule());
if (streams.empty()) {
if (params_.dp_size_ > 1) {
CHECK_AND_ASSIGN(streams, scheduler_->schedule());
if (streams.empty()) {
enqueueMinFakeQuery(1);
return absl::OkStatus();
}
} else {
return absl::OkStatus();
}
}
}
RTP_LLM_LOG_DEBUG(__PRETTY_FUNCTION__);
int64_t step_begin_time_us = autil::TimeUtility::currentTimeInMicroSeconds();
absl::Status status;
if (params_.world_size_ > 1) {
status = executor_->process(streams);
} else {
try {
status = executor_->process(streams);
} catch (const std::exception& e) {
RTP_LLM_LOG_ERROR("step running error: %s", e.what());
for (auto& stream: streams) {
stream->stopAndRelease(ErrorCode::EXECUTION_EXCEPTION, e.what());
}
}
}
// report step metrics
if (device_->getDeviceProperties().tp_rank == 0) {
auto step_latency = autil::TimeUtility::currentTimeInMicroSeconds() - step_begin_time_us;
reportMetrics({false, false, step_latency});
for (auto& stream: streams) {
if (stream->finished()) {
step_recorder_.addStepCount(stream->iterCount());
}
}
step_recorder_.registerStep(autil::TimeUtility::currentTimeInMicroSeconds());
}
return status;
}
const rtp_llm::GptInitParameter NormalEngine::gptInitParameter() const {
return params_;
}
} // namespace rtp_llm