neuron_explainer/api_client.py (111 lines of code) (raw):
import asyncio
import contextlib
import os
import random
import traceback
from functools import wraps
from typing import Any
import httpx
import orjson
API_KEY = os.getenv("OPENAI_API_KEY")
assert API_KEY, "Please set the OPENAI_API_KEY environment variable"
API_HTTP_HEADERS = {
"Content-Type": "application/json",
"Authorization": "Bearer " + API_KEY,
}
BASE_API_URL = "https://api.openai.com/v1"
def async_exponential_backoff(retry_on=lambda err: True):
"""
Returns a decorator which retries the wrapped function as long as the specified retry_on
function returns True for the exception, applying exponential backoff with jitter after
failures, up to a retry limit.
"""
init_delay_s = 1
max_delay_s = 10
backoff_multiplier = 2.0
jitter = 0.2
def decorate(f):
@wraps(f)
async def f_retry(self, *args, **kwargs):
max_tries = 200
delay_s = init_delay_s
for i in range(max_tries):
try:
return await f(self, *args, **kwargs)
except Exception as err:
if i == max_tries - 1:
print(f"Exceeded max tries ({max_tries}) on HTTP request")
raise
if not retry_on(err):
print("Unretryable error on HTTP request")
raise
jittered_delay = random.uniform(delay_s * (1 - jitter), delay_s * (1 + jitter))
await asyncio.sleep(jittered_delay)
delay_s = min(delay_s * backoff_multiplier, max_delay_s)
return f_retry
return decorate
def is_api_error(err: Exception) -> bool:
if isinstance(err, httpx.HTTPStatusError):
response = err.response
error_data = response.json().get("error", {})
error_message = error_data.get("message")
if response.status_code in [400, 404, 415]:
if error_data.get("type") == "idempotency_error":
print(f"Retrying after idempotency error: {error_message} ({response.url})")
return True
else:
# Invalid request
return False
else:
print(f"Retrying after API error: {error_message} ({response.url})")
return True
elif isinstance(err, httpx.ConnectError):
print(f"Retrying after connection error... ({err.request.url})")
return True
elif isinstance(err, httpx.TimeoutException):
print(f"Retrying after a timeout error... ({err.request.url})")
return True
elif isinstance(err, httpx.ReadError):
print(f"Retrying after a read error... ({err.request.url})")
return True
print(f"Retrying after an unexpected error: {repr(err)}")
traceback.print_tb(err.__traceback__)
return True
class ApiClient:
"""
Performs inference using the OpenAI API. Supports response caching and concurrency limits."
Cache is useful for rerunning code in jupyter notebooks without repeating identical requests
"""
def __init__(
self,
model_name: str,
max_concurrent: int | None = None,
# If set, no more than this number of HTTP requests will be made concurrently.
cache: bool = False,
):
self.model_name = model_name
if max_concurrent is not None:
self.concurrency_check: asyncio.Semaphore | None = asyncio.Semaphore(max_concurrent)
else:
self.concurrency_check = None
if cache:
self.cache = {}
else:
self.cache = None
@async_exponential_backoff(retry_on=is_api_error)
async def async_generate(self, timeout: int | None = None, **kwargs: Any) -> dict:
if self.cache is not None:
key = orjson.dumps(kwargs)
if key in self.cache:
return self.cache[key]
async with contextlib.AsyncExitStack() as stack:
if self.concurrency_check is not None:
await stack.enter_async_context(self.concurrency_check)
http_client = await stack.enter_async_context(httpx.AsyncClient(timeout=timeout))
# If the request has a "messages" key, it should be sent to the /chat/completions
# endpoint. Otherwise, it should be sent to the /completions endpoint.
url = BASE_API_URL + ("/chat/completions" if "messages" in kwargs else "/completions")
kwargs["model"] = self.model_name
response = await http_client.post(url, headers=API_HTTP_HEADERS, json=kwargs)
# Response json has useful information but the exception doesn't include it, so print it out then reraise
try:
response.raise_for_status()
except Exception as e:
print(response.json())
raise e
if self.cache is not None:
self.cache[key] = response.json()
return response.json()
if __name__ == "__main__":
async def main() -> None:
client = ApiClient(model_name="gpt-3.5-turbo", max_concurrent=1)
print(
await client.async_generate(prompt="Why did the chicken cross the road?", max_tokens=9)
)
asyncio.run(main())