optimum_benchmark/backends/llama_cpp/backend.py (48 lines of code) (raw):
from tempfile import TemporaryDirectory
from typing import Any, Dict
from llama_cpp import Llama
from ..base import Backend
from .config import LlamaCppConfig
class LlamaCppBackend(Backend[LlamaCppConfig]):
NAME: str = "llama_cpp"
pretrained_model: Llama
def __init__(self, config: LlamaCppConfig) -> None:
super().__init__(config)
def load(self) -> None:
self.logger.info("\t+ Creating backend temporary directory")
self.tmpdir = TemporaryDirectory()
self.logger.info("\t+ Loading pretrained model")
self.load_model_from_pretrained()
self.tmpdir.cleanup()
def load_model_from_pretrained(self) -> None:
"""
Load the pretrained model from the given model name (normally GGUF, GGML)
"""
self.pretrained_model = Llama.from_pretrained(
self.config.model,
**self.llama_cpp_kwargs,
)
@property
def llama_cpp_kwargs(self) -> Dict[str, Any]:
return {
"embedding": self.config.task == "feature-extraction",
"filename": self.config.filename,
"verbose": False,
"echo": False,
}
def prepare_inputs(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
if self.config.task == "text-generation":
if inputs["input_ids"].shape[0] != 1:
raise ValueError("Batch size must be 1 for Text Generation with llama-cpp-python")
return {"tokens": inputs["input_ids"].squeeze(0).tolist()}
elif self.config.task == "feature-extraction":
return {"input": [self.pretrained_model.detokenize(x).decode("utf-8") for x in inputs["input_ids"]]}
else:
raise ValueError(f"Task {self.config.task} not supported by {self.NAME}")
def forward(self, inputs: Dict[str, Any], kwargs: Dict[str, Any]) -> Any:
self.pretrained_model.embed(**inputs)
def prefill(self, inputs: Dict[str, Any], kwargs: Dict[str, Any]) -> list[int]:
generator = self.pretrained_model.generate(**inputs, reset=True)
for _ in range(kwargs["max_new_tokens"]):
next(generator)
def generate(self, inputs: Dict[str, Any], kwargs: Dict[str, Any]) -> list[int]:
generator = self.pretrained_model.generate(**inputs, reset=True)
for _ in range(kwargs["max_new_tokens"]):
next(generator)