arctic_inference/embedding/replica_manager.py (361 lines of code) (raw):

#!/usr/bin/env python3 """Replica manager and load-inger that launches multiple replicas on a **single GPU** (technically the same CUDA device) and forwards client requests to them. ## Architecture The replica manager consists of: 1. **Replica Manager**: Coordinates replica lifecycle and request routing 2. **Replica State**: Tracks health and status of each replica 3. **Load inger**: Selects replicas based on the configured policy 4. **gRPC Server**: Provides the same API as a single replica but handles distribution When a request comes in: 1. The replica manager selects a replica based on the load balancing policy 2. The request is forwarded to the selected replica 3. The response is returned to the client 4. If a replica fails, the request is retried on a different replica Usage example: ```bash python -m arctic_inference.embedding.replica_manager \ --port 60050 \ --num-replicas 2 \ --model "Snowflake/snowflake-arctic-embed-m-v1.5" ``` The replica manager will listen on ``--port`` (60050) and spawn two replicas on 60051, and 60052. """ from __future__ import annotations import asyncio import logging import os import random import subprocess import sys import time import uuid from argparse import Namespace from dataclasses import dataclass from enum import Enum from pathlib import Path from typing import Dict, List, Optional import grpc import grpc.aio os.environ["VLLM_PLUGINS"] = "" # vLLM imports – *lazy* import inference protos to avoid import-time failures. from vllm.engine.arg_utils import AsyncEngineArgs from vllm.utils import FlexibleArgumentParser # Ensure repo root is on the path so ``python -m`` works from anywhere. ROOT_DIR = Path(__file__).resolve().parent.parent.parent sys.path.insert(0, str(ROOT_DIR)) import arctic_inference.embedding.proto.python.inference_pb2 as inference_pb2 import arctic_inference.embedding.proto.python.inference_pb2_grpc as inference_pb2_grpc import contextlib logger = logging.getLogger("arctic_inference.embedding.manage4") class LoadingerType(str, Enum): """Different policies for selecting the next replica.""" ROUND_ROBIN = "round_robin" LEAST_LOADED = "least_loaded" RANDOM = "random" @dataclass class ReplicaInfo: """Book-keeping information for a single replica.""" id: str port: int process: subprocess.Popen channel: grpc.aio.Channel stub: inference_pb2_grpc.InferenceServiceStub healthy: bool = False current_load: int = 0 # active requests (best-effort) total_load: int = 0 # total number of requests (best-effort) # current_load_tokens: int = 0 # active tokens (best-effort) # total_load_tokens: int = 0 # total number of tokens (best-effort) last_checked: int = 0 async def close(self) -> None: if self.channel: await self.channel.close() if self.process and self.process.poll() is None: self.process.terminate() try: self.process.wait(5) except subprocess.TimeoutExpired: self.process.kill() class ReplicaManager: """Maintains a pool of replicas and provides selection + health logic.""" def __init__( self, args: Namespace, args_list: List[str], base_port: int, num_replicas: int, lb: LoadingerType = LoadingerType.ROUND_ROBIN, health_interval: float = 5.0, ) -> None: self.args = args self.args_list = args_list self.base_port = base_port self.num_replicas = num_replicas self.lb = lb self.health_interval = health_interval self.ready = False self._replicas: Dict[str, ReplicaInfo] = {} self._rr_index = -1 # will become 0 after first selection self._health_task: Optional[asyncio.Task] = None self._lock = asyncio.Lock() # configurable startup timeout (seconds) for each replica to become healthy self.startup_timeout: int = getattr(args, "startup_timeout", 300) # --------------------------------------------------------------------- # Public lifecycle helpers # --------------------------------------------------------------------- async def start(self) -> None: """Launch all replicas concurrently and start background health loop.""" ports = [self.base_port + 1 + i for i in range(self.num_replicas)] # check if the ports are available for p in ports: if not self._check_port_available(p): raise RuntimeError(f"Port {p} is already in use") results = await asyncio.gather( *(self._launch_replica_process(self.args.host, p) for p in ports) ) for replica in results: if replica: self._replicas[replica.id] = replica if not self._replicas: raise RuntimeError("Failed to start ANY replicas: aborting manager.") self._health_task = asyncio.create_task(self._health_loop()) self.ready = True async def stop(self) -> None: if self._health_task: self._health_task.cancel() with contextlib.suppress(asyncio.CancelledError): await self._health_task self.ready = False await asyncio.gather( *(r.close() for r in list(self._replicas.values())), return_exceptions=True ) # ------------------------------------------------------------------ # Request routing helpers # ------------------------------------------------------------------ async def route_request(self, method_name: str, request, context): """Route request to an appropriate replica and forward the response.""" replica = await self._select_replica() if not replica: context.set_code(grpc.StatusCode.UNAVAILABLE) context.set_details("No healthy replicas") return getattr(inference_pb2, self._response_type_name(method_name))() # Track load roughly. replica.current_load += request.n_prompts replica.total_load += request.n_prompts try: stub_method = getattr(replica.stub, method_name) response = await stub_method(request, timeout=self.args.forward_timeout) return response except grpc.RpcError as exc: logger.warning("RPC Error from replica %s - %s", replica.id, exc) # TODO(juncheng): consider exporting some metrics and launching a new replica await self._mark_unhealthy(replica) context.set_code(grpc.StatusCode.UNAVAILABLE) context.set_details("Replica failure: " + str(exc)) return getattr(inference_pb2, self._response_type_name(method_name))() finally: replica.current_load = max(replica.current_load - 1, 0) # ------------------------------------------------------------------ # Internal methods # ------------------------------------------------------------------ def _check_port_available(self, port: int) -> bool: """Check if a port is available. return True if the port is available, False otherwise. """ import socket with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: return s.connect_ex(("localhost", port)) != 0 def _build_replica_cmd(self, host: str, port: int) -> List[str]: cmd: List[str] = [ sys.executable, "-m", "arctic_inference.embedding.replica", "--host", host, "--port", str(port), ] skip_current_arg = False for arg in self.args_list: if arg.startswith("-"): skip_current_arg = False if skip_current_arg: continue if arg in [ "--host", "--port", "--num-replicas", "--load-balancing", "--health-interval", "--forward-timeout", "--startup-timeout", ]: skip_current_arg = True continue cmd.append(arg) return cmd async def _launch_replica_process( self, host: str, port: int ) -> Optional[ReplicaInfo]: """Start a replica process and wait until it reports healthy.""" cmd = self._build_replica_cmd(host, port) time.sleep(2) logger.info("Starting replica on port %d: %s", port, " ".join(cmd)) # Use line-buffered output so we can stream logs and avoid deadlocks. proc = subprocess.Popen( cmd, stdout=subprocess.PIPE, # capture stdout stderr=subprocess.STDOUT, # redirect stderr to same pipe text=True, bufsize=1, env=os.environ.copy(), ) # Stream logs asynchronously to avoid blocking on full pipe buffers. asyncio.create_task(self._stream_subprocess_output(proc, f"replica-{port}")) # Create gRPC stub. channel = grpc.aio.insecure_channel(f"localhost:{port}") stub = inference_pb2_grpc.InferenceServiceStub(channel) replica = ReplicaInfo( id=str(uuid.uuid4()), port=port, process=proc, channel=channel, stub=stub ) deadline = time.time() + self.startup_timeout while time.time() < deadline: if await self._check_health(replica): replica.healthy = True logger.info("Replica %s healthy on port %d", replica.id, port) return replica # Early exit if process already died. if proc.poll() is not None: logger.error( "Replica process on port %d exited prematurely with code %s", port, proc.returncode, ) break await asyncio.sleep(2) logger.error("Replica on port %d failed to become healthy within timeout", port) await replica.close() return None async def _stream_subprocess_output(self, proc: subprocess.Popen, prefix: str): """Continuously read a subprocess' combined stdout/stderr and log lines.""" if not proc.stdout: return loop = asyncio.get_running_loop() try: while True: line = await loop.run_in_executor(None, proc.stdout.readline) if not line: break print(line.rstrip()) except Exception as exc: logger.warning("Log stream for %s terminated: %s", prefix, exc) async def _health_loop(self): while True: await asyncio.sleep(self.health_interval) await asyncio.gather( *(self._periodic_health(r) for r in self._replicas.values()) ) async def _periodic_health(self, replica: ReplicaInfo): healthy = await self._check_health(replica) if healthy and not replica.healthy: logger.info("Replica %s recovered", replica.id) if not healthy and replica.healthy: logger.warning("Replica %s became UNHEALTHY", replica.id) replica.healthy = healthy replica.last_checked = time.time() async def _check_health(self, replica: ReplicaInfo) -> bool: try: resp = await replica.stub.HealthCheck( inference_pb2.HealthCheckRequest(), timeout=2 ) return resp.healthy except Exception: return False async def _mark_unhealthy(self, replica: ReplicaInfo): replica.healthy = False async def _select_replica(self) -> Optional[ReplicaInfo]: """Return a healthy replica based on the chosen LB policy.""" healthy = [r for r in self._replicas.values() if r.healthy] if not healthy: return None if self.lb == LoadingerType.ROUND_ROBIN: self._rr_index = (self._rr_index + 1) % len(healthy) return healthy[self._rr_index] if self.lb == LoadingerType.LEAST_LOADED: return min(healthy, key=lambda r: r.current_load) if self.lb == LoadingerType.RANDOM: return random.choice(healthy) # Fallback: return healthy[0] @staticmethod def _response_type_name(method_name: str) -> str: # Encode -> EncodeResponse, Generate -> GenerateResponse, etc. return method_name + "Response" class ManagerServicer(inference_pb2_grpc.InferenceServiceServicer): """Thin wrapper that forwards every call to ``ReplicaManager``.""" def __init__(self, replica_manager: ReplicaManager): self.replica_manager = replica_manager # ---------------------------------------------------------------------------- # Forwarded RPC methods – keep only Encode / Generate for brevity; add others # if you rely on them in production. # ---------------------------------------------------------------------------- async def Encode(self, request, context) -> inference_pb2.EncodeResponse: # type: ignore return await self.replica_manager.route_request("Encode", request, context) async def HealthCheck( self, request, context ) -> inference_pb2.HealthCheckResponse: # pass‐through to manager return inference_pb2.HealthCheckResponse(healthy=self.replica_manager.ready) async def GetReplicaInfo( self, request, context ) -> inference_pb2.ReplicaInfoResponse: # type: ignore """query each replica's info and return the info of all replicas """ replica_info_list = [] n_healthy_replicas = 0 for replica in self.replica_manager._replicas.values(): try: r = await replica.stub.GetReplicaInfo( request, timeout=self.replica_manager.args.forward_timeout ) replica_info_list.append(r.replica_infos[0]) n_healthy_replicas += 1 except grpc.RpcError as e: # Mark replica unhealthy and fall through to default response. logger.warning( f"grpc Error: {e} getting info from replica {replica.id} mark unhealthy" ) await self.replica_manager._mark_unhealthy(replica) except (AttributeError, IndexError) as e: # Handle case where replica_infos is missing or empty logger.warning( f"Failed to get replica info from {replica.id}: {e}, mark unhealthy" ) await self.replica_manager._mark_unhealthy(replica) # Fallback minimal info when no replica is healthy. return inference_pb2.ReplicaInfoResponse( replica_infos=replica_info_list, n_replicas=self.replica_manager.num_replicas, n_healthy_replicas=n_healthy_replicas, message="", ) async def serve(args_list: List[str]): from concurrent.futures import ThreadPoolExecutor args = parser.parse_args(args_list) lm = ReplicaManager( args=args, args_list=args_list, base_port=args.port, num_replicas=args.num_replicas, lb=LoadingerType(args.load_balancing), health_interval=args.health_interval, ) grpc_server = grpc.aio.server( ThreadPoolExecutor(max_workers=args.workers), options=[ ("grpc.max_message_length", 200 * 1024 * 1024), ("grpc.max_send_message_length", 200 * 1024 * 1024), ("grpc.max_receive_message_length", 200 * 1024 * 1024), ], ) inference_pb2_grpc.add_InferenceServiceServicer_to_server( ManagerServicer(lm), grpc_server ) listen_addr = f"{args.host}:{args.port}" grpc_server.add_insecure_port(listen_addr) logger.info("Manager server listening on %s", listen_addr) await grpc_server.start() await lm.start() try: await grpc_server.wait_for_termination() except (KeyboardInterrupt, asyncio.CancelledError): logger.info("Received shutdown signal") finally: await lm.stop() await grpc_server.stop(0) if __name__ == "__main__": from vllm import logger as vllm_logger # type: ignore logging.basicConfig( format=vllm_logger._FORMAT, datefmt=vllm_logger._DATE_FORMAT, level=logging.INFO ) parser = FlexibleArgumentParser( description="Replica manager that manages and load balances between multiple replicas on a single GPU" ) parser.add_argument("--host", default="0.0.0.0", help="Bind host") parser.add_argument( "--port", type=int, default=60050, help="Port for the manager, replicas will be on consecutive ports", ) parser.add_argument( "--num-replicas", type=int, default=2, help="Number of replicas to launch" ) parser.add_argument( "--load-balancing", default="round_robin", choices=[e.value for e in LoadingerType], help="Load balancing strategy", ) parser.add_argument( "--health-interval", type=int, default=2, help="Seconds between health checks" ) parser.add_argument( "--forward-timeout", type=int, default=20, help="Timeout (s) for requests forwarded to replicas", ) parser.add_argument( "--startup-timeout", type=int, default=60, help="Seconds to wait for each replica to become healthy", ) parser.add_argument( "--workers", type=int, default=16, help="Number of gRPC workers" ) # Propagate vLLM engine args so we can pass them through to each replica. parser = AsyncEngineArgs.add_cli_args(parser) parser.set_defaults( host="0.0.0.0", port=50050, num_replicas=4, health_interval=1, forward_timeout=120, startup_timeout=120, load_balancing="round_robin", model="Snowflake/snowflake-arctic-embed-m-v1.5", ) asyncio.run(serve(sys.argv[1:]))