arctic_inference/embedding/replica.py (300 lines of code) (raw):

#!/usr/bin/env python3 """ gRPC server for running one model on a given GPU. It can be used directly if only one model is needed. More commonly, it is used together with a load balancer (manager.py) to serve multiple models on one GPU concurrently. This is useful for load balancing and improving throughput. """ import os import asyncio import logging import torch import numpy as np import uuid import sys from concurrent import futures from typing import Optional from argparse import Namespace os.environ["VLLM_PLUGINS"] = "" import grpc from grpc import aio from grpc import ServicerContext from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.arg_utils import AsyncEngineArgs from vllm.usage.usage_lib import UsageContext from vllm.inputs import TokensPrompt from vllm.pooling_params import PoolingParams from vllm.utils import FlexibleArgumentParser # Ensure we're using vLLM v0 for embedding support os.environ["VLLM_USE_V1"] = "0" # Import the generated protobuf code try: import arctic_inference.embedding.proto.python.inference_pb2 as inference_pb2 import arctic_inference.embedding.proto.python.inference_pb2_grpc as inference_pb2_grpc except ImportError: print( "Error: Could not import gRPC modules. Make sure to run generate_proto.py first." ) print("Run: python arctic_inference/grpc/generate_proto.py") sys.exit(1) # Configure logger logger = logging.getLogger("arctic_inference.embedding.replica") class InferenceServicer(inference_pb2_grpc.InferenceServiceServicer): """Implementation of the InferenceService gRPC service. This class handles incoming gRPC requests for inference operations including text embedding and generation. """ def __init__(self, engine_args: AsyncEngineArgs): """Initialize the servicer with AsyncLLMEngine arguments. Args: engine_args: Configuration for the AsyncLLMEngine. """ self.engine_args = engine_args self.engine = None self.tokenizer = None self.model_config = None self.model_name = None self.max_model_len = None self._active_requests = {} # Tracks active generation requests self.ready = False async def start(self): """Initialize and start the LLM engine. This method creates the AsyncLLMEngine instance and initializes the tokenizer and model configuration. """ # Create the engine from the provided arguments self.engine = AsyncLLMEngine.from_engine_args( self.engine_args, usage_context=UsageContext.API_SERVER ) self.engine.log_requests = False self.engine.start_background_loop() # Get tokenizer and model configuration self.tokenizer = await self.engine.get_tokenizer() self.model_config = await self.engine.get_model_config() self.model_name = self.model_config.model self.max_model_len = self.model_config.max_model_len self.ready = True async def stop(self): """Shut down the LLM engine.""" await self.engine.shutdown() async def _encode_one_prompt( self, prompt: TokensPrompt, pooling_params: PoolingParams, request_id: str, priority: int, ) -> Optional[torch.Tensor]: """Encode a single prompt into an embedding. Args: prompt: The tokenized prompt to encode. pooling_params: Parameters for embedding pooling. request_id: Unique identifier for this request. priority: Priority level of this request. Returns: Tensor containing the embedding, or None if encoding failed. """ result = None async for encode_result in self.engine.encode( prompt=prompt, pooling_params=pooling_params, request_id=request_id, priority=priority, ): result = encode_result.outputs.data return result async def Encode( self, request: inference_pb2.EncodeRequest, context: ServicerContext ) -> inference_pb2.EncodeResponse: """Handle a request to encode prompts into embeddings. Args: request: The client request containing prompts to encode. context: gRPC service context. Returns: Response containing the generated embeddings or an error message. """ # Check if the model is ready if not self.ready: return inference_pb2.EncodeResponse( request_id=request.request_id, error="Model not ready", ) # Ensure we have a request ID request_id = request.request_id or str(uuid.uuid4()) pooling_params = PoolingParams() # Validate the request contains the required number of prompts n_prompts = request.n_prompts if not hasattr(request, "n_prompts"): return inference_pb2.EncodeResponse( request_id=request_id, error="n_prompts is required", ) # Validate the model name if specified if hasattr(request, "model_name") and request.model_name != self.model_name: return inference_pb2.EncodeResponse( request_id=request_id, error=f"Expected model name {self.model_name}, got {request.model_name}", ) logger.info( f"Received encoding request: {request_id} having {n_prompts} prompts for {self.model_name}" ) # Process pre-tokenized prompts or raw text prompts if hasattr(request, "token_id_bytes_i32") and request.token_id_bytes_i32: logging.info("Using tokenized prompts for encoding") if len(request.token_id_bytes_i32) != n_prompts: return inference_pb2.EncodeResponse( request_id=request_id, error=f"Expected {n_prompts} prompts, got {len(request.token_id_bytes_i32)}", ) # Convert binary token IDs to TokensPrompt objects tokenized_prompts = [] for i in range(n_prompts): token_ids = np.frombuffer(request.token_id_bytes_i32[i], dtype=np.int32) # TODO(juncheng): need to patch vllm to support tensor input tokenized_prompts.append( TokensPrompt(prompt_token_ids=token_ids.tolist()) ) else: # Check if the correct number of text prompts is provided if len(request.prompts) != n_prompts: return inference_pb2.EncodeResponse( request_id=request_id, error=f"Expected {n_prompts} prompts, got {len(request.prompts)}", ) # Tokenize the prompts before sending to vLLM # This enables batch tokenization which is more efficient than # vLLM's per-prompt tokenization tokens = self.tokenizer( [text_prompt for text_prompt in request.prompts], return_tensors=None, padding=False, return_token_type_ids=False, return_attention_mask=False, truncation=True, ).input_ids tokenized_prompts = [ TokensPrompt(prompt_token_ids=token_ids) for token_ids in tokens ] # Create encoding tasks for all prompts tasks = [ self._encode_one_prompt( tokenized_prompt, pooling_params, f"{request_id}:{i}", request.priority, ) for i, tokenized_prompt in enumerate(tokenized_prompts) ] # Process all tasks concurrently embedding_tensors = await asyncio.gather(*tasks) if any(tensor is None for tensor in embedding_tensors): return inference_pb2.EncodeResponse( request_id=request_id, error="No encoding result", ) # Get embedding dimension from the first result embedding_dim = embedding_tensors[0].shape[0] # Convert embeddings to bytes for the response embedding_bytes_fp32 = [ tensor.numpy().astype(dtype="<f4").tobytes() for tensor in embedding_tensors ] # Return the response with embeddings return inference_pb2.EncodeResponse( request_id=request_id, n_prompts=n_prompts, embedding_dim=embedding_dim, embedding_bytes_fp32=embedding_bytes_fp32, ) async def Abort( self, request: inference_pb2.AbortRequest, context: ServicerContext ) -> inference_pb2.AbortResponse: """Abort an ongoing generation request. Args: request: The abort request containing the request ID to abort. context: gRPC service context. Returns: Response indicating whether the abort was successful. """ request_id = request.request_id logger.info(f"Received abort request for: {request_id}") try: await self.engine.abort(request_id) return inference_pb2.AbortResponse( success=True, message=f"Request {request_id} aborted successfully" ) except Exception as e: logger.exception(f"Error aborting request {request_id}: {e}") return inference_pb2.AbortResponse( success=False, message=f"Error aborting request: {str(e)}" ) async def GetReplicaInfo( self, request: inference_pb2.ReplicaInfoRequest, context: ServicerContext ) -> inference_pb2.ReplicaInfoResponse: """Get information about this replica including model and configuration. Args: request: The info request. context: gRPC service context. Returns: Response containing model and configuration information. """ try: # Gather all configuration information model_config = await self.engine.get_model_config() parallel_config = await self.engine.get_parallel_config() decoding_config = await self.engine.get_decoding_config() scheduler_config = await self.engine.get_scheduler_config() lora_config = await self.engine.get_lora_config() replica_info = inference_pb2.SingleReplicaInfoResponse( model_name=str(model_config.model), task=str(model_config.task), dtype=str(model_config.dtype), ready=self.ready, parallel_config=str(parallel_config), decoding_config=str(decoding_config), scheduler_config=str(scheduler_config), lora_config=str(lora_config), ) return inference_pb2.ReplicaInfoResponse( replica_infos=[replica_info, replica_info], n_replicas=2, n_healthy_replicas=2, message="", ) except Exception as e: logger.exception(f"Error getting model info: {e}") context.set_code(grpc.StatusCode.INTERNAL) context.set_details(f"Error getting model info: {str(e)}") return inference_pb2.ReplicaInfoResponse( replica_infos=[], n_replicas=1, n_healthy_replicas=0, message=f"Error getting model info: {str(e)}", ) async def HealthCheck( self, request: inference_pb2.HealthCheckRequest, context: ServicerContext ) -> inference_pb2.HealthCheckResponse: """Check if the service is healthy. Args: request: The health check request. context: gRPC service context. Returns: Response indicating whether the service is healthy. """ try: await self.engine.check_health() return inference_pb2.HealthCheckResponse(healthy=self.ready) except Exception as e: logger.exception(f"Health check failed: {e}") return inference_pb2.HealthCheckResponse( healthy=False, message=f"Service is unhealthy: {str(e)}" ) class InferenceServer: """gRPC server for the InferenceService. This class manages the lifecycle of the gRPC server and the InferenceServicer. """ def __init__( self, args: Namespace, ): """Initialize the server with command line arguments. Args: args: Command line arguments containing server configuration. """ self.args = args self.engine_args = AsyncEngineArgs.from_cli_args(self.args) self.server = None self.host = args.host self.port = args.port self.workers = args.workers self.ready = False async def start(self): """Start the gRPC server and initialize the servicer. This method configures and starts the gRPC server, then waits for termination signals. """ # Create the gRPC server with appropriate concurrency and message size limits self.server = aio.server( futures.ThreadPoolExecutor(max_workers=self.workers), options=[ ("grpc.max_message_length", 200 * 1024 * 1024), ("grpc.max_send_message_length", 200 * 1024 * 1024), ("grpc.max_receive_message_length", 200 * 1024 * 1024), ], ) # TODO(juncheng): set up metrics # Create and start the servicer self.servicer = InferenceServicer(self.engine_args) await self.servicer.start() # Register the servicer with the server inference_pb2_grpc.add_InferenceServiceServicer_to_server( self.servicer, self.server ) # Start the server address = f"{self.host}:{self.port}" self.server.add_insecure_port(address) logger.info(f"Starting gRPC replica on {address}") # Set up exception handling # loop = asyncio.get_running_loop() # loop.set_exception_handler(exception_handler) await self.server.start() logger.info("arctic_inference gRPC replica started") try: # Wait for replica termination await self.server.wait_for_termination() except asyncio.CancelledError: # Handle task cancellation print("Server task cancelled.") except KeyboardInterrupt: # Handle Ctrl+C logger.info("KeyboardInterrupt detected. Shutting down server...") finally: # Ensure server is stopped await self.stop() async def stop(self): """Stop the replica and clean up resources.""" if self.server: logger.info("Stopping replica") await self.server.stop(0) await self.servicer.stop() self.server = None self.servicer = None # Clean up PyTorch distributed process group if initialized try: import torch.distributed as dist if dist.is_initialized(): dist.destroy_process_group() except Exception: pass async def serve(args: Namespace) -> None: """Main entry point to start the replica. Args: args: Command line arguments. """ logger.info("args: %s", args) server = InferenceServer(args) await server.start() def patch_embedding_performance(): from functools import lru_cache import vllm.model_executor.model_loader.utils as vllm_utils # Get the original function original_function = vllm_utils.get_model_architecture # Apply your decorator decorated_function = lru_cache(maxsize=None)(original_function) # Replace the original function in the module with the decorated one vllm_utils.get_model_architecture = decorated_function logger.debug("Patched get_model_architecture for embedding performance") def patch_model_config_hash(): from vllm.config import ModelConfig # compute hash in int def compute_hash_int(self): return int(ModelConfig.compute_hash(self), 16) ModelConfig.__hash__ = compute_hash_int if __name__ == "__main__": # patch the get_model_architecture for embedding performance patch_embedding_performance() patch_model_config_hash() # Configure logging from vllm import logger as vllm_logger # type: ignore logging.basicConfig( format=vllm_logger._FORMAT, datefmt=vllm_logger._DATE_FORMAT, level=logging.INFO ) # Parse command line arguments parser = FlexibleArgumentParser(description="gRPC replica for vLLM inference") parser = AsyncEngineArgs.add_cli_args(parser) # Replica-specific arguments parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind to") parser.add_argument( "--port", type=int, default=50050, help="Port to bind to for inference" ) parser.add_argument( "--metrics-port", type=int, default=80001, help="Port to bind to for metrics" ) parser.add_argument( "--workers", type=int, default=16, help="Number of gRPC workers" ) # Default model parser.set_defaults(model="Snowflake/snowflake-arctic-embed-m-v1.5") args = parser.parse_args() # Start the replica asyncio.run(serve(args))