optimum_benchmark/backends/vllm/backend.py (100 lines of code) (raw):

import asyncio import shutil from tempfile import TemporaryDirectory from typing import Any, Dict, Union from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.llm_engine import LLMEngine from vllm.sampling_params import SamplingParams from ...task_utils import TEXT_GENERATION_TASKS from ..base import Backend from .config import VLLMConfig class VLLMBackend(Backend[VLLMConfig]): NAME: str = "vllm" pretrained_model: Union[LLMEngine, AsyncLLMEngine] def __init__(self, config: VLLMConfig) -> None: super().__init__(config) if self.config.task not in TEXT_GENERATION_TASKS: raise NotImplementedError(f"We only support text generation tasks for VLLM, but got {self.config.task}") def load(self) -> None: self.logger.info("\t+ Creating backend temporary directory") self.tmpdir = TemporaryDirectory() if self.config.no_weights: self.logger.info("\t+ Creating no weights model") self.create_no_weights_model_slow() self.logger.info("\t+ Loading no weights model") self.load_model_with_no_weights() else: self.logger.info("\t+ Downloading pretrained model") self.download_pretrained_model() self.logger.info("\t+ Loading pretrained model") self.load_model_from_pretrained() try: self.tmpdir.cleanup() except Exception: shutil.rmtree(self.tmpdir.name, ignore_errors=True) def load_model_with_no_weights(self) -> None: original_model, self.config.model = self.config.model, self.no_weights_model_path.as_posix() self.load_model_from_pretrained() self.config.model = original_model def load_model_from_pretrained(self) -> None: if self.config.serving_mode == "offline": self.pretrained_model = LLMEngine.from_engine_args(EngineArgs(**self.vllm_kwargs)) else: self.pretrained_model = AsyncLLMEngine.from_engine_args(AsyncEngineArgs(**self.vllm_kwargs)) @property def vllm_kwargs(self): return { "model": self.config.model, "tokenizer": self.config.processor, "device": self.config.device, **self.config.engine_args, } def prepare_inputs(self, inputs: Dict[str, Any]) -> Dict[str, Any]: if self.config.task in TEXT_GENERATION_TASKS: inputs = {"prompts": self.pretrained_processor.batch_decode(inputs["input_ids"])} else: raise NotImplementedError(f"vLLM does not support task {self.config.task}") return inputs def batch_offline_engine_generate(self, inputs: Dict[str, Any], kwargs: Dict[str, Any]) -> Any: for i, prompt in enumerate(inputs["prompts"]): self.pretrained_model.add_request( prompt=prompt, request_id=str(i), params=self.get_sampling_params(kwargs), ) while self.pretrained_model.has_unfinished_requests(): self.pretrained_model.step() def get_sampling_params(self, kwargs: Dict[str, Any]) -> SamplingParams: return SamplingParams( ignore_eos=True, detokenize=True, seed=self.config.seed, n=kwargs.get("num_return_sequences"), max_tokens=kwargs.get("max_new_tokens"), min_tokens=kwargs.get("min_new_tokens"), logits_processors=kwargs.get("logits_processors", None), ) async def single_online_engine_generate(self, prompt: str, request_id: str, kwargs: Dict[str, Any]) -> Any: stream = await self.pretrained_model.add_request( prompt=prompt, request_id=request_id, params=self.get_sampling_params(kwargs), ) async for _ in stream: pass async def batch_online_engine_generate(self, inputs: Dict[str, Any], kwargs: Dict[str, Any]) -> Any: tasks = [ self.single_online_engine_generate(prompt, str(i), kwargs) for i, prompt in enumerate(inputs["prompts"]) ] await asyncio.gather(*tasks) def prefill(self, inputs: Dict[str, Any], kwargs: Dict[str, Any]) -> Dict[str, Any]: if self.config.serving_mode == "offline": self.batch_offline_engine_generate(inputs, kwargs) else: asyncio.run(self.batch_online_engine_generate(inputs, kwargs)) def generate(self, inputs: Dict[str, Any], kwargs: Dict[str, Any]) -> Any: if self.config.serving_mode == "offline": self.batch_offline_engine_generate(inputs, kwargs) else: asyncio.run(self.batch_online_engine_generate(inputs, kwargs))