backends/python/server/text_embeddings_server/server.py (72 lines of code) (raw):
import asyncio
import torch
from grpc import aio
from loguru import logger
from grpc_reflection.v1alpha import reflection
from pathlib import Path
from typing import Optional
from text_embeddings_server.models import Model, get_model
from text_embeddings_server.pb import embed_pb2_grpc, embed_pb2
from text_embeddings_server.utils.tracing import UDSOpenTelemetryAioServerInterceptor
from text_embeddings_server.utils.interceptor import ExceptionInterceptor
class EmbeddingService(embed_pb2_grpc.EmbeddingServiceServicer):
def __init__(self, model: Model):
self.model = model
# Force inference mode for the lifetime of EmbeddingService
self._inference_mode_raii_guard = torch._C._InferenceMode(True)
async def Health(self, request, context):
if self.model.device.type == "cuda":
torch.zeros((2, 2), device="cuda")
return embed_pb2.HealthResponse()
async def Embed(self, request, context):
max_input_length = self.model.max_input_length
batch = self.model.batch_type.from_pb(
request, self.model.device, max_input_length
)
embeddings = self.model.embed(batch)
return embed_pb2.EmbedResponse(embeddings=embeddings)
async def Predict(self, request, context):
max_input_length = self.model.max_input_length
batch = self.model.batch_type.from_pb(
request, self.model.device, max_input_length
)
scores = self.model.predict(batch)
return embed_pb2.PredictResponse(scores=scores)
def serve(
model_path: Path,
dtype: Optional[str],
uds_path: Path,
pool: str,
):
async def serve_inner(
model_path: Path,
dtype: Optional[str] = None,
):
unix_socket = f"unix://{uds_path}"
try:
model = get_model(model_path, dtype, pool)
except Exception:
logger.exception("Error when initializing model")
raise
server = aio.server(
interceptors=[
ExceptionInterceptor(),
UDSOpenTelemetryAioServerInterceptor(),
]
)
embed_pb2_grpc.add_EmbeddingServiceServicer_to_server(
EmbeddingService(model), server
)
SERVICE_NAMES = (
embed_pb2.DESCRIPTOR.services_by_name["EmbeddingService"].full_name,
reflection.SERVICE_NAME,
)
reflection.enable_server_reflection(SERVICE_NAMES, server)
server.add_insecure_port(unix_socket)
await server.start()
logger.info(f"Server started at {unix_socket}")
try:
await server.wait_for_termination()
except KeyboardInterrupt:
logger.info("Signal received. Shutting down")
await server.stop(0)
asyncio.run(serve_inner(model_path, dtype))