neuron-explainer/neuron_explainer/api_client.py (113 lines of code) (raw):

import asyncio import contextlib import os import random import traceback from asyncio import Semaphore from functools import wraps from typing import Any, Callable, Optional import httpx import orjson def is_api_error(err: Exception) -> bool: if isinstance(err, httpx.HTTPStatusError): response = err.response error_data = response.json().get("error", {}) error_message = error_data.get("message") if response.status_code in [400, 404, 415]: if error_data.get("type") == "idempotency_error": print(f"Retrying after idempotency error: {error_message} ({response.url})") return True else: # Invalid request return False else: print(f"Retrying after API error: {error_message} ({response.url})") return True elif isinstance(err, httpx.ConnectError): print(f"Retrying after connection error... ({err.request.url})") return True elif isinstance(err, httpx.TimeoutException): print(f"Retrying after a timeout error... ({err.request.url})") return True elif isinstance(err, httpx.ReadError): print(f"Retrying after a read error... ({err.request.url})") return True print(f"Retrying after an unexpected error: {repr(err)}") traceback.print_tb(err.__traceback__) return True def exponential_backoff( retry_on: Callable[[Exception], bool] = lambda err: True ) -> Callable[[Callable], Callable]: """ Returns a decorator which retries the wrapped function as long as the specified retry_on function returns True for the exception, applying exponential backoff with jitter after failures, up to a retry limit. """ init_delay_s = 1.0 max_delay_s = 10.0 # Roughly 30 minutes before we give up. max_tries = 200 backoff_multiplier = 2.0 jitter = 0.2 def decorate(f: Callable) -> Callable: assert asyncio.iscoroutinefunction(f) @wraps(f) async def f_retry(*args: Any, **kwargs: Any) -> None: delay_s = init_delay_s for i in range(max_tries): try: return await f(*args, **kwargs) except Exception as err: if not retry_on(err) or i == max_tries - 1: raise jittered_delay = random.uniform(delay_s * (1 - jitter), delay_s * (1 + jitter)) await asyncio.sleep(jittered_delay) delay_s = min(delay_s * backoff_multiplier, max_delay_s) return f_retry return decorate API_KEY = os.getenv("OPENAI_API_KEY") assert API_KEY, "Please set the OPENAI_API_KEY environment variable" API_HTTP_HEADERS = { "Content-Type": "application/json", "Authorization": "Bearer " + API_KEY, } BASE_API_URL = "https://api.openai.com/v1" class ApiClient: """Performs inference using the OpenAI API. Supports response caching and concurrency limits.""" def __init__( self, model_name: str, # If set, no more than this number of HTTP requests will be made concurrently. max_concurrent: Optional[int] = None, # Whether to cache request/response pairs in memory to avoid duplicating requests. cache: bool = False, ): self.model_name = model_name if max_concurrent is not None: self._concurrency_check: Optional[Semaphore] = Semaphore(max_concurrent) else: self._concurrency_check = None if cache: self._cache: Optional[dict[str, Any]] = {} else: self._cache = None @exponential_backoff(retry_on=is_api_error) async def make_request( self, timeout_seconds: Optional[int] = None, **kwargs: Any ) -> dict[str, Any]: if self._cache is not None: key = orjson.dumps(kwargs) if key in self._cache: return self._cache[key] async with contextlib.AsyncExitStack() as stack: if self._concurrency_check is not None: await stack.enter_async_context(self._concurrency_check) http_client = await stack.enter_async_context( httpx.AsyncClient(timeout=timeout_seconds) ) # If the request has a "messages" key, it should be sent to the /chat/completions # endpoint. Otherwise, it should be sent to the /completions endpoint. url = BASE_API_URL + ("/chat/completions" if "messages" in kwargs else "/completions") kwargs["model"] = self.model_name response = await http_client.post(url, headers=API_HTTP_HEADERS, json=kwargs) # The response json has useful information but the exception doesn't include it, so print it # out then reraise. try: response.raise_for_status() except Exception as e: print(response.json()) raise e if self._cache is not None: self._cache[key] = response.json() return response.json() if __name__ == "__main__": async def main() -> None: client = ApiClient(model_name="gpt-3.5-turbo", max_concurrent=1) print(await client.make_request(prompt="Why did the chicken cross the road?", max_tokens=9)) asyncio.run(main())