in maga_transformer/cpp/model_rpc/DecodeRpcServer.cc [295:432]
ErrorInfo DecodeRpcServer::loadCacheAsyncForTp(DecodeGenerateContext& decode_context,
LoadKVCacheContext& load_context) {
int64_t load_cache_begin_time_us = currentTimeUs();
struct WorkerRpcContext {
WorkerRpcContext() {
client_context = make_shared<ClientContext>();
}
BroadcastLoadResponsePB response;
Status status;
std::shared_ptr<RpcService::Stub> stub;
std::shared_ptr<ClientContext> client_context;
};
uint32_t worker_size = resource_.grpc_workers.size();
vector<WorkerRpcContext> all_context(worker_size);
uint32_t cq_size = worker_size % 2 == 0 ? worker_size / 2 : worker_size / 2 + 1;
vector<CompletionQueue> completion_queues(cq_size);
vector<int> each_finished_count(cq_size, 0);
if(worker_size == 0 || cq_size == 0) {
RTP_LLM_LOG_WARNING("request:[%s] cq_size or worker_size is 0, worker size = %d, cq size = %d", decode_context.request_key.c_str(), worker_size, cq_size);
return ErrorInfo(ErrorCode::LOAD_KV_CACHE_FAILED, "worker size or cq size is 0");
}
auto worker_size_per_queue = worker_size / completion_queues.size();
RTP_LLM_LOG_DEBUG("request:[%s] start to async remote load for all rank", decode_context.request_key.c_str());
for (int i = 0; i < worker_size; i++) {
auto& worker = resource_.grpc_workers[i];
auto connect_status = resource_.rpc_pool.getConnection(worker);
if (!connect_status.ok()) {
string error_msg = "get grpc connection for rank:" + std::to_string(i) + ", addr:" + worker + " failed";
return ErrorInfo(ErrorCode::GET_CONNECTION_FAILED, error_msg);
}
all_context.push_back(WorkerRpcContext());
auto& rpc_context = all_context[i];
rpc_context.stub = connect_status.value().stub;
BroadcastLoadRequestPB load_request;
if (engine_->resourceContext().cache_manager->cacheConfig().use_mla) {
load_request = constructRemoteLoadRequestForMla(load_context, i, decode_context.peer_addrs);
} else {
load_request = constructRemoteLoadRequest(load_context, i, decode_context.peer_addrs);
}
std::unique_ptr<ClientAsyncResponseReader<BroadcastLoadResponsePB>> reader(rpc_context.stub->AsyncRemoteLoad(
rpc_context.client_context.get(), load_request, &completion_queues[i % completion_queues.size()]));
reader->Finish(&rpc_context.response, &rpc_context.status, reinterpret_cast<void*>(i));
}
bool all_success = true;
size_t finished_count = 0;
auto total_timeout_ms = load_context.timeout_ms + EXTRA_TIMEOUT_MS;
ErrorCode error_code = ErrorCode::NONE_ERROR;
std::string error_msg = "failed to load kv cache in rank: ";
int64_t min_response_done_time_us = 1lu << 60;
int64_t max_response_done_time_us = 0;
while (true) {
RTP_LLM_LOG_DEBUG("request [%s] load cache loop step", decode_context.request_key.c_str());
auto cost_time_ms = (currentTimeUs() - load_cache_begin_time_us) / 1000;
if (cost_time_ms > total_timeout_ms) {
error_msg = "load cache timeout : cost time is " + std::to_string(cost_time_ms)
+ "ms, "
"total timeout for load cache is "
+ std::to_string(total_timeout_ms) + "ms";
return ErrorInfo(ErrorCode::LOAD_CACHE_TIMEOUT, error_msg);
}
if (load_context.server_context->IsCancelled()) {
string error_msg = "request is cancelled";
return ErrorInfo(ErrorCode::CANCELLED, error_msg);
}
auto once_deadline =
std::chrono::system_clock::now()
+ std::chrono::milliseconds(maga_init_params_.gpt_init_parameter.decode_polling_kv_cache_step_ms_);
RTP_LLM_LOG_DEBUG("request [%s] start to execute async next", decode_context.request_key.c_str());
// TODO(xinfei.sxf) There is a problem with complete queue next call delay here, the reason is yet to be
// investigated
void* got_tag;
bool ok = false;
for (uint32_t i = 0; i < completion_queues.size(); i++) {
if (each_finished_count[i] == worker_size_per_queue) {
continue;
}
if (completion_queues[i].AsyncNext(&got_tag, &ok, once_deadline)
== grpc::CompletionQueue::NextStatus::TIMEOUT) {
RTP_LLM_LOG_DEBUG("request [%s] async next timeout", decode_context.request_key.c_str());
continue;
}
each_finished_count[i]++;
if (!ok) {
string error_msg = "async get next event from grpc completion queue failed";
return ErrorInfo(ErrorCode::LOAD_KV_CACHE_FAILED, error_msg);
}
auto rank = reinterpret_cast<uintptr_t>(got_tag);
const auto& status = all_context[rank].status;
const auto& response = all_context[rank].response;
const auto& pb_error_code = response.error_info().error_code();
const auto& pb_error_message = response.error_info().error_message();
min_response_done_time_us = std::min(min_response_done_time_us, response.done_time_us());
max_response_done_time_us = std::max(max_response_done_time_us, response.done_time_us());
RTP_LLM_LOG_DEBUG("request [%s] load cache for rank [%d] done", decode_context.request_key.c_str(), rank);
if (!status.ok()) {
all_success = false;
error_code = ErrorCode::LOAD_KV_CACHE_FAILED;
error_msg += std::to_string(rank) + ": " + status.error_message() + ", ";
} else if (pb_error_code != ErrorCodePB::NONE_ERROR) {
all_success = false;
error_code = transRPCErrorCode(pb_error_code);
error_msg += std::to_string(rank) + ": " + pb_error_message + ", ";
}
finished_count++;
if (finished_count == worker_size) {
break;
}
}
if (finished_count == worker_size) {
break;
}
}
for (auto& completion_queue : completion_queues) {
completion_queue.Shutdown();
}
if (finished_count != worker_size) {
all_success = false;
}
if (!all_success) {
return ErrorInfo(error_code, error_msg);
}
decode_context.stat_info.load_cache_min_rt_us = min_response_done_time_us - load_cache_begin_time_us;
decode_context.stat_info.load_cache_max_rt_us = max_response_done_time_us - load_cache_begin_time_us;
decode_context.stat_info.load_cache_polling_cost_us = currentTimeUs() - max_response_done_time_us;
RTP_LLM_LOG_DEBUG("load_cache_min_rt_us = %ld, load_cache_max_rt_us = %ld, load_cache_polling_cost_us = %ld",
decode_context.stat_info.load_cache_min_rt_us,
decode_context.stat_info.load_cache_max_rt_us,
decode_context.stat_info.load_cache_polling_cost_us);
return ErrorInfo::OkStatus();
}