ErrorInfo DecodeRpcServer::loadCache()

in maga_transformer/cpp/model_rpc/DecodeRpcServer.cc [502:638]


ErrorInfo DecodeRpcServer::loadCache(const LoadKVCacheContext& load_context) {
    AtomicGuard request_guard(onflight_load_cache_requests_);
    const auto& request_key      = load_context.request_key;
    auto        cache_manager    = engine_->resourceContext().cache_manager;
    const auto& cache_config     = cache_manager->cacheConfig();
    auto        k_block_size       = cache_config.k_block_stride;
    auto        v_block_size       = cache_config.v_block_stride;
    auto        scale_block_size = cache_config.kv_scale_block_stride;
    auto        layer_num        = maga_init_params_.gpt_init_parameter.num_layers_;

    if (v_block_size % load_context.peer_addrs.size() != 0 || k_block_size % load_context.peer_addrs.size() != 0
        || scale_block_size % load_context.peer_addrs.size() != 0) {
        RTP_LLM_LOG_WARNING(
            "k block size [%d] or v block size [%d] or scale block size [%d] is not divisible by peer ips size [%d]",
            k_block_size,
            v_block_size,
            scale_block_size,
            load_context.peer_addrs.size());
        return ErrorInfo(ErrorCode::LOAD_KV_CACHE_FAILED, "block size is not divisible by peer ips size");
    }
    k_block_size     = k_block_size / load_context.peer_addrs.size();
    v_block_size = v_block_size / load_context.peer_addrs.size();
    scale_block_size = scale_block_size / load_context.peer_addrs.size();

    auto cancel_check_func  = [&load_context]() -> bool { return load_context.server_context->IsCancelled(); };
    auto start_load_time_us = currentTimeUs();
    std::vector<std::shared_ptr<LoadContext>> load_contexts;
    for (int i = 0; i < load_context.peer_addrs.size(); i++) {
        auto&                                            peer_addr = load_context.peer_addrs[i];
        std::vector<std::shared_ptr<RequestBlockBuffer>> layer_caches;

        for (size_t layer_id = 0; layer_id < layer_num; layer_id++) {
            auto request_key = std::to_string(load_context.request_id) + "-" + std::to_string(layer_id);
            auto load_layer_cache =
                std::make_shared<RequestBlockBuffer>(std::to_string(load_context.request_id), request_key);
            auto block_num = load_context.block_ids.size();

            for (size_t block_pos = load_context.reuse_block_size; block_pos < block_num; block_pos++) {
                auto  cache_key = makeCacheKey(std::to_string(load_context.cache_keys[block_pos]), layer_id);
                auto  block_id  = load_context.block_ids[block_pos];
                auto  addr_info = cache_manager->convertIndexToAddr(block_id, layer_id);
                void* k_addr    = (void*)((int64_t)addr_info.k_addr + i * k_block_size);
                std::shared_ptr<void> k_block_addr(k_addr, [](void* p) {});
                load_layer_cache->addBlock("k_" + cache_key, k_block_addr, k_block_size, true, true);
                if (addr_info.k_scale_addr) {
                    void* k_scale_addr = (void*)((int64_t)addr_info.k_scale_addr + i * scale_block_size);
                    std::shared_ptr<void> k_block_scale_addr(k_scale_addr, [](void* p) {});
                    load_layer_cache->addBlock("k_scale" + cache_key, k_block_scale_addr, scale_block_size, true, true);
                }
                if (engine_->resourceContext().cache_manager->cacheConfig().use_mla) {
                    continue;
                }
                void* v_addr    = (void*)((int64_t)addr_info.v_addr + i * v_block_size);
                std::shared_ptr<void> v_block_addr(v_addr, [](void* p) {});
                load_layer_cache->addBlock("v_" + cache_key, v_block_addr, v_block_size, true, true);
                if (addr_info.v_scale_addr) {
                    void* v_scale_addr = (void*)((int64_t)addr_info.v_scale_addr + i * scale_block_size);
                    std::shared_ptr<void> v_block_scale_addr(v_scale_addr, [](void* p) {});
                    load_layer_cache->addBlock("v_scale" + cache_key, v_block_scale_addr, scale_block_size, true, true);
                }
            }
            layer_caches.push_back(load_layer_cache);
        }

        auto ip_parts = autil::StringUtil::split(peer_addr, ":");
        if (ip_parts.size() != 3) {
            RTP_LLM_LOG_WARNING("invalid peer ip to load [%s]", peer_addr.c_str());
            return ErrorInfo(ErrorCode::LOAD_KV_CACHE_FAILED, "invalid peer ip");
        }

        auto layer_cache_load_context =
            resource_.cache_store->loadBuffers(layer_caches,
                                               ip_parts[0],
                                               autil::StringUtil::strToInt32WithDefault(ip_parts[1].c_str(), 0),
                                               autil::StringUtil::strToInt32WithDefault(ip_parts[2].c_str(), 0),
                                               load_context.timeout_ms,
                                               cancel_check_func,
                                               load_context.partition_count,
                                               load_context.partition_id);
        if (!layer_cache_load_context) {
            RTP_LLM_LOG_WARNING("request [%s] load cache failed, layer cache load context is nullptr", request_key.c_str());
            return ErrorInfo(ErrorCode::LOAD_KV_CACHE_FAILED, "load kv cache failed");
        }
        load_contexts.push_back(layer_cache_load_context);

        if (engine_->isMTP() &&
            engine_->getDevice()->getDeviceProperties().tp_rank == 0 &&
            load_context.hidden_states_ != nullptr)
        {
            RTP_LLM_LOG_DEBUG("mtp hidden states rdma get.");
            std::vector<std::shared_ptr<RequestBlockBuffer>> hidden_states_caches;
            auto load_layer_cache =
                std::make_shared<RequestBlockBuffer>(std::to_string(load_context.request_id), "");
            auto hidden_states_ptr = load_context.hidden_states_->data();
            auto hidden_states_size = load_context.hidden_states_->sizeBytes();
            RTP_LLM_LOG_DEBUG("token_num is %d", load_context.hidden_states_->shape()[0]);
            RTP_LLM_LOG_DEBUG("decoder need hidden states size is %d", hidden_states_size);
            std::shared_ptr<void> hidden_states_addr(hidden_states_ptr, [](void* p) {});
            load_layer_cache->addBlock("hidden_states", hidden_states_addr, hidden_states_size, true, true);
            hidden_states_caches.push_back(load_layer_cache);
            auto hidden_states_load_context =
                resource_.cache_store->loadBuffers(hidden_states_caches,
                                                    ip_parts[0],
                                                    autil::StringUtil::strToInt32WithDefault(ip_parts[1].c_str(), 0),
                                                    autil::StringUtil::strToInt32WithDefault(ip_parts[2].c_str(), 0),
                                                    load_context.timeout_ms,
                                                    cancel_check_func,
                                                    load_context.partition_count,
                                                    load_context.partition_id);
            if (!hidden_states_load_context) {
                RTP_LLM_LOG_WARNING("request [%s] load hidden_states failed, layer cache load context is nullptr", request_key.c_str());
                return ErrorInfo(ErrorCode::LOAD_KV_CACHE_FAILED, "load hidden_states failed");
            }
            load_contexts.push_back(hidden_states_load_context);
            RTP_LLM_LOG_DEBUG("mtp hidden states rdma end.");
        }


    }

    for (auto& layer_cache_load_context : load_contexts) {
        layer_cache_load_context->waitDone();
        if (layer_cache_load_context->success()) {
            RTP_LLM_LOG_DEBUG("request [%s] load kv cache success", request_key.c_str());
        } else {
            // TODO(xinfei.sxf) add retry for part failed blocks.
            auto load_done_time_us = currentTimeUs();
            RTP_LLM_LOG_WARNING("request [%s] load cache failed, status [%s], cost time [%ld] ms",
                           request_key.c_str(),
                           layer_cache_load_context->getErrorInfoString().c_str(),
                           (load_done_time_us - start_load_time_us) / 1000);
        }
        return layer_cache_load_context->getErrorInfo();
    }

    return ErrorInfo::OkStatus();
}