src/agents/agent_output.py (114 lines of code) (raw):
import abc
from dataclasses import dataclass
from typing import Any
from pydantic import BaseModel, TypeAdapter
from typing_extensions import TypedDict, get_args, get_origin
from .exceptions import ModelBehaviorError, UserError
from .strict_schema import ensure_strict_json_schema
from .tracing import SpanError
from .util import _error_tracing, _json
_WRAPPER_DICT_KEY = "response"
class AgentOutputSchemaBase(abc.ABC):
"""An object that captures the JSON schema of the output, as well as validating/parsing JSON
produced by the LLM into the output type.
"""
@abc.abstractmethod
def is_plain_text(self) -> bool:
"""Whether the output type is plain text (versus a JSON object)."""
pass
@abc.abstractmethod
def name(self) -> str:
"""The name of the output type."""
pass
@abc.abstractmethod
def json_schema(self) -> dict[str, Any]:
"""Returns the JSON schema of the output. Will only be called if the output type is not
plain text.
"""
pass
@abc.abstractmethod
def is_strict_json_schema(self) -> bool:
"""Whether the JSON schema is in strict mode. Strict mode constrains the JSON schema
features, but guarantees valis JSON. See here for details:
https://platform.openai.com/docs/guides/structured-outputs#supported-schemas
"""
pass
@abc.abstractmethod
def validate_json(self, json_str: str) -> Any:
"""Validate a JSON string against the output type. You must return the validated object,
or raise a `ModelBehaviorError` if the JSON is invalid.
"""
pass
@dataclass(init=False)
class AgentOutputSchema(AgentOutputSchemaBase):
"""An object that captures the JSON schema of the output, as well as validating/parsing JSON
produced by the LLM into the output type.
"""
output_type: type[Any]
"""The type of the output."""
_type_adapter: TypeAdapter[Any]
"""A type adapter that wraps the output type, so that we can validate JSON."""
_is_wrapped: bool
"""Whether the output type is wrapped in a dictionary. This is generally done if the base
output type cannot be represented as a JSON Schema object.
"""
_output_schema: dict[str, Any]
"""The JSON schema of the output."""
_strict_json_schema: bool
"""Whether the JSON schema is in strict mode. We **strongly** recommend setting this to True,
as it increases the likelihood of correct JSON input.
"""
def __init__(self, output_type: type[Any], strict_json_schema: bool = True):
"""
Args:
output_type: The type of the output.
strict_json_schema: Whether the JSON schema is in strict mode. We **strongly** recommend
setting this to True, as it increases the likelihood of correct JSON input.
"""
self.output_type = output_type
self._strict_json_schema = strict_json_schema
if output_type is None or output_type is str:
self._is_wrapped = False
self._type_adapter = TypeAdapter(output_type)
self._output_schema = self._type_adapter.json_schema()
return
# We should wrap for things that are not plain text, and for things that would definitely
# not be a JSON Schema object.
self._is_wrapped = not _is_subclass_of_base_model_or_dict(output_type)
if self._is_wrapped:
OutputType = TypedDict(
"OutputType",
{
_WRAPPER_DICT_KEY: output_type, # type: ignore
},
)
self._type_adapter = TypeAdapter(OutputType)
self._output_schema = self._type_adapter.json_schema()
else:
self._type_adapter = TypeAdapter(output_type)
self._output_schema = self._type_adapter.json_schema()
if self._strict_json_schema:
try:
self._output_schema = ensure_strict_json_schema(self._output_schema)
except UserError as e:
raise UserError(
"Strict JSON schema is enabled, but the output type is not valid. "
"Either make the output type strict, or pass output_schema_strict=False to "
"your Agent()"
) from e
def is_plain_text(self) -> bool:
"""Whether the output type is plain text (versus a JSON object)."""
return self.output_type is None or self.output_type is str
def is_strict_json_schema(self) -> bool:
"""Whether the JSON schema is in strict mode."""
return self._strict_json_schema
def json_schema(self) -> dict[str, Any]:
"""The JSON schema of the output type."""
if self.is_plain_text():
raise UserError("Output type is plain text, so no JSON schema is available")
return self._output_schema
def validate_json(self, json_str: str) -> Any:
"""Validate a JSON string against the output type. Returns the validated object, or raises
a `ModelBehaviorError` if the JSON is invalid.
"""
validated = _json.validate_json(json_str, self._type_adapter, partial=False)
if self._is_wrapped:
if not isinstance(validated, dict):
_error_tracing.attach_error_to_current_span(
SpanError(
message="Invalid JSON",
data={"details": f"Expected a dict, got {type(validated)}"},
)
)
raise ModelBehaviorError(
f"Expected a dict, got {type(validated)} for JSON: {json_str}"
)
if _WRAPPER_DICT_KEY not in validated:
_error_tracing.attach_error_to_current_span(
SpanError(
message="Invalid JSON",
data={"details": f"Could not find key {_WRAPPER_DICT_KEY} in JSON"},
)
)
raise ModelBehaviorError(
f"Could not find key {_WRAPPER_DICT_KEY} in JSON: {json_str}"
)
return validated[_WRAPPER_DICT_KEY]
return validated
def name(self) -> str:
"""The name of the output type."""
return _type_to_str(self.output_type)
def _is_subclass_of_base_model_or_dict(t: Any) -> bool:
if not isinstance(t, type):
return False
# If it's a generic alias, 'origin' will be the actual type, e.g. 'list'
origin = get_origin(t)
allowed_types = (BaseModel, dict)
# If it's a generic alias e.g. list[str], then we should check the origin type i.e. list
return issubclass(origin or t, allowed_types)
def _type_to_str(t: type[Any]) -> str:
origin = get_origin(t)
args = get_args(t)
if origin is None:
# It's a simple type like `str`, `int`, etc.
return t.__name__
elif args:
args_str = ", ".join(_type_to_str(arg) for arg in args)
return f"{origin.__name__}[{args_str}]"
else:
return str(t)