optimum_benchmark/backends/py_txi/backend.py (118 lines of code) (raw):

import shutil from tempfile import TemporaryDirectory from typing import Any, Dict, List, Union from py_txi import TEI, TGI, TEIConfig, TGIConfig from ...task_utils import TEXT_EMBEDDING_TASKS, TEXT_GENERATION_TASKS from ..base import Backend from .config import PyTXIConfig class PyTXIBackend(Backend[PyTXIConfig]): NAME: str = "py-txi" pretrained_model: Union[TEI, TGI] def __init__(self, config: PyTXIConfig) -> None: super().__init__(config) def load(self) -> None: self.logger.info("\t+ Creating backend temporary directory") self.tmpdir = TemporaryDirectory() if self.config.no_weights: self.logger.info("\t+ Creating no weights model") self.create_no_weights_model_slow() self.logger.info("\t+ Loading no weights model") self.load_model_with_no_weights() else: self.logger.info("\t+ Downloading pretrained model") self.download_pretrained_model() self.logger.info("\t+ Loading pretrained model") self.load_model_from_pretrained() try: self.tmpdir.cleanup() except Exception: shutil.rmtree(self.tmpdir.name, ignore_errors=True) def load_model_with_no_weights(self) -> None: original_volumes, self.config.volumes = self.config.volumes, {self.tmpdir.name: {"bind": "/data", "mode": "rw"}} self.load_model_from_pretrained() self.config.volumes = original_volumes def load_model_from_pretrained(self) -> None: if self.config.task in TEXT_GENERATION_TASKS: self.pretrained_model = TGI( config=TGIConfig(model_id=self.config.model, **self.txi_kwargs, **self.tgi_kwargs), ) elif self.config.task in TEXT_EMBEDDING_TASKS: self.pretrained_model = TEI( config=TEIConfig(model_id=self.config.model, **self.txi_kwargs, **self.tei_kwargs), ) else: raise NotImplementedError(f"TXI does not support task {self.config.task}") @property def txi_kwargs(self): kwargs = {} if self.config.gpus is not None: kwargs["gpus"] = self.config.gpus if self.config.image is not None: kwargs["image"] = self.config.image if self.config.ports is not None: kwargs["ports"] = self.config.ports if self.config.volumes is not None: kwargs["volumes"] = self.config.volumes if self.config.devices is not None: kwargs["devices"] = self.config.devices if self.config.shm_size is not None: kwargs["shm_size"] = self.config.shm_size if self.config.environment is not None: kwargs["environment"] = self.config.environment if self.config.connection_timeout is not None: kwargs["connection_timeout"] = self.config.connection_timeout if self.config.first_request_timeout is not None: kwargs["first_request_timeout"] = self.config.first_request_timeout if self.config.max_concurrent_requests is not None: kwargs["max_concurrent_requests"] = self.config.max_concurrent_requests return kwargs @property def tei_kwargs(self): kwargs = {} if self.config.dtype is not None: kwargs["dtype"] = self.config.dtype if self.config.pooling is not None: kwargs["pooling"] = self.config.pooling return kwargs @property def tgi_kwargs(self): kwargs = {} if self.config.dtype is not None: kwargs["dtype"] = self.config.dtype if self.config.sharded is not None: kwargs["sharded"] = self.config.sharded if self.config.quantize is not None: kwargs["quantize"] = self.config.quantize if self.config.num_shard is not None: kwargs["num_shard"] = self.config.num_shard if self.config.speculate is not None: kwargs["speculate"] = self.config.speculate if self.config.cuda_graphs is not None: kwargs["cuda_graphs"] = self.config.cuda_graphs if self.config.trust_remote_code is not None: kwargs["trust_remote_code"] = self.config.trust_remote_code if self.config.disable_custom_kernels is not None: kwargs["disable_custom_kernels"] = self.config.disable_custom_kernels return kwargs def prepare_inputs(self, inputs: Dict[str, Any]) -> Dict[str, Any]: if self.config.task in TEXT_GENERATION_TASKS: inputs = {"prompt": self.pretrained_processor.batch_decode(inputs["input_ids"].tolist())} elif self.config.task in TEXT_EMBEDDING_TASKS: inputs = {"text": self.pretrained_processor.batch_decode(inputs["input_ids"].tolist())} else: raise NotImplementedError(f"TXI does not support task {self.config.task}") return inputs def forward(self, inputs: Dict[str, Any], kwargs: Dict[str, Any]) -> List[str]: return self.pretrained_model.encode(**inputs, **kwargs) def prefill(self, inputs: Dict[str, Any], kwargs: Dict[str, Any]) -> Dict[str, Any]: return self.pretrained_model.generate( **inputs, do_sample=kwargs.get("do_sample", False), max_new_tokens=kwargs.get("max_new_tokens"), ) def generate(self, inputs: Dict[str, Any], kwargs: Dict[str, Any]) -> List[str]: return self.pretrained_model.generate( **inputs, do_sample=kwargs.get("do_sample", False), max_new_tokens=kwargs.get("max_new_tokens"), )