arctic_inference/dynasor/openai_server.py (272 lines of code) (raw):

# Copyright 2025 Snowflake Inc. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ OpenAI-compatible proxy server that injects Dynasor probes. Known issues: - Prompt formatting is currently hardcoded - limitation to get system prompt to format a proper prefix. Potential accuracy degrade (unknown prompt) and/or performance issue (kv reuse) - API key is currently used as "EMPTY" placeholder. """ import argparse import asyncio import httpx import json import logging import os import time import uvicorn from dataclasses import dataclass from fastapi import FastAPI, Request from fastapi import HTTPException from fastapi.responses import StreamingResponse from openai import AsyncOpenAI from typing import Optional, Dict, Any, List, Union, AsyncGenerator from arctic_inference.dynasor.cot import ( obtain_answer, formalize_final_response, uncertain_words, default_probing_suffix, format_prompt_for_completions, ) from arctic_inference.dynasor.evaluator import count_not_empty, equal_group def init_logger() -> logging.Logger: """Initialize and configure the logger for the OpenAI proxy server. Returns: logging.Logger: Configured logger instance """ logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) handler = logging.StreamHandler() handler.setLevel(logging.DEBUG) formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') handler.setFormatter(formatter) logger.addHandler(handler) return logger logger = init_logger() @dataclass class ProxyConfig: """Configuration for the OpenAI proxy server. Attributes: host: Host address to bind the server to port: Port number to run the server on target_base_url: Base URL of the target OpenAI API server api_key: API key for authentication (defaults to "EMPTY") """ host: str port: int target_base_url: str api_key: str = "EMPTY" def make_parser() -> argparse.ArgumentParser: """Create and configure the argument parser for the server. Returns: argparse.ArgumentParser: Configured argument parser """ parser = argparse.ArgumentParser(description="OpenAI API Proxy Server") parser.add_argument( "--host", type=str, default="0.0.0.0", help="Host to bind the server to (default: 0.0.0.0)", ) parser.add_argument( "--port", type=int, default=8001, help="Port to run the server on (default: 8001)", ) parser.add_argument( "--target-base-url", type=str, default="http://localhost:8000", help="Base URL of the target OpenAI API server (default: http://localhost:8000)", ) return parser def parse_args(args_: Optional[List[str]] = None) -> ProxyConfig: """Parse command line arguments and create a ProxyConfig instance. Args: args_: Optional list of command line arguments Returns: ProxyConfig: Configuration object with parsed arguments """ parser = make_parser() args = parser.parse_args(args=args_) return ProxyConfig( host=args.host, port=args.port, target_base_url=args.target_base_url, ) app = FastAPI() # Initialize with None, will be set during startup config: Optional[ProxyConfig] = None def set_config(c: ProxyConfig) -> None: """Set the global configuration for the server. Args: c: ProxyConfig instance to set as global config """ global config config = c async def execute_single_probe( client: AsyncOpenAI, model_id: str, prompt: str, generated: str, probe_in_progress_event: asyncio.Event, max_tokens: int = 32, ) -> str: """Execute a single probe request to check model's certainty. Args: client: AsyncOpenAI client instance model_id: ID of the model to use prompt: Original prompt text generated: Generated text so far probe_in_progress_event: Event to track probe status max_tokens: Maximum tokens for probe response Returns: str: Probe response text """ try: # TODO(GindaChen)(Refactor): Prompt formatting is currently highly hardcoded. # Main issue is that we have to control the `</think>` token, and # except for the `/v1/completions` endpoint, we don't have a # proper way to control. # Case: # - If the template is unknown, then we can only submit something reasonoable. # - If the template is know-able, then the proxy server has to know about it. # Either the server provide an endpoint, or the user override this function. text = format_prompt_for_completions(prompt, generated) probe_response = await client.completions.create( model=model_id, prompt=text, max_tokens=max_tokens, temperature=0.6, top_p=0.95, ) if probe_response.choices and probe_response.choices[0].text: response_text_probe = probe_response.choices[0].text else: response_text_probe = "" finally: probe_in_progress_event.clear() return response_text_probe async def handle_chat_completion_request( request: Request, path: str ) -> AsyncGenerator[bytes, None]: """Handle chat completion requests with Dynasor probing. Keep the states of the probed results, and stream back the decoding results to the user. Args: request: FastAPI request object path: API endpoint path Yields: bytes: Chunks of the streaming response Raises: HTTPException: If the endpoint is not found """ auth_header = request.headers.get("Authorization") if auth_header and auth_header.startswith("Bearer "): api_key = auth_header.split(" ")[1] else: api_key = config.api_key # Fallback to default body = await request.body() body_json = json.loads(body) if body else {} client = AsyncOpenAI( api_key=api_key, base_url=f"{config.target_base_url}/v1", max_retries=1 ) logger.debug("Handle chat completion request: %s", body_json) model_id = body_json.get("model") max_tokens = body_json.get("max_tokens", 1024) # By default disable dynasor, unless client specifies it. dynasor_body = body_json.get("dynasor", {}) probe_interval = dynasor_body.get("probe_interval", 1e9) certainty_window = dynasor_body.get("certainty_window", 3) if path == "/v1/chat/completions": messages = body_json.get("messages") prompt = messages[-1].get("content") _response_stream = client.chat.completions.create( messages=messages, model=model_id, max_tokens=max_tokens, stream=True, ) response_stream = await _response_stream elif path == "/v1/completions": prompt = body_json.get("prompt") _response_stream = client.completions.create( model=model_id, prompt=prompt, max_tokens=max_tokens, stream=True, ) response_stream = await _response_stream else: raise HTTPException(status_code=404) probe_task: Optional[asyncio.Task] = None probe_in_progress_event = asyncio.Event() probe_in_progress_event.clear() probe_answers: List[str] = [] probe_responses: List[str] = [] adaptive_end = False should_launch_next_probe = False generated_text = "" chunks_processed = 0 async for chunk in response_stream: _chunk = chunk.to_json(indent=None, ) reconstructed_chunk = f"data: {_chunk}\n\n" yield reconstructed_chunk.encode("utf-8") # TODO: Properly set the exit condition. if ( chunk.choices[0].finish_reason is not None and chunk.choices[0].finish_reason != "length" ): break if chunk.choices and chunk.choices[0].delta and chunk.choices[0].delta.content is not None: text = chunk.choices[0].delta.content generated_text += text chunks_processed += 1 if chunks_processed > 0 and chunks_processed % probe_interval == 0: should_launch_next_probe = True if probe_task is not None and probe_task.done(): # Obtain the result from the probe task. probe_text = probe_task.result() answer = obtain_answer(probe_text) probe_task = None # Now check the certaindex for exiting condition. probe_answers.append(answer) probe_responses.append(probe_text) probe_certain_count = [ not any(word in res.lower() for word in uncertain_words) for res in probe_responses[-certainty_window:] ] is_group_equal = equal_group(probe_answers[-certainty_window:]) count_not_empty_count = count_not_empty(probe_answers[-certainty_window:]) if ( not adaptive_end and is_group_equal and count_not_empty_count == certainty_window and sum(probe_certain_count) == certainty_window ): adaptive_end = True if adaptive_end: should_launch_next_probe = False # TODO: Make the probe customizable output_text = formalize_final_response(generated_text, probe_answers[-1]) # Make a new chunk with the output text. new_chunk = chunk.model_copy() new_chunk.choices[0].delta.content = output_text new_chunk_bytes = new_chunk.to_json(indent=None) reconstructed_chunk = f"data: {new_chunk_bytes}\n\n" yield reconstructed_chunk.encode("utf-8") new_chunk.choices[0].delta.content = "" new_chunk.choices[0].finish_reason = "stop" reconstructed_chunk = f"data: {new_chunk.to_json(indent=None)}\n\n" yield reconstructed_chunk.encode("utf-8") break if should_launch_next_probe: if not probe_in_progress_event.is_set(): should_launch_next_probe = False probe_in_progress_event.set() probe_task = asyncio.create_task( execute_single_probe( client, model_id, prompt, generated_text, probe_in_progress_event, max_tokens=32, ) ) await response_stream.close() yield "data: [DONE]\n\n".encode("utf-8") @app.post("/v1/chat/completions") async def chat_completions_endpoint(request: Request) -> StreamingResponse: """Handle chat completions endpoint requests. Args: request: FastAPI request object Returns: StreamingResponse: Streaming response with chat completion results """ gen = handle_chat_completion_request(request, "/v1/chat/completions") return StreamingResponse( gen, media_type="text/event-stream", ) @app.post("/v1/completions") async def completions_endpoint(request: Request) -> StreamingResponse: """Handle completions endpoint requests. Args: request: FastAPI request object Returns: StreamingResponse: Streaming response with completion results """ gen = handle_chat_completion_request(request, "/v1/completions") return StreamingResponse( gen, media_type="text/event-stream", ) async def proxy_request(request: Request, path: str) -> StreamingResponse: """Proxy requests to the target OpenAI API server. Args: request: FastAPI request object path: API endpoint path Returns: StreamingResponse: Streaming response from the target server Raises: HTTPException: If the endpoint is not found """ # Skip chat/completions endpoints since they are handled separately if request.method == "POST" and request.url.path in ["/v1/chat/completions", "/v1/completions"]: raise HTTPException(status_code=404, detail="Not Found") # Get the raw request body body = await request.body() body_json = json.loads(body) if body else {} # Forward headers but exclude host headers = {k: v for k, v in request.headers.items() if k.lower() != "host"} # Construct target URL target_url = config.target_base_url.rstrip('/') + '/' + path.lstrip('/') # Check if streaming is requested is_stream = body_json.get("stream", False) async with httpx.AsyncClient() as client: # Forward the request with same method, headers, and body response = await client.request( method=request.method, url=target_url, headers=headers, content=body, ) if is_stream: # For streaming responses, stream each chunk async def stream_generator(): buffer = b"" async for chunk in response.aiter_bytes(): yield chunk if buffer: # Yield any remaining data yield buffer proxy_headers = { k: v for k, v in response.headers.items() if k.lower() not in {"content-length", "transfer-encoding", "content-encoding"} } return StreamingResponse( stream_generator(), status_code=response.status_code, headers=proxy_headers, media_type="text/event-stream" ) else: # For non-streaming responses, return the full response return StreamingResponse( response.aiter_bytes(), status_code=response.status_code, headers=dict(response.headers) ) @app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS", "HEAD", "PATCH"]) async def proxy(request: Request, path: str) -> StreamingResponse: """Generic proxy endpoint for all other API routes. Args: request: FastAPI request object path: API endpoint path Returns: StreamingResponse: Streaming response from the target server """ return await proxy_request(request, "/" + path) def start_server(config: ProxyConfig) -> None: """Start the FastAPI server with the given configuration. Args: config: ProxyConfig instance with server settings """ uvicorn.run(app, host=config.host, port=config.port) if __name__ == "__main__": args = parse_args() config = ProxyConfig( host=args.host, port=args.port, target_base_url=args.target_base_url, ) logger.info(f"Starting server with config: {config}") start_server(config)