in maga_transformer/cpp/model_rpc/DecodeRpcServer.cc [502:638]
ErrorInfo DecodeRpcServer::loadCache(const LoadKVCacheContext& load_context) {
AtomicGuard request_guard(onflight_load_cache_requests_);
const auto& request_key = load_context.request_key;
auto cache_manager = engine_->resourceContext().cache_manager;
const auto& cache_config = cache_manager->cacheConfig();
auto k_block_size = cache_config.k_block_stride;
auto v_block_size = cache_config.v_block_stride;
auto scale_block_size = cache_config.kv_scale_block_stride;
auto layer_num = maga_init_params_.gpt_init_parameter.num_layers_;
if (v_block_size % load_context.peer_addrs.size() != 0 || k_block_size % load_context.peer_addrs.size() != 0
|| scale_block_size % load_context.peer_addrs.size() != 0) {
RTP_LLM_LOG_WARNING(
"k block size [%d] or v block size [%d] or scale block size [%d] is not divisible by peer ips size [%d]",
k_block_size,
v_block_size,
scale_block_size,
load_context.peer_addrs.size());
return ErrorInfo(ErrorCode::LOAD_KV_CACHE_FAILED, "block size is not divisible by peer ips size");
}
k_block_size = k_block_size / load_context.peer_addrs.size();
v_block_size = v_block_size / load_context.peer_addrs.size();
scale_block_size = scale_block_size / load_context.peer_addrs.size();
auto cancel_check_func = [&load_context]() -> bool { return load_context.server_context->IsCancelled(); };
auto start_load_time_us = currentTimeUs();
std::vector<std::shared_ptr<LoadContext>> load_contexts;
for (int i = 0; i < load_context.peer_addrs.size(); i++) {
auto& peer_addr = load_context.peer_addrs[i];
std::vector<std::shared_ptr<RequestBlockBuffer>> layer_caches;
for (size_t layer_id = 0; layer_id < layer_num; layer_id++) {
auto request_key = std::to_string(load_context.request_id) + "-" + std::to_string(layer_id);
auto load_layer_cache =
std::make_shared<RequestBlockBuffer>(std::to_string(load_context.request_id), request_key);
auto block_num = load_context.block_ids.size();
for (size_t block_pos = load_context.reuse_block_size; block_pos < block_num; block_pos++) {
auto cache_key = makeCacheKey(std::to_string(load_context.cache_keys[block_pos]), layer_id);
auto block_id = load_context.block_ids[block_pos];
auto addr_info = cache_manager->convertIndexToAddr(block_id, layer_id);
void* k_addr = (void*)((int64_t)addr_info.k_addr + i * k_block_size);
std::shared_ptr<void> k_block_addr(k_addr, [](void* p) {});
load_layer_cache->addBlock("k_" + cache_key, k_block_addr, k_block_size, true, true);
if (addr_info.k_scale_addr) {
void* k_scale_addr = (void*)((int64_t)addr_info.k_scale_addr + i * scale_block_size);
std::shared_ptr<void> k_block_scale_addr(k_scale_addr, [](void* p) {});
load_layer_cache->addBlock("k_scale" + cache_key, k_block_scale_addr, scale_block_size, true, true);
}
if (engine_->resourceContext().cache_manager->cacheConfig().use_mla) {
continue;
}
void* v_addr = (void*)((int64_t)addr_info.v_addr + i * v_block_size);
std::shared_ptr<void> v_block_addr(v_addr, [](void* p) {});
load_layer_cache->addBlock("v_" + cache_key, v_block_addr, v_block_size, true, true);
if (addr_info.v_scale_addr) {
void* v_scale_addr = (void*)((int64_t)addr_info.v_scale_addr + i * scale_block_size);
std::shared_ptr<void> v_block_scale_addr(v_scale_addr, [](void* p) {});
load_layer_cache->addBlock("v_scale" + cache_key, v_block_scale_addr, scale_block_size, true, true);
}
}
layer_caches.push_back(load_layer_cache);
}
auto ip_parts = autil::StringUtil::split(peer_addr, ":");
if (ip_parts.size() != 3) {
RTP_LLM_LOG_WARNING("invalid peer ip to load [%s]", peer_addr.c_str());
return ErrorInfo(ErrorCode::LOAD_KV_CACHE_FAILED, "invalid peer ip");
}
auto layer_cache_load_context =
resource_.cache_store->loadBuffers(layer_caches,
ip_parts[0],
autil::StringUtil::strToInt32WithDefault(ip_parts[1].c_str(), 0),
autil::StringUtil::strToInt32WithDefault(ip_parts[2].c_str(), 0),
load_context.timeout_ms,
cancel_check_func,
load_context.partition_count,
load_context.partition_id);
if (!layer_cache_load_context) {
RTP_LLM_LOG_WARNING("request [%s] load cache failed, layer cache load context is nullptr", request_key.c_str());
return ErrorInfo(ErrorCode::LOAD_KV_CACHE_FAILED, "load kv cache failed");
}
load_contexts.push_back(layer_cache_load_context);
if (engine_->isMTP() &&
engine_->getDevice()->getDeviceProperties().tp_rank == 0 &&
load_context.hidden_states_ != nullptr)
{
RTP_LLM_LOG_DEBUG("mtp hidden states rdma get.");
std::vector<std::shared_ptr<RequestBlockBuffer>> hidden_states_caches;
auto load_layer_cache =
std::make_shared<RequestBlockBuffer>(std::to_string(load_context.request_id), "");
auto hidden_states_ptr = load_context.hidden_states_->data();
auto hidden_states_size = load_context.hidden_states_->sizeBytes();
RTP_LLM_LOG_DEBUG("token_num is %d", load_context.hidden_states_->shape()[0]);
RTP_LLM_LOG_DEBUG("decoder need hidden states size is %d", hidden_states_size);
std::shared_ptr<void> hidden_states_addr(hidden_states_ptr, [](void* p) {});
load_layer_cache->addBlock("hidden_states", hidden_states_addr, hidden_states_size, true, true);
hidden_states_caches.push_back(load_layer_cache);
auto hidden_states_load_context =
resource_.cache_store->loadBuffers(hidden_states_caches,
ip_parts[0],
autil::StringUtil::strToInt32WithDefault(ip_parts[1].c_str(), 0),
autil::StringUtil::strToInt32WithDefault(ip_parts[2].c_str(), 0),
load_context.timeout_ms,
cancel_check_func,
load_context.partition_count,
load_context.partition_id);
if (!hidden_states_load_context) {
RTP_LLM_LOG_WARNING("request [%s] load hidden_states failed, layer cache load context is nullptr", request_key.c_str());
return ErrorInfo(ErrorCode::LOAD_KV_CACHE_FAILED, "load hidden_states failed");
}
load_contexts.push_back(hidden_states_load_context);
RTP_LLM_LOG_DEBUG("mtp hidden states rdma end.");
}
}
for (auto& layer_cache_load_context : load_contexts) {
layer_cache_load_context->waitDone();
if (layer_cache_load_context->success()) {
RTP_LLM_LOG_DEBUG("request [%s] load kv cache success", request_key.c_str());
} else {
// TODO(xinfei.sxf) add retry for part failed blocks.
auto load_done_time_us = currentTimeUs();
RTP_LLM_LOG_WARNING("request [%s] load cache failed, status [%s], cost time [%ld] ms",
request_key.c_str(),
layer_cache_load_context->getErrorInfoString().c_str(),
(load_done_time_us - start_load_time_us) / 1000);
}
return layer_cache_load_context->getErrorInfo();
}
return ErrorInfo::OkStatus();
}