maga_transformer/async_decoder_engine/engine_creator.py (26 lines of code) (raw):
import logging
import torch
from enum import Enum
from typing import Optional
from maga_transformer.models.propose_model.propose_model import ProposeModel
from maga_transformer.models.base_model import BaseModel
from maga_transformer.async_decoder_engine.embedding.embedding_engine import EmbeddingCppEngine
from maga_transformer.async_decoder_engine.rpc_engine import RPCEngine
from maga_transformer.async_decoder_engine.base_engine import BaseEngine
class ExecutorType(Enum):
Normal = "normal"
Embedding = 'embedding'
def check_exeutor_type(model: BaseModel):
if model.custom_module is not None:
return ExecutorType.Embedding
return ExecutorType.Normal
def create_engine(model: BaseModel, propose_model: Optional[ProposeModel] = None) -> BaseEngine:
torch.ops.rtp_llm.init_engine()
executor_type = check_exeutor_type(model)
logging.info(f"executor_type: {executor_type}")
if executor_type == ExecutorType.Normal:
return RPCEngine(model, propose_model)
elif executor_type == ExecutorType.Embedding:
return EmbeddingCppEngine(model)
else:
raise Exception(f"unsupported executor type: {executor_type}")