text-generation-inference/server/text_generation_server/auto_generator.py (26 lines of code) (raw):
from loguru import logger
from .generator_base import Generator
from .jetstream_pt_support import model_can_use_jetstream_pt
class AutoGenerator:
@staticmethod
def from_pretrained(
model_path: str, revision: str, max_batch_size: int, max_sequence_length: int, max_input_tokens: int = None
) -> Generator:
"""Instantiate a Generator for TPU using Jetstream Pytorch or Pytorch/XLA.
Args:
model_path (`str`):
The path to a local model. This path must also contain a Tokenizer.
revision (`str`):
The revision of the model.
max_batch_size (`int`):
The maximum batch size.
max_sequence_length (`int`):
The maximum sequence length.
max_input_tokens (`int`):
The maximum number of tokens allowed in the input. When set to None, it will be set to 80% of the
`max_sequence_length`.
Returns:
A TpuGenerator.
"""
if max_input_tokens is None:
max_input_tokens = int(0.8 * max_sequence_length)
if model_can_use_jetstream_pt(model_path):
logger.debug("Using Jetstream PyTorch generator.")
from .jetstream_pt_support.generator import TpuGeneratorJetStream
return TpuGeneratorJetStream.from_pretrained(
model_path,
revision=revision,
max_batch_size=max_batch_size,
max_sequence_length=max_sequence_length,
max_input_tokens=max_input_tokens,
)
else:
logger.debug("Using PyTorch/XLA generator.")
from .generator import TpuGenerator
return TpuGenerator.from_pretrained(
model_path, revision=revision, max_batch_size=max_batch_size, max_sequence_length=max_sequence_length
)