databao/configs/llm.py (131 lines of code) (raw):
import os
from pathlib import Path
from typing import Any, Literal
from langchain.chat_models import init_chat_model
from langchain_core.language_models import BaseChatModel
from pydantic import BaseModel, ConfigDict, Field
from typing_extensions import Self
_OPENAI_PREFIXES = ["gpt", "o1", "o3", "o4"]
_ANTHROPIC_PREFIXES = ["claude", "anthropic"]
_OPENAI_REASONING_INFIXES = ["o1", "o3", "o4", "gpt-5", "openai/gpt-oss"]
# TODO: add a config folder for LLM configs, make it initializable from hydra configs
class LLMConfig(BaseModel):
"""Base class with all fields and computed logic for LLM configurations."""
# Fields declared in parent - can be overridden in children with different defaults
name: str
"""The model name can be of the form 'provider:name' or 'name'."""
temperature: float = 0.0
max_tokens: int = 8192
"""Maximum number of tokens to generate."""
reasoning_effort: str = "medium"
"""Reasoning effort is used for OpenAI reasoning models only.
Warning: reasoning can use a lot of tokens! OpenAI recommends at least 25000 tokens"""
cache_system_prompt: bool = True
"""Cache system prompt with prompt caching. Only used for Anthropic models."""
# TODO multi-turn prompt caching
max_tokens_before_cleaning: int = 10000
"""Number of tokens to start history cleaning. Each Executor has it's own cleaning strategy."""
timeout: int | None | Literal["auto"] = "auto"
"""Timeout in seconds for LLM calls. If None, use the LLM provider's defaults.
If 'auto', use a default timeout (60s) that increases for reasoning models."""
api_base_url: str | None = None
"""Base URL for an OpenAI-compatible API like 'http://localhost:8080/v1'. Mostly used for running local models."""
use_responses_api: bool = True
"""Use the [responses API](https://platform.openai.com/docs/guides/migrate-to-responses) for OpenAI models.
If False, use the old Chat Completions API (useful for local models that don't support the new responses API)."""
ollama_pull_model: bool = True
"""Pull the model from Ollama if it's not already downloaded. Only applicable for the 'ollama' model provider."""
model_kwargs: dict[str, Any] = Field(default_factory=dict)
"""Additional kwargs for the model constructor."""
model_config = ConfigDict(frozen=True, extra="forbid")
def _resolve_timeout(self) -> float | None:
if self.timeout == "auto":
return 360 if _is_reasoning_model(self.name) else 60
else:
return self.timeout
def new_chat_model(self) -> BaseChatModel:
"""Create a chat model from this config using init_chat_model for provider detection."""
provider, name = _parse_model_provider(self.name)
if provider == "openai" or self.api_base_url is not None:
from langchain_openai import ChatOpenAI
# Use the verbatim name if using an OAI server
model_name = self.name if self.api_base_url is not None else name
is_reasoning = _is_reasoning_model(model_name)
extra_kwargs: dict[str, Any] = {}
if self.use_responses_api:
extra_kwargs.update(
# Without "summary", no reasoning traces will be returned by the API
reasoning={"effort": self.reasoning_effort, "summary": "auto"} if is_reasoning else None,
temperature=self.temperature,
# TODO output_version="responses/v1"
)
else:
extra_kwargs.update(
reasoning_effort=self.reasoning_effort if is_reasoning else None,
# The old API errors out if you provide a temperature for reasoning models
temperature=self.temperature if not is_reasoning else None,
)
# Set a default API key for local models if the user didn't provide one
if (
self.api_base_url is not None
and "api_key" not in self.model_kwargs
and "OPENAI_API_KEY" not in os.environ
):
extra_kwargs["api_key"] = "local-api-key"
return ChatOpenAI(
model=model_name,
timeout=self._resolve_timeout(),
max_tokens=self.max_tokens,
base_url=self.api_base_url,
use_responses_api=self.use_responses_api,
**extra_kwargs,
**self.model_kwargs,
)
elif provider == "anthropic":
from langchain_anthropic import ChatAnthropic
return ChatAnthropic(
model_name=name,
timeout=self._resolve_timeout(),
temperature=self.temperature,
max_tokens_to_sample=self.max_tokens,
**self.model_kwargs,
)
if provider == "ollama" and self.ollama_pull_model:
import ollama
# Download with ollama. If the model already exists it will not be re-downloaded.
ollama.pull(name)
return init_chat_model(
self.name,
configurable_fields=None, # Ensures we match the BaseChatModel overload
temperature=self.temperature,
max_tokens=self.max_tokens,
timeout=self._resolve_timeout(),
**self.model_kwargs,
)
@classmethod
def from_yaml(cls, path: str | Path) -> Self:
"""Load an LLM config from a YAML file."""
import yaml
path = Path(path)
if path.exists():
content = path.read_text()
else:
raise ValueError(f"LLM config file {path} not found.")
model_dict = yaml.safe_load(content)
return cls.model_validate(model_dict)
def _is_reasoning_model(model_name: str) -> bool:
"""Check if a model is a reasoning model based on its name."""
return any(prefix in model_name for prefix in _OPENAI_REASONING_INFIXES)
def _is_openai_model(model_name: str) -> bool:
"""Check if a model is an OpenAI model based on its name."""
return any(model_name.startswith(prefix) for prefix in _OPENAI_PREFIXES)
def _is_anthropic_model(model_name: str) -> bool:
"""Check if a model is an Anthropic model based on its name."""
return any(model_name.startswith(prefix) for prefix in _ANTHROPIC_PREFIXES)
def _parse_model_provider(model: str) -> tuple[str, str]:
"""Parse the provider and model name from a string of the form 'provider:name' or 'name'."""
provider, sep, name = model.partition(":")
if len(sep) == 0 and len(name) == 0:
if _is_openai_model(model):
return "openai", model
elif _is_anthropic_model(model):
return "anthropic", model
else:
return "", model
return provider, name
class LLMConfigDirectory:
"""Namespace for preconfigured LLM configurations."""
@classmethod
def list_all(cls) -> list[LLMConfig]:
return [config for name, config in vars(cls).items() if name.isupper()]
DEFAULT = LLMConfig(name="gpt-4o-mini")
# https://huggingface.co/Qwen/Qwen3-8B-GGUF#best-practices
QWEN3_8B_OAI = LLMConfig(
name="qwen/qwen3-8b",
api_base_url="http://localhost:8080/v1",
max_tokens=32768,
temperature=0.6,
use_responses_api=False,
timeout=600,
)
# https://huggingface.co/Qwen/Qwen3-8B-GGUF#best-practices
QWEN3_8B_OLLAMA = LLMConfig(
name="ollama:qwen3:8b",
max_tokens=32768,
temperature=0.6,
timeout=600,
# Refer to https://python.langchain.com/api_reference/ollama/chat_models/langchain_ollama.chat_models.ChatOllama.html
model_kwargs={
"reasoning": True,
"num_ctx": 40960, # Override the global context size: https://docs.ollama.com/context-length
"num_predict": 32768,
"validate_model_on_init": True,
},
)