text-generation-inference/server/text_generation_server/cli.py (78 lines of code) (raw):
import os
import sys
from typing import Optional
import typer
from loguru import logger
app = typer.Typer()
@app.command()
def serve(
model_id: str,
revision: Optional[str] = None,
sharded: bool = False,
trust_remote_code: bool = None,
uds_path: str = "/tmp/text-generation-server",
logger_level: str = "INFO",
json_output: bool = False,
otlp_service_name: str = "text-generation-inference.server",
max_input_tokens: Optional[int] = None,
):
"""This is the main entry-point for the server CLI.
Args:
model_id (`str`):
The *model_id* of a model on the HuggingFace hub or the path to a local model.
revision (`Optional[str]`, defaults to `None`):
The revision of the model on the HuggingFace hub.
sharded (`bool`):
Whether the model must be sharded or not. Kept for compatibility with the
text-generation-launcher, but must be set to False.
trust-remote-code (`bool`):
Kept for compatibility with text-generation-launcher. Ignored.
uds_path (`Union[Path, str]`):
The local path on which the server will expose its google RPC services.
logger_level (`str`):
The server logger level. Defaults to *INFO*.
json_output (`bool`):
Use JSON format for log serialization.
otlp_service_name (`str`):
The name of the OTLP service. For now it is ignored.
max_input_tokens (`Optional[int]`):
The maximum number of tokens allowed in the input. For now it is ignored.
"""
if sharded:
raise ValueError("Sharding is not supported.")
# Remove default handler
logger.remove()
logger.add(
sys.stdout,
format="{message}",
filter="text_generation_server",
level=logger_level,
serialize=json_output,
backtrace=True,
diagnose=False,
)
if trust_remote_code is not None:
logger.warning("'trust_remote_code' argument is not supported and will be ignored.")
# TODO: these two parameters are used when the server is started, but they are not used yet, so just inform the
# user about that.
logger.info("'otlp_service_name' argument is not supported and will be ignored.")
# This is a workaround to pass the logger level to other threads, it's only used in
# Pytorch/XLA generator.
os.environ["LOGGER_LEVEL_GENERATOR"] = logger_level
# Import here after the logger is added to log potential import exceptions
from optimum.tpu.model import fetch_model
from .server import serve
# Read environment variables forwarded by the launcher
max_batch_size = int(os.environ.get("MAX_BATCH_SIZE", "4"))
max_total_tokens = int(os.environ.get("MAX_TOTAL_TOKENS", "64"))
# Start the server
model_path = fetch_model(model_id, revision)
serve(
model_path,
revision=revision,
max_batch_size=max_batch_size,
max_sequence_length=max_total_tokens,
max_input_tokens=max_input_tokens,
uds_path=uds_path
)
@app.command()
def download_weights(
model_id: str,
revision: Optional[str] = None,
logger_level: str = "INFO",
json_output: bool = False,
auto_convert: Optional[bool] = None,
extension: Optional[str] = None,
trust_remote_code: Optional[bool] = None,
merge_lora: Optional[bool] = None,
):
"""Download the model weights.
This command will be called by text-generation-launcher before serving the model.
"""
# Remove default handler
logger.remove()
logger.add(
sys.stdout,
format="{message}",
filter="text_generation_server",
level=logger_level,
serialize=json_output,
backtrace=True,
diagnose=False,
)
if extension is not None:
logger.warning("'extension' argument is not supported and will be ignored.")
if trust_remote_code is not None:
logger.warning("'trust_remote_code' argument is not supported and will be ignored.")
if auto_convert is not None:
logger.warning("'auto_convert' argument is not supported and will be ignored.")
if merge_lora is not None:
logger.warning("'merge_lora' argument is not supported and will be ignored.")
# Import here after the logger is added to log potential import exceptions
from optimum.tpu.model import fetch_model
fetch_model(model_id, revision)