maga_transformer/cpp/model_rpc/DecodeRpcServer.cc (638 lines of code) (raw):
#include <mutex>
#include <memory>
#include <unistd.h>
#include <limits.h>
#include <condition_variable>
#include "maga_transformer/cpp/core/Buffer.h"
#include "maga_transformer/cpp/utils/NetUtil.h"
#include "maga_transformer/cpp/utils/KVCacheUtils.h"
#include "maga_transformer/cpp/model_rpc/QueryConverter.h"
#include "maga_transformer/cpp/model_rpc/DecodeRpcServer.h"
#include "maga_transformer/cpp/devices/utils/DebugUtils.h"
using namespace std;
using namespace autil::legacy;
using grpc::Status;
using grpc::ClientContext;
using grpc::CompletionQueue;
using grpc::ClientAsyncResponseReader;
const int LOAD_TIMEOUT_MS = 5 * 1000;
const int EXTRA_TIMEOUT_MS = 100;
const int RDMA_CONNECT_RETRY_TIME = 3;
#define GRPC_RET_IF_ERROR(decode_context, stat, code, msg) \
if (!(stat)) { \
decode_context.error_status = grpc::Status(code, msg); \
return; \
}
string makeRequestKey(const string& client_id, size_t request_id) {
return client_id + "_request_id_" + std::to_string(request_id);
}
namespace rtp_llm {
grpc::Status DecodeRpcServer::init(const EngineInitParams& maga_init_params,
py::object mm_process_engine,
std::unique_ptr<rtp_llm::ProposeModelEngineInitParams> propose_params) {
auto ret = RemoteRpcServer::init(maga_init_params, mm_process_engine, std::move(propose_params));
if (!ret.ok()) {
return ret;
}
if (!maga_init_params_.gpt_init_parameter.decode_use_async_load_cache_) {
initThreadPool();
}
return grpc::Status::OK;
}
void DecodeRpcServer::initThreadPool() {
if (resource_.workers.size() > 0) {
return;
}
thread_pool_ = std::make_shared<autil::LockFreeThreadPool>(
resource_.workers.size() * 8, resource_.workers.size() * 8, nullptr, "RemoteCacheLoadPool");
RTP_LLM_CHECK_WITH_INFO(thread_pool_->start(), "DecodeRpcServer init ThreadPool failed");
RTP_LLM_LOG_INFO("normal cache store init done");
}
DecodeRpcServer::~DecodeRpcServer() {
if (thread_pool_) {
thread_pool_->stop();
thread_pool_.reset();
}
}
void DecodeRpcServer::prepareGenerateContext(DecodeGenerateContext& decode_context) {
decode_context.time_info.updateRequestBegineTime();
auto& allocate_request = decode_context.allocate_request;
GRPC_RET_IF_ERROR(decode_context,
decode_context.rpc_context.grpc_stream->Read(&allocate_request),
grpc::StatusCode::INTERNAL,
"failed to get message");
GRPC_RET_IF_ERROR(decode_context,
allocate_request.stage() == RemoteStage::ALLOCATE,
grpc::StatusCode::INTERNAL,
"message first status != RemoteStage::ALLOCATE");
decode_context.request_id = allocate_request.request_id();
decode_context.request_key = makeRequestKey(allocate_request.client_id(), allocate_request.request_id());
for (auto& addr : allocate_request.peer_addrs()) {
decode_context.peer_addrs.push_back(addr);
}
RTP_LLM_LOG_DEBUG("request [%s] prepare generate context done", decode_context.request_key.c_str());
}
void DecodeRpcServer::allocateResource(DecodeGenerateContext& decode_context) {
RTP_LLM_LOG_DEBUG("request [%s] start to allocate resource", decode_context.request_key.c_str());
auto input = QueryConverter::transQuery(&decode_context.allocate_request.input());
auto generate_stream = engine_->makeStream(input);
decode_context.setStream(generate_stream);
decode_context.request_timeout_ms = generate_stream->getTimeoutMs();
auto status = generate_stream->initKVBlock(0);
if (!status.ok()) {
string error_msg = "request: [" + decode_context.request_key + "] malloc kv cache block failed at decode node";
RTP_LLM_LOG_ERROR(error_msg);
decode_context.error_status = grpc::Status(grpc::StatusCode::RESOURCE_EXHAUSTED, error_msg);
return;
}
GRPC_RET_IF_ERROR(decode_context,
decode_context.rpc_context.grpc_stream->Write(GenerateOutputsPB()),
grpc::StatusCode::INTERNAL,
"failed to write allocate output");
RTP_LLM_LOG_DEBUG("request [%s] allocate resource done", decode_context.request_key.c_str());
}
void DecodeRpcServer::loadCacheFromPrefill(DecodeGenerateContext& decode_context) {
RTP_LLM_LOG_DEBUG("request [%s] load cache from prefill", decode_context.request_key.c_str());
AtomicGuard request_guard(loading_cache_requests_);
auto& grpc_stream = decode_context.rpc_context.grpc_stream;
GenerateRequestPB load_request;
GRPC_RET_IF_ERROR(
decode_context, grpc_stream->Read(&load_request), grpc::StatusCode::INTERNAL, "failed to get loadReqeust");
decode_context.time_info.updateLoadBeginTime();
auto error_info = loadCacheForAllRank(decode_context);
decode_context.time_info.updateLoadEndTime();
if (!error_info.ok()) {
RTP_LLM_LOG_WARNING("request [%s] load kv cache failed, error code [%s], cost time [%ld] ms",
decode_context.request_key.c_str(), error_info.ToString().c_str(), decode_context.time_info.loadCacheTimeMs());
}
GenerateOutputsPB load_response;
load_response.mutable_error_info()->set_error_code(transErrorCodeToRPC(error_info.code()));
GRPC_RET_IF_ERROR(
decode_context, grpc_stream->Write(load_response), grpc::StatusCode::INTERNAL, "send load response failed");
GRPC_RET_IF_ERROR(decode_context, error_info.ok(), grpc::StatusCode::INTERNAL, error_info.ToString().c_str());
RTP_LLM_LOG_DEBUG("request [%s] load cache from prefill done", decode_context.request_key.c_str());
}
void DecodeRpcServer::localGenerate(DecodeGenerateContext& decode_context) {
RTP_LLM_LOG_DEBUG("request [%s] start to local generate", decode_context.request_key.c_str());
auto& grpc_stream = decode_context.rpc_context.grpc_stream;
auto& generate_stream = decode_context.getStream();
GenerateRequestPB generate_request;
GRPC_RET_IF_ERROR(decode_context,
grpc_stream->Read(&generate_request),
grpc::StatusCode::INTERNAL,
"poll generate request failed");
GRPC_RET_IF_ERROR(decode_context,
generate_request.stage() == RemoteStage::GENERATE,
grpc::StatusCode::INTERNAL,
"message first status != RemoteStage::GENERATE");
decode_context.time_info.updateGenerateBeginTime();
generate_stream->setIsContextStream(false);
generate_stream->step();
auto new_tokens = engine_->getDevice()->allocateBuffer(
{rtp_llm::DataType::TYPE_INT32, {(size_t)generate_stream->tileNum(), (size_t)1}, rtp_llm::AllocationType::HOST}, {});
auto data = new_tokens->data<int32_t>();
auto first_token_id = generate_request.first_generate_token_id();
*data = first_token_id;
generate_stream->incLastOutputPos();
generate_stream->update({new_tokens, 1, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr});
if (propose_maga_init_params_) {
generate_stream->setReuseLength(generate_stream->seqLength() - 1);
generate_stream->setFallbackPrefixLength(generate_stream->reuseLength());
generate_stream->setSpEditRun(false);
}
generate_stream->resetBeginTime(currentTimeUs());
RTP_LLM_LOG_DEBUG("decode init stream[%d]: %s", generate_stream->streamId(), generate_stream->debugString().c_str());
engine_->enqueue(generate_stream);
RTP_LLM_LOG_DEBUG("request [%s] enqueue success", decode_context.request_key.c_str());
decode_context.error_status =
pollStreamOutput(decode_context.server_context,
decode_context.request_key,
dynamic_cast<grpc::internal::WriterInterface<GenerateOutputsPB>*>(grpc_stream),
generate_stream);
decode_context.time_info.updateGenerateEndTime();
RTP_LLM_LOG_DEBUG("request [%s] local generate done", decode_context.request_key.c_str());
}
BroadcastLoadRequestPB DecodeRpcServer::constructRemoteLoadRequestForMla(
const LoadKVCacheContext& load_context, int index, const std::vector<std::string>& peer_addrs) const {
BroadcastLoadRequestPB request;
request.set_request_id(load_context.request_id);
request.set_request_key(load_context.request_key);
request.set_dp_rank(maga_init_params_.gpt_init_parameter.dp_rank_);
request.set_partition_count(1);
request.set_partition_id(0);
// D >= P
if (resource_.workers.size() % peer_addrs.size() == 0) {
int part_cnt = resource_.workers.size() / peer_addrs.size();
request.add_peer_addrs(peer_addrs[index / part_cnt]);
} else {
// P >= D, load multi block of prefill
int group_num = peer_addrs.size() / resource_.workers.size();
request.add_peer_addrs(peer_addrs[index * group_num]);
}
for (auto& cache_key : load_context.cache_keys) {
request.add_cache_keys(cache_key);
}
for (auto& block_id : load_context.block_ids) {
request.add_block_ids(block_id);
}
request.set_timeout_ms(load_context.timeout_ms);
return request;
}
BroadcastLoadRequestPB DecodeRpcServer::constructRemoteLoadRequest(const LoadKVCacheContext& load_context,
int index,
const std::vector<std::string>& peer_addrs) const {
BroadcastLoadRequestPB request;
request.set_request_id(load_context.request_id);
request.set_request_key(load_context.request_key);
request.set_dp_rank(maga_init_params_.gpt_init_parameter.dp_rank_);
if (resource_.workers.size() % peer_addrs.size() == 0) {
// D >= P, load part block of prefill
int part_cnt = resource_.workers.size() / peer_addrs.size();
request.set_partition_count(part_cnt);
request.set_partition_id(index % part_cnt);
request.add_peer_addrs(peer_addrs[index / part_cnt]);
} else {
// P >= D, load multi block of prefill
request.set_partition_count(1);
request.set_partition_id(0);
int group_num = peer_addrs.size() / resource_.workers.size();
for (int i = 0; i < group_num; i++) {
request.add_peer_addrs(peer_addrs[index * group_num + i]);
}
}
for (auto& cache_key : load_context.cache_keys) {
request.add_cache_keys(cache_key);
}
for (auto& block_id : load_context.block_ids) {
request.add_block_ids(block_id);
}
request.set_timeout_ms(load_context.timeout_ms);
return request;
}
ErrorInfo DecodeRpcServer::loadCacheForAllRank(DecodeGenerateContext& decode_context) {
auto* generate_stream = decode_context.getStream().get();
auto& cache_keys = generate_stream->cacheKeys(0);
auto& block_ids = generate_stream->kvCache().blocks(0);
if (cache_keys.size() != block_ids.size()) {
return ErrorInfo(ErrorCode::LOAD_KV_CACHE_FAILED,
"cache keys size " + std::to_string(cache_keys.size()) + " not equal to block size "
+ std::to_string(block_ids.size()));
}
if (resource_.workers.size() % decode_context.peer_addrs.size() != 0
&& decode_context.peer_addrs.size() % resource_.workers.size() != 0) {
RTP_LLM_LOG_WARNING("request:[%s] peer ips size %d not equal to worker size %d",
decode_context.request_key.c_str(),
decode_context.peer_addrs.size(),
resource_.workers.size());
return ErrorInfo(ErrorCode::LOAD_KV_CACHE_FAILED, "peer ips size not equal to worker size");
}
auto load_cache_timeout_ms = maga_init_params_.gpt_init_parameter.load_cache_timeout_ms_;
load_cache_timeout_ms = load_cache_timeout_ms > 0 ? load_cache_timeout_ms : LOAD_TIMEOUT_MS;
auto max_rpc_timeout_ms = maga_init_params_.gpt_init_parameter.max_rpc_timeout_ms_;
auto rpc_timeout = max_rpc_timeout_ms > 0 ? max_rpc_timeout_ms : MAX_GRPC_TIMEOUT_MS;
auto min_timeout_ms = std::min(load_cache_timeout_ms, rpc_timeout);
auto request_timeout_ms = decode_context.request_timeout_ms;
min_timeout_ms = request_timeout_ms > 0 ? std::min(request_timeout_ms, min_timeout_ms) : min_timeout_ms;
LoadKVCacheContext load_context{decode_context.request_id,
decode_context.request_key,
decode_context.peer_addrs,
cache_keys,
block_ids,
generate_stream->reuseBlockSize(),
min_timeout_ms,
1,
0,
decode_context.server_context,
generate_stream->returnEmptyHiddenStates()};
// Prefill: TP = 1 && Decode: TP = 1
if (resource_.workers.size() == 1 && decode_context.peer_addrs.size() == 1) {
for (size_t i = 0; i < maga_init_params_.gpt_init_parameter.rdma_connect_retry_times_ + 1; i++) {
auto error_info = loadCache(load_context);
if (error_info.code() != ErrorCode::CACHE_STORE_LOAD_CONNECT_FAILED
&& error_info.code() != ErrorCode::CACHE_STORE_LOAD_RDMA_CONNECT_FAILED) {
return error_info;
}
}
}
if (maga_init_params_.gpt_init_parameter.decode_use_async_load_cache_) {
return loadCacheAsyncForTp(decode_context, load_context);
} else {
return loadCacheSyncForTp(decode_context, load_context);
}
}
ErrorInfo DecodeRpcServer::loadCacheAsyncForTp(DecodeGenerateContext& decode_context,
LoadKVCacheContext& load_context) {
int64_t load_cache_begin_time_us = currentTimeUs();
struct WorkerRpcContext {
WorkerRpcContext() {
client_context = make_shared<ClientContext>();
}
BroadcastLoadResponsePB response;
Status status;
std::shared_ptr<RpcService::Stub> stub;
std::shared_ptr<ClientContext> client_context;
};
uint32_t worker_size = resource_.grpc_workers.size();
vector<WorkerRpcContext> all_context(worker_size);
uint32_t cq_size = worker_size % 2 == 0 ? worker_size / 2 : worker_size / 2 + 1;
vector<CompletionQueue> completion_queues(cq_size);
vector<int> each_finished_count(cq_size, 0);
if(worker_size == 0 || cq_size == 0) {
RTP_LLM_LOG_WARNING("request:[%s] cq_size or worker_size is 0, worker size = %d, cq size = %d", decode_context.request_key.c_str(), worker_size, cq_size);
return ErrorInfo(ErrorCode::LOAD_KV_CACHE_FAILED, "worker size or cq size is 0");
}
auto worker_size_per_queue = worker_size / completion_queues.size();
RTP_LLM_LOG_DEBUG("request:[%s] start to async remote load for all rank", decode_context.request_key.c_str());
for (int i = 0; i < worker_size; i++) {
auto& worker = resource_.grpc_workers[i];
auto connect_status = resource_.rpc_pool.getConnection(worker);
if (!connect_status.ok()) {
string error_msg = "get grpc connection for rank:" + std::to_string(i) + ", addr:" + worker + " failed";
return ErrorInfo(ErrorCode::GET_CONNECTION_FAILED, error_msg);
}
all_context.push_back(WorkerRpcContext());
auto& rpc_context = all_context[i];
rpc_context.stub = connect_status.value().stub;
BroadcastLoadRequestPB load_request;
if (engine_->resourceContext().cache_manager->cacheConfig().use_mla) {
load_request = constructRemoteLoadRequestForMla(load_context, i, decode_context.peer_addrs);
} else {
load_request = constructRemoteLoadRequest(load_context, i, decode_context.peer_addrs);
}
std::unique_ptr<ClientAsyncResponseReader<BroadcastLoadResponsePB>> reader(rpc_context.stub->AsyncRemoteLoad(
rpc_context.client_context.get(), load_request, &completion_queues[i % completion_queues.size()]));
reader->Finish(&rpc_context.response, &rpc_context.status, reinterpret_cast<void*>(i));
}
bool all_success = true;
size_t finished_count = 0;
auto total_timeout_ms = load_context.timeout_ms + EXTRA_TIMEOUT_MS;
ErrorCode error_code = ErrorCode::NONE_ERROR;
std::string error_msg = "failed to load kv cache in rank: ";
int64_t min_response_done_time_us = 1lu << 60;
int64_t max_response_done_time_us = 0;
while (true) {
RTP_LLM_LOG_DEBUG("request [%s] load cache loop step", decode_context.request_key.c_str());
auto cost_time_ms = (currentTimeUs() - load_cache_begin_time_us) / 1000;
if (cost_time_ms > total_timeout_ms) {
error_msg = "load cache timeout : cost time is " + std::to_string(cost_time_ms)
+ "ms, "
"total timeout for load cache is "
+ std::to_string(total_timeout_ms) + "ms";
return ErrorInfo(ErrorCode::LOAD_CACHE_TIMEOUT, error_msg);
}
if (load_context.server_context->IsCancelled()) {
string error_msg = "request is cancelled";
return ErrorInfo(ErrorCode::CANCELLED, error_msg);
}
auto once_deadline =
std::chrono::system_clock::now()
+ std::chrono::milliseconds(maga_init_params_.gpt_init_parameter.decode_polling_kv_cache_step_ms_);
RTP_LLM_LOG_DEBUG("request [%s] start to execute async next", decode_context.request_key.c_str());
// TODO(xinfei.sxf) There is a problem with complete queue next call delay here, the reason is yet to be
// investigated
void* got_tag;
bool ok = false;
for (uint32_t i = 0; i < completion_queues.size(); i++) {
if (each_finished_count[i] == worker_size_per_queue) {
continue;
}
if (completion_queues[i].AsyncNext(&got_tag, &ok, once_deadline)
== grpc::CompletionQueue::NextStatus::TIMEOUT) {
RTP_LLM_LOG_DEBUG("request [%s] async next timeout", decode_context.request_key.c_str());
continue;
}
each_finished_count[i]++;
if (!ok) {
string error_msg = "async get next event from grpc completion queue failed";
return ErrorInfo(ErrorCode::LOAD_KV_CACHE_FAILED, error_msg);
}
auto rank = reinterpret_cast<uintptr_t>(got_tag);
const auto& status = all_context[rank].status;
const auto& response = all_context[rank].response;
const auto& pb_error_code = response.error_info().error_code();
const auto& pb_error_message = response.error_info().error_message();
min_response_done_time_us = std::min(min_response_done_time_us, response.done_time_us());
max_response_done_time_us = std::max(max_response_done_time_us, response.done_time_us());
RTP_LLM_LOG_DEBUG("request [%s] load cache for rank [%d] done", decode_context.request_key.c_str(), rank);
if (!status.ok()) {
all_success = false;
error_code = ErrorCode::LOAD_KV_CACHE_FAILED;
error_msg += std::to_string(rank) + ": " + status.error_message() + ", ";
} else if (pb_error_code != ErrorCodePB::NONE_ERROR) {
all_success = false;
error_code = transRPCErrorCode(pb_error_code);
error_msg += std::to_string(rank) + ": " + pb_error_message + ", ";
}
finished_count++;
if (finished_count == worker_size) {
break;
}
}
if (finished_count == worker_size) {
break;
}
}
for (auto& completion_queue : completion_queues) {
completion_queue.Shutdown();
}
if (finished_count != worker_size) {
all_success = false;
}
if (!all_success) {
return ErrorInfo(error_code, error_msg);
}
decode_context.stat_info.load_cache_min_rt_us = min_response_done_time_us - load_cache_begin_time_us;
decode_context.stat_info.load_cache_max_rt_us = max_response_done_time_us - load_cache_begin_time_us;
decode_context.stat_info.load_cache_polling_cost_us = currentTimeUs() - max_response_done_time_us;
RTP_LLM_LOG_DEBUG("load_cache_min_rt_us = %ld, load_cache_max_rt_us = %ld, load_cache_polling_cost_us = %ld",
decode_context.stat_info.load_cache_min_rt_us,
decode_context.stat_info.load_cache_max_rt_us,
decode_context.stat_info.load_cache_polling_cost_us);
return ErrorInfo::OkStatus();
}
ErrorInfo DecodeRpcServer::loadCacheSyncForTp(DecodeGenerateContext& decode_context, LoadKVCacheContext& load_context) {
int64_t load_cache_begin_time_us = currentTimeUs();
int64_t min_response_done_time_us = 1lu << 60;
int64_t max_response_done_time_us = 0;
std::vector<autil::ThreadPoolBase::Future<ErrorInfo>> futures;
auto local_task = [&] { return this->loadCache(load_context); };
futures.emplace_back(thread_pool_->async(local_task));
for (int i = 0; i < resource_.grpc_workers.size(); i++) {
auto& worker = resource_.grpc_workers[i];
auto remote_task = [&]() {
auto connect_status = resource_.rpc_pool.getConnection(worker);
if (!connect_status.ok()) {
string error_msg = "get grpc connection for ip " + worker + " failed";
return ErrorInfo(ErrorCode::LOAD_KV_CACHE_FAILED, error_msg);
}
auto stub = connect_status.value().stub.get();
ClientContext client_context;
BroadcastLoadRequestPB load_request;
if (engine_->resourceContext().cache_manager->cacheConfig().use_mla) {
load_request = constructRemoteLoadRequestForMla(load_context, i, decode_context.peer_addrs);
} else {
load_request = constructRemoteLoadRequest(load_context, i, decode_context.peer_addrs);
}
BroadcastLoadResponsePB response;
auto grpc_status = stub->RemoteLoad(&client_context, load_request, &response);
const auto& pb_error_code = response.error_info().error_code();
const auto& pb_error_message = response.error_info().error_message();
if (!grpc_status.ok()) {
return ErrorInfo(ErrorCode::LOAD_KV_CACHE_FAILED, grpc_status.error_message());
} else if (pb_error_code != ErrorCodePB::NONE_ERROR) {
auto error_code = transRPCErrorCode(pb_error_code);
return ErrorInfo(error_code, pb_error_message);
}
min_response_done_time_us = std::min(min_response_done_time_us, response.done_time_us());
max_response_done_time_us = std::max(max_response_done_time_us, response.done_time_us());
return ErrorInfo::OkStatus();
};
futures.emplace_back(thread_pool_->async(remote_task));
}
std::string err_msg = "failed to load kv cache in rank: ";
bool success = true;
for (int i = 0; i < futures.size(); i++) {
auto status = futures[i].get();
if (!status.ok()) {
// TODO(xinfei.sxf) 可以不等待其他rank的结果吗
success = false;
err_msg += std::to_string(i) + ": " + status.ToString() + ", ";
}
}
if (!success) {
RTP_LLM_LOG_WARNING(err_msg);
return ErrorInfo(ErrorCode::LOAD_KV_CACHE_FAILED, err_msg);
}
decode_context.stat_info.load_cache_min_rt_us = min_response_done_time_us - load_cache_begin_time_us;
decode_context.stat_info.load_cache_max_rt_us = max_response_done_time_us - load_cache_begin_time_us;
decode_context.stat_info.load_cache_polling_cost_us = currentTimeUs() - max_response_done_time_us;
RTP_LLM_LOG_DEBUG("load_cache_min_rt_us = %ld, load_cache_max_rt_us = %ld, load_cache_polling_cost_us = %ld",
decode_context.stat_info.load_cache_min_rt_us,
decode_context.stat_info.load_cache_max_rt_us,
decode_context.stat_info.load_cache_polling_cost_us);
return ErrorInfo::OkStatus();
}
ErrorInfo DecodeRpcServer::loadCache(const LoadKVCacheContext& load_context) {
AtomicGuard request_guard(onflight_load_cache_requests_);
const auto& request_key = load_context.request_key;
auto cache_manager = engine_->resourceContext().cache_manager;
const auto& cache_config = cache_manager->cacheConfig();
auto k_block_size = cache_config.k_block_stride;
auto v_block_size = cache_config.v_block_stride;
auto scale_block_size = cache_config.kv_scale_block_stride;
auto layer_num = maga_init_params_.gpt_init_parameter.num_layers_;
if (v_block_size % load_context.peer_addrs.size() != 0 || k_block_size % load_context.peer_addrs.size() != 0
|| scale_block_size % load_context.peer_addrs.size() != 0) {
RTP_LLM_LOG_WARNING(
"k block size [%d] or v block size [%d] or scale block size [%d] is not divisible by peer ips size [%d]",
k_block_size,
v_block_size,
scale_block_size,
load_context.peer_addrs.size());
return ErrorInfo(ErrorCode::LOAD_KV_CACHE_FAILED, "block size is not divisible by peer ips size");
}
k_block_size = k_block_size / load_context.peer_addrs.size();
v_block_size = v_block_size / load_context.peer_addrs.size();
scale_block_size = scale_block_size / load_context.peer_addrs.size();
auto cancel_check_func = [&load_context]() -> bool { return load_context.server_context->IsCancelled(); };
auto start_load_time_us = currentTimeUs();
std::vector<std::shared_ptr<LoadContext>> load_contexts;
for (int i = 0; i < load_context.peer_addrs.size(); i++) {
auto& peer_addr = load_context.peer_addrs[i];
std::vector<std::shared_ptr<RequestBlockBuffer>> layer_caches;
for (size_t layer_id = 0; layer_id < layer_num; layer_id++) {
auto request_key = std::to_string(load_context.request_id) + "-" + std::to_string(layer_id);
auto load_layer_cache =
std::make_shared<RequestBlockBuffer>(std::to_string(load_context.request_id), request_key);
auto block_num = load_context.block_ids.size();
for (size_t block_pos = load_context.reuse_block_size; block_pos < block_num; block_pos++) {
auto cache_key = makeCacheKey(std::to_string(load_context.cache_keys[block_pos]), layer_id);
auto block_id = load_context.block_ids[block_pos];
auto addr_info = cache_manager->convertIndexToAddr(block_id, layer_id);
void* k_addr = (void*)((int64_t)addr_info.k_addr + i * k_block_size);
std::shared_ptr<void> k_block_addr(k_addr, [](void* p) {});
load_layer_cache->addBlock("k_" + cache_key, k_block_addr, k_block_size, true, true);
if (addr_info.k_scale_addr) {
void* k_scale_addr = (void*)((int64_t)addr_info.k_scale_addr + i * scale_block_size);
std::shared_ptr<void> k_block_scale_addr(k_scale_addr, [](void* p) {});
load_layer_cache->addBlock("k_scale" + cache_key, k_block_scale_addr, scale_block_size, true, true);
}
if (engine_->resourceContext().cache_manager->cacheConfig().use_mla) {
continue;
}
void* v_addr = (void*)((int64_t)addr_info.v_addr + i * v_block_size);
std::shared_ptr<void> v_block_addr(v_addr, [](void* p) {});
load_layer_cache->addBlock("v_" + cache_key, v_block_addr, v_block_size, true, true);
if (addr_info.v_scale_addr) {
void* v_scale_addr = (void*)((int64_t)addr_info.v_scale_addr + i * scale_block_size);
std::shared_ptr<void> v_block_scale_addr(v_scale_addr, [](void* p) {});
load_layer_cache->addBlock("v_scale" + cache_key, v_block_scale_addr, scale_block_size, true, true);
}
}
layer_caches.push_back(load_layer_cache);
}
auto ip_parts = autil::StringUtil::split(peer_addr, ":");
if (ip_parts.size() != 3) {
RTP_LLM_LOG_WARNING("invalid peer ip to load [%s]", peer_addr.c_str());
return ErrorInfo(ErrorCode::LOAD_KV_CACHE_FAILED, "invalid peer ip");
}
auto layer_cache_load_context =
resource_.cache_store->loadBuffers(layer_caches,
ip_parts[0],
autil::StringUtil::strToInt32WithDefault(ip_parts[1].c_str(), 0),
autil::StringUtil::strToInt32WithDefault(ip_parts[2].c_str(), 0),
load_context.timeout_ms,
cancel_check_func,
load_context.partition_count,
load_context.partition_id);
if (!layer_cache_load_context) {
RTP_LLM_LOG_WARNING("request [%s] load cache failed, layer cache load context is nullptr", request_key.c_str());
return ErrorInfo(ErrorCode::LOAD_KV_CACHE_FAILED, "load kv cache failed");
}
load_contexts.push_back(layer_cache_load_context);
if (engine_->isMTP() &&
engine_->getDevice()->getDeviceProperties().tp_rank == 0 &&
load_context.hidden_states_ != nullptr)
{
RTP_LLM_LOG_DEBUG("mtp hidden states rdma get.");
std::vector<std::shared_ptr<RequestBlockBuffer>> hidden_states_caches;
auto load_layer_cache =
std::make_shared<RequestBlockBuffer>(std::to_string(load_context.request_id), "");
auto hidden_states_ptr = load_context.hidden_states_->data();
auto hidden_states_size = load_context.hidden_states_->sizeBytes();
RTP_LLM_LOG_DEBUG("token_num is %d", load_context.hidden_states_->shape()[0]);
RTP_LLM_LOG_DEBUG("decoder need hidden states size is %d", hidden_states_size);
std::shared_ptr<void> hidden_states_addr(hidden_states_ptr, [](void* p) {});
load_layer_cache->addBlock("hidden_states", hidden_states_addr, hidden_states_size, true, true);
hidden_states_caches.push_back(load_layer_cache);
auto hidden_states_load_context =
resource_.cache_store->loadBuffers(hidden_states_caches,
ip_parts[0],
autil::StringUtil::strToInt32WithDefault(ip_parts[1].c_str(), 0),
autil::StringUtil::strToInt32WithDefault(ip_parts[2].c_str(), 0),
load_context.timeout_ms,
cancel_check_func,
load_context.partition_count,
load_context.partition_id);
if (!hidden_states_load_context) {
RTP_LLM_LOG_WARNING("request [%s] load hidden_states failed, layer cache load context is nullptr", request_key.c_str());
return ErrorInfo(ErrorCode::LOAD_KV_CACHE_FAILED, "load hidden_states failed");
}
load_contexts.push_back(hidden_states_load_context);
RTP_LLM_LOG_DEBUG("mtp hidden states rdma end.");
}
}
for (auto& layer_cache_load_context : load_contexts) {
layer_cache_load_context->waitDone();
if (layer_cache_load_context->success()) {
RTP_LLM_LOG_DEBUG("request [%s] load kv cache success", request_key.c_str());
} else {
// TODO(xinfei.sxf) add retry for part failed blocks.
auto load_done_time_us = currentTimeUs();
RTP_LLM_LOG_WARNING("request [%s] load cache failed, status [%s], cost time [%ld] ms",
request_key.c_str(),
layer_cache_load_context->getErrorInfoString().c_str(),
(load_done_time_us - start_load_time_us) / 1000);
}
return layer_cache_load_context->getErrorInfo();
}
return ErrorInfo::OkStatus();
}
grpc::Status DecodeRpcServer::RemoteLoad(grpc::ServerContext* server_context,
const BroadcastLoadRequestPB* request,
BroadcastLoadResponsePB* response) {
if (request->dp_rank() != maga_init_params_.gpt_init_parameter.dp_rank_) {
RTP_LLM_LOG_WARNING("only load when in dp group, skip load for dp rank %d", request->dp_rank());
return grpc::Status::OK;
}
std::vector<int64_t> cache_keys(request->cache_keys().begin(), request->cache_keys().end());
std::vector<int32_t> block_ids(request->block_ids().begin(), request->block_ids().end());
std::vector<std::string> peer_addrs(request->peer_addrs().begin(), request->peer_addrs().end());
// TODO(xinfei.sxf) add retry
auto error_info = loadCache({request->request_id(),
request->request_key(),
peer_addrs,
cache_keys,
block_ids,
request->reuse_block_size(),
request->timeout_ms(),
request->partition_count(),
request->partition_id(),
server_context,
nullptr});
response->mutable_error_info()->set_error_code(transErrorCodeToRPC(error_info.code()));
response->mutable_error_info()->set_error_message(error_info.ToString());
response->set_done_time_us(currentTimeUs());
RTP_LLM_LOG_DEBUG("request: %s, remote load cache grpc done", request->request_key().c_str());
return grpc::Status::OK;
}
grpc::Status DecodeRpcServer::allocateResourceFunc(DecodeGenerateContext& decode_context) {
EXECUTE_STAGE_FUNC(allocateResource, decode_context);
return grpc::Status::OK;
}
grpc::Status DecodeRpcServer::RemoteGenerate(grpc::ServerContext* server_context, ServerStream* grpc_stream) {
AtomicGuard request_guard(onflight_requests_);
DecodeRpcContext rpc_context{grpc_stream};
// TODO(xinfei.sxf) request id is 0 here
auto decode_context = DecodeGenerateContext(rpc_context, 0, server_context, metrics_reporter_);
decode_context.onflight_requests = onflight_requests_;
decode_context.loading_cache_requests = loading_cache_requests_;
auto max_retry_times = maga_init_params_.gpt_init_parameter.decode_retry_times_;
auto max_retry_timeout_ms = maga_init_params_.gpt_init_parameter.decode_retry_timeout_ms_;
try {
EXECUTE_STAGE_FUNC(prepareGenerateContext, decode_context);
EXECUTE_WITH_RETRY(allocateResourceFunc, decode_context, max_retry_times, max_retry_timeout_ms);
if (decode_context.hasError()) {
RTP_LLM_LOG_WARNING("request [%s] allocate resource failed after retry %d times, cost time ms [%ld], "
"max retry time [%ld], max retry timeout ms [%ld]",
decode_context.request_key.c_str(),
decode_context.retry_times,
decode_context.retry_cost_time_ms,
max_retry_times + 1,
max_retry_timeout_ms);
return decode_context.error_status;
}
EXECUTE_STAGE_FUNC(loadCacheFromPrefill, decode_context);
EXECUTE_STAGE_FUNC(localGenerate, decode_context);
decode_context.stat_info.nextStage();
} catch (const std::exception& e) {
auto error_msg = "request [" + decode_context.request_key + "] catch exception [" + e.what() + "]";
decode_context.error_status = grpc::Status(grpc::StatusCode::INTERNAL, error_msg);
return decode_context.error_status;
} catch (...) {
auto error_msg = "request [" + decode_context.request_key + "] catch unknown exception";
decode_context.error_status = grpc::Status(grpc::StatusCode::INTERNAL, error_msg);
return decode_context.error_status;
}
return grpc::Status::OK;
}
} // namespace rtp_llm