optimum_benchmark/backends/tensorrt_llm/backend.py (90 lines of code) (raw):
import shutil
from collections import OrderedDict
from tempfile import TemporaryDirectory
from typing import Any, Dict
from hydra.utils import get_class
from ..base import Backend
from .config import TRTLLMConfig
from .utils import MODEL_TYPE_TO_TRTLLMMODELS
class TRTLLMBackend(Backend[TRTLLMConfig]):
NAME = "tensorrt-llm"
def __init__(self, config: TRTLLMConfig):
super().__init__(config)
if self.config.model_type in MODEL_TYPE_TO_TRTLLMMODELS:
self.trtllm_loader = get_class(MODEL_TYPE_TO_TRTLLMMODELS[self.config.model_type])
self.logger.info(f"\t+ Using TRTLLMModel class {self.trtllm_loader.__name__}")
else:
raise NotImplementedError(f"TRTLLMBackend does not support model_type {self.config.model_type}")
def load(self) -> None:
self.logger.info("\t+ Creating backend temporary directory")
self.tmpdir = TemporaryDirectory()
if self.config.no_weights:
self.logger.info("\t+ Creating no weights model")
self.create_no_weights_model_slow()
self.logger.info("\t+ Loading no weights model")
self.load_model_with_no_weights()
else:
self.logger.info("\t+ Downloading pretrained model")
self.download_pretrained_model()
self.logger.info("\t+ Loading pretrained model")
self.load_model_from_pretrained()
try:
self.tmpdir.cleanup()
except Exception:
shutil.rmtree(self.tmpdir.name, ignore_errors=True)
def load_model_with_no_weights(self) -> None:
original_model, self.config.model = self.config.model, self.no_weights_model_path.as_posix()
self.load_model_from_pretrained()
self.config.model = original_model
def load_model_from_pretrained(self) -> None:
self.pretrained_model = self.trtllm_loader.from_pretrained(
self.config.model,
**self.config.model_kwargs,
**self.trtllm_kwargs,
)
@property
def trtllm_kwargs(self):
kwargs = {}
if self.config.tp is not None:
kwargs["tp"] = self.config.tp
if self.config.pp is not None:
kwargs["pp"] = self.config.pp
if self.config.dtype is not None:
kwargs["dtype"] = self.config.dtype
if self.config.use_fp8 is not None:
kwargs["use_fp8"] = self.config.use_fp8
if self.config.world_size is not None:
kwargs["world_size"] = self.config.world_size
if self.config.gpus_per_node is not None:
kwargs["gpus_per_node"] = self.config.gpus_per_node
if self.config.max_input_len is not None:
kwargs["max_input_len"] = self.config.max_input_len
if self.config.max_output_len is not None:
kwargs["max_output_len"] = self.config.max_output_len
if self.config.max_batch_size is not None:
kwargs["max_batch_size"] = self.config.max_batch_size
if self.config.max_new_tokens is not None:
kwargs["max_new_tokens"] = self.config.max_new_tokens
if self.config.max_prompt_length is not None:
kwargs["max_prompt_length"] = self.config.max_prompt_length
if self.config.optimization_level is not None:
kwargs["optimization_level"] = self.config.optimization_level
if self.config.use_cuda_graph is not None:
kwargs["use_cuda_graph"] = self.config.use_cuda_graph
return kwargs
def prefill(self, inputs: Dict[str, Any], kwargs: Dict[str, Any]) -> OrderedDict:
return self.pretrained_model.generate(
input_ids=inputs.get("input_ids"),
attention_mask=inputs.get("attention_mask"),
pad_token_id=kwargs.get("pad_token_id", 0),
eos_token_id=kwargs.get("eos_token_id", 1),
**kwargs,
)
def generate(self, inputs: Dict[str, Any], kwargs: Dict[str, Any]) -> OrderedDict:
return self.pretrained_model.generate(
input_ids=inputs.get("input_ids"),
attention_mask=inputs.get("attention_mask"),
pad_token_id=kwargs.get("pad_token_id", 0),
eos_token_id=kwargs.get("eos_token_id", 1),
**kwargs,
)