maga_transformer/async_decoder_engine/async_model.py (62 lines of code) (raw):

import torch import logging from typing import Optional, Dict from maga_transformer.models.propose_model.propose_model import ProposeModel from maga_transformer.models.base_model import BaseModel, GenerateInput from maga_transformer.config.generate_config import GenerateConfig from maga_transformer.async_decoder_engine.engine_creator import create_engine from maga_transformer.distribute.worker_info import g_parallel_info from maga_transformer.config.task_type import TaskType from maga_transformer.config.exceptions import ExceptionType, FtRuntimeException from maga_transformer.models.multimodal.multimodal_mixin import MultiModalMixin from maga_transformer.ops import LoadBalanceInfo, EngineScheduleInfo from maga_transformer.utils.gemm_utils.device_map import get_device class AsyncModel: def __init__(self, model: BaseModel, propose_model: Optional[ProposeModel] = None) -> None: self.model = model self.propose_model = propose_model self.config = model.config self.model_runtime_meta = self._model_runtime_meta() assert self.config.max_seq_len > 0 self.tokenizer = model.tokenizer self.decoder_engine_ = create_engine(self.model, self.propose_model) self.decoder_engine_.start() def is_multimodal(self) -> bool: return self.config.is_multimodal def _model_runtime_meta(self) -> str: try: device_name = torch.cuda.get_device_name(0) manchine_name = get_device(device_name).upper() except Exception as e: logging.info(f"error get device name with error: {e}") manchine_name = "unknown" parallel_info = f"TP{g_parallel_info.tp_size}_PP{g_parallel_info.pp_size}_EP{g_parallel_info.ep_size}" weight_info = f"W{self.config.gpt_init_params.quant_algo.getWeightBits()}A{self.config.gpt_init_params.quant_algo.getActivationBits()}" return "_".join([manchine_name, parallel_info, weight_info]) @property def default_generate_config(self) -> GenerateConfig: return self.model.default_generate_config @property def task_type(self) -> TaskType: return self.model.task_type def stop(self): self.decoder_engine_.stop() def ready(self): return self.decoder_engine_.ready() @torch.no_grad() def enqueue(self, input: GenerateInput): if g_parallel_info.tp_size > 1 and g_parallel_info.tp_rank > 0: raise Exception('bug, not supposed to be here') if input.prompt_length <= 0: raise FtRuntimeException(ExceptionType.LONG_PROMPT_ERROR, f"model tokens can not be empty, request length is {input.prompt_length}") max_new_tokens = min(self.config.max_seq_len - input.prompt_length, input.generate_config.max_new_tokens) if max_new_tokens <= 0: raise FtRuntimeException(ExceptionType.LONG_PROMPT_ERROR, f"model max tokens is {self.config.max_seq_len}, " \ f"request length is {input.prompt_length}, max_new_tokens is {max_new_tokens}") return self.decoder_engine_.decode(input) def get_load_balance_info(self) -> LoadBalanceInfo: return self.decoder_engine_.get_load_balance_info() def get_engine_schedule_info(self) -> EngineScheduleInfo: return self.decoder_engine_.get_engine_schedule_info()