text-generation-inference/server/text_generation_server/generator_base.py (25 lines of code) (raw):
from abc import ABC
from typing import List, Optional, Tuple
from .pb.generate_pb2 import (
Batch,
CachedBatch,
Generation,
InfoResponse,
)
class Generator(ABC):
"""An abstract class to represent the workhorse behind TextGenerationService.
Ideally, it should not rely on protobuf constructs, but in a first step it does.
Implementations would typically need a model and a tokenizer to implement the Generator methods.
"""
@property
def info(self) -> InfoResponse:
"""This should simply return the expected InfoResponse"""
raise NotImplementedError
def warmup(self, batch: Batch) -> int:
"""Verify if the hardware can support the target load.
Args:
batch (`Batch`):
A batch corresponding to the maximum number of concurrent requests.
Return:
The maximum number of tokens the model supports.
"""
raise NotImplementedError
def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]:
"""Prefill is called whenever new requests need to be added.
When this method returns successfully, a decode method will follow
with both the current and newly prefilled batch(es).
Args:
batch (`Batch`):
A batch containing the new requests.
Return:
A list of `Generation` for each request and a `CachedBatch` containing all pending requests.
"""
raise NotImplementedError
def decode(self, batches: List[Batch]) -> Tuple[List[Generation], CachedBatch]:
"""Decode after a prefill or another decode."""
raise NotImplementedError
def filter(self, batch_id: int, request_ids: List[int]) -> CachedBatch:
"""Remove requests that are not listed from the specified batch"""
raise NotImplementedError
def clear(self, batch_id: Optional[int] = None):
"""Remove all requests from the generator"""
raise NotImplementedError
@classmethod
def from_pretrained(cls, model_id: str, revision: Optional[str]):
"""Factory method "a la transformers" """
raise NotImplementedError