ErrorInfo DecodeRpcServer::loadCacheAsyncForTp()

in maga_transformer/cpp/model_rpc/DecodeRpcServer.cc [295:432]


ErrorInfo DecodeRpcServer::loadCacheAsyncForTp(DecodeGenerateContext& decode_context,
                                               LoadKVCacheContext&    load_context) {
    int64_t load_cache_begin_time_us = currentTimeUs();

    struct WorkerRpcContext {
        WorkerRpcContext() {
            client_context = make_shared<ClientContext>();
        }
        BroadcastLoadResponsePB           response;
        Status                            status;
        std::shared_ptr<RpcService::Stub> stub;
        std::shared_ptr<ClientContext>    client_context;
    };

    uint32_t                 worker_size = resource_.grpc_workers.size();
    vector<WorkerRpcContext> all_context(worker_size);
    uint32_t                 cq_size = worker_size % 2 == 0 ? worker_size / 2 : worker_size / 2 + 1;
    vector<CompletionQueue>  completion_queues(cq_size);
    vector<int>              each_finished_count(cq_size, 0);
    if(worker_size == 0 || cq_size == 0) {
        RTP_LLM_LOG_WARNING("request:[%s] cq_size or worker_size is 0, worker size = %d, cq size = %d", decode_context.request_key.c_str(), worker_size, cq_size);
        return ErrorInfo(ErrorCode::LOAD_KV_CACHE_FAILED, "worker size or cq size is 0");
    }
    auto                     worker_size_per_queue = worker_size / completion_queues.size();
    RTP_LLM_LOG_DEBUG("request:[%s] start to async remote load for all rank", decode_context.request_key.c_str());
    for (int i = 0; i < worker_size; i++) {
        auto& worker         = resource_.grpc_workers[i];
        auto  connect_status = resource_.rpc_pool.getConnection(worker);
        if (!connect_status.ok()) {
            string error_msg = "get grpc connection for rank:" + std::to_string(i) + ", addr:" + worker + " failed";
            return ErrorInfo(ErrorCode::GET_CONNECTION_FAILED, error_msg);
        }
        all_context.push_back(WorkerRpcContext());
        auto& rpc_context                   = all_context[i];
        rpc_context.stub                    = connect_status.value().stub;
        BroadcastLoadRequestPB load_request;
        if (engine_->resourceContext().cache_manager->cacheConfig().use_mla) {
            load_request = constructRemoteLoadRequestForMla(load_context, i, decode_context.peer_addrs);
        } else {
            load_request = constructRemoteLoadRequest(load_context, i, decode_context.peer_addrs);
        }
        std::unique_ptr<ClientAsyncResponseReader<BroadcastLoadResponsePB>> reader(rpc_context.stub->AsyncRemoteLoad(
            rpc_context.client_context.get(), load_request, &completion_queues[i % completion_queues.size()]));
        reader->Finish(&rpc_context.response, &rpc_context.status, reinterpret_cast<void*>(i));
    }

    bool        all_success               = true;
    size_t      finished_count            = 0;
    auto        total_timeout_ms          = load_context.timeout_ms + EXTRA_TIMEOUT_MS;
    ErrorCode   error_code                = ErrorCode::NONE_ERROR;
    std::string error_msg                 = "failed to load kv cache in rank: ";
    int64_t     min_response_done_time_us = 1lu << 60;
    int64_t     max_response_done_time_us = 0;
    while (true) {
        RTP_LLM_LOG_DEBUG("request [%s] load cache loop step", decode_context.request_key.c_str());
        auto cost_time_ms = (currentTimeUs() - load_cache_begin_time_us) / 1000;
        if (cost_time_ms > total_timeout_ms) {
            error_msg = "load cache timeout : cost time is " + std::to_string(cost_time_ms)
                        + "ms, "
                          "total timeout for load cache is "
                        + std::to_string(total_timeout_ms) + "ms";
            return ErrorInfo(ErrorCode::LOAD_CACHE_TIMEOUT, error_msg);
        }
        if (load_context.server_context->IsCancelled()) {
            string error_msg = "request is cancelled";
            return ErrorInfo(ErrorCode::CANCELLED, error_msg);
        }
        auto once_deadline =
            std::chrono::system_clock::now()
            + std::chrono::milliseconds(maga_init_params_.gpt_init_parameter.decode_polling_kv_cache_step_ms_);
        RTP_LLM_LOG_DEBUG("request [%s] start to execute async next", decode_context.request_key.c_str());
        // TODO(xinfei.sxf) There is a problem with complete queue next call delay here, the reason is yet to be
        // investigated
        void* got_tag;
        bool  ok = false;
        for (uint32_t i = 0; i < completion_queues.size(); i++) {
            if (each_finished_count[i] == worker_size_per_queue) {
                continue;
            }
            if (completion_queues[i].AsyncNext(&got_tag, &ok, once_deadline)
                == grpc::CompletionQueue::NextStatus::TIMEOUT) {
                RTP_LLM_LOG_DEBUG("request [%s] async next timeout", decode_context.request_key.c_str());
                continue;
            }
            each_finished_count[i]++;
            if (!ok) {
                string error_msg = "async get next event from grpc completion queue failed";
                return ErrorInfo(ErrorCode::LOAD_KV_CACHE_FAILED, error_msg);
            }
            auto        rank             = reinterpret_cast<uintptr_t>(got_tag);
            const auto& status           = all_context[rank].status;
            const auto& response         = all_context[rank].response;
            const auto& pb_error_code    = response.error_info().error_code();
            const auto& pb_error_message = response.error_info().error_message();
            min_response_done_time_us    = std::min(min_response_done_time_us, response.done_time_us());
            max_response_done_time_us    = std::max(max_response_done_time_us, response.done_time_us());
            RTP_LLM_LOG_DEBUG("request [%s] load cache for rank [%d] done", decode_context.request_key.c_str(), rank);
            if (!status.ok()) {
                all_success = false;
                error_code  = ErrorCode::LOAD_KV_CACHE_FAILED;
                error_msg += std::to_string(rank) + ": " + status.error_message() + ", ";
            } else if (pb_error_code != ErrorCodePB::NONE_ERROR) {
                all_success = false;
                error_code  = transRPCErrorCode(pb_error_code);
                error_msg += std::to_string(rank) + ": " + pb_error_message + ", ";
            }
            finished_count++;
            if (finished_count == worker_size) {
                break;
            }
        }
        if (finished_count == worker_size) {
            break;
        }
    }

    for (auto& completion_queue : completion_queues) {
        completion_queue.Shutdown();
    }

    if (finished_count != worker_size) {
        all_success = false;
    }
    if (!all_success) {
        return ErrorInfo(error_code, error_msg);
    }

    decode_context.stat_info.load_cache_min_rt_us       = min_response_done_time_us - load_cache_begin_time_us;
    decode_context.stat_info.load_cache_max_rt_us       = max_response_done_time_us - load_cache_begin_time_us;
    decode_context.stat_info.load_cache_polling_cost_us = currentTimeUs() - max_response_done_time_us;

    RTP_LLM_LOG_DEBUG("load_cache_min_rt_us = %ld, load_cache_max_rt_us = %ld, load_cache_polling_cost_us = %ld",
                 decode_context.stat_info.load_cache_min_rt_us,
                 decode_context.stat_info.load_cache_max_rt_us,
                 decode_context.stat_info.load_cache_polling_cost_us);

    return ErrorInfo::OkStatus();
}