maga_transformer/cpp/proto/model_rpc_service.proto (197 lines of code) (raw):
// message.proto
syntax = "proto3";
import "google/protobuf/wrappers.proto";
message TensorPB {
enum DataType {
FP32 = 0;
INT32 = 1;
FP16 = 2;
BF16 = 3;
}
DataType data_type = 1;
repeated int64 shape = 2;
bytes fp32_data = 3;
bytes int32_data = 4;
bytes fp16_data = 5;
bytes bf16_data = 6;
}
message IntVector {
repeated int32 values = 1;
}
message IntMatrix {
repeated IntVector rows = 1;
}
message GenerateConfigPB {
int32 max_new_tokens = 1;
int32 num_beams = 2;
int32 num_return_sequences = 3;
int32 min_new_tokens = 4;
int32 top_k = 5;
float top_p = 6;
float temperature = 7;
float repetition_penalty = 8;
google.protobuf.Int32Value no_repeat_ngram_size = 9;
google.protobuf.Int64Value random_seed = 10;
google.protobuf.FloatValue top_p_decay = 11;
google.protobuf.FloatValue top_p_min = 12;
google.protobuf.Int32Value top_p_reset_ids = 13;
google.protobuf.StringValue task_id = 14;
int32 calculate_loss = 15;
bool return_incremental = 16;
bool return_hidden_states = 17;
bool return_logits = 18;
bool is_streaming = 19;
int32 timeout_ms = 20;
IntMatrix stop_words_list = 21;
repeated int32 select_tokens_id = 22;
google.protobuf.StringValue adapter_name = 23;
bool sp_edit = 24;
repeated int32 sp_advice_prompt_token_ids = 25;
bool force_disable_sp_run = 26;
bool return_all_probs = 27;
bool sp_input_lookup = 28;
bool can_use_pd_separation = 29;
bool return_softmax_probs = 30;
bool return_cum_log_probs = 31;
bool in_think_mode = 32;
int32 max_thinking_tokens = 33;
repeated int32 end_think_token_ids = 34;
bool gen_timeline = 35;
int32 global_request_id = 36;
}
message MMPreprocessConfigPB {
int32 width = 1;
int32 height = 2;
int32 min_pixels = 3;
int32 max_pixels = 4;
int32 fps = 5;
int32 min_frames = 6;
int32 max_frames = 7;
}
message MultimodalInputPB {
string multimodal_url = 1;
int32 multimodal_type = 2;
TensorPB multimodal_tensor = 3;
MMPreprocessConfigPB mm_preprocess_config = 4;
}
message MultimodalInputsPB {
repeated MultimodalInputPB multimodal_inputs = 1;
}
message MultimodalOutputPB {
TensorPB multimodal_embedding = 1;
TensorPB multimodal_pos_id = 2;
}
message MultimodalOutputsPB {
repeated MultimodalOutputPB multimodal_outputs = 1;
}
message GenerateInputPB {
int64 request_id = 1;
repeated int32 token_ids = 2;
repeated MultimodalInputPB multimodal_inputs = 3;
GenerateConfigPB generate_config = 4;
string client_id = 5;
int64 start_time = 6;
}
message AuxInfoPB {
int32 cost_time_us = 1;
int32 iter_count = 2;
int32 input_len = 3;
int32 reuse_len = 4;
int32 prefix_len = 5;
int32 output_len = 6;
int32 fallback_tokens = 7;
int32 fallback_times = 8;
TensorPB cum_log_probs = 9;
int32 step_output_len = 10;
bool pd_sep = 11;
int32 first_token_cost_time_us = 12;
TensorPB softmax_probs = 13;
int32 wait_time_us = 14;
}
message GenerateOutputPB {
bool finished = 1;
AuxInfoPB aux_info = 2;
TensorPB output_ids = 3;
TensorPB hidden_states = 4;
TensorPB loss = 5;
TensorPB logits = 6;
TensorPB all_probs = 7;
}
message GenerateOutputsPB {
int64 request_id = 1;
repeated GenerateOutputPB generate_outputs = 2;
int64 receive_load_time = 3;
int64 start_load_time = 4;
int64 receive_generate_time = 5;
int64 load_done_time = 6;
int64 begin_compute_time = 7;
int64 compute_done_time = 8;
RpcErrorPB error_info = 9;
}
// return to python client
message ErrorDetailsPB {
int64 error_code = 1;
string error_message = 2;
}
enum ErrorCodePB {
NONE_ERROR = 0;
UNKNOWN_ERROR = 1;
CANCELLED = 2;
LOAD_CACHE_TIMEOUT = 3;
CACHE_STORE_LOAD_CONNECT_FAILED = 4;
CACHE_STORE_LOAD_SEND_REQUEST_FAILED = 5;
CACHE_STORE_CALL_PREFILL_TIMEOUT = 6;
CACHE_STORE_LOAD_RDMA_CONNECT_FAILED = 7;
CACHE_STORE_LOAD_RDMA_WRITE_FAILED = 8;
CACHE_STORE_LOAD_BUFFER_TIMEOUT = 9;
}
// transfer between prefill and decode
message RpcErrorPB {
ErrorCodePB error_code = 1;
string error_message = 2;
}
message EmptyPB {}
enum RemoteStage {
ALLOCATE = 0;
LOAD = 1;
GENERATE = 2;
}
message GenerateRequestPB {
RemoteStage stage = 1;
int64 request_id = 2;
string client_id = 3;
int64 start_time = 4;
int32 first_generate_token_id = 5;
GenerateInputPB input = 6;
repeated string peer_addrs = 7;
}
// in tp case, decode master broadcast to decode worker
message BroadcastLoadRequestPB {
int64 request_id = 1;
string request_key = 2;
repeated string peer_addrs = 3;
repeated int64 cache_keys = 4;
repeated int32 block_ids = 5;
int64 block_num = 6;
int64 reuse_block_size = 7;
int64 timeout_ms = 8;
int64 dp_rank = 9;
int32 partition_count = 10;
int32 partition_id = 11;
}
message BroadcastLoadResponsePB {
RpcErrorPB error_info = 1;
int64 done_time_us = 2;
}
message RemoteFinishRequestPB {
int64 request_id = 1;
}
service RpcService {
rpc GenerateStreamCall(GenerateInputPB) returns (stream GenerateOutputsPB);
rpc RemoteLoad(BroadcastLoadRequestPB) returns (BroadcastLoadResponsePB);
rpc RemoteGenerate(stream GenerateRequestPB) returns (stream GenerateOutputsPB);
rpc RemoteFinish(RemoteFinishRequestPB) returns (EmptyPB);
}
service MultimodalRpcService {
rpc RemoteMultimodalEmbedding(MultimodalInputsPB) returns (MultimodalOutputsPB);
}