server/server.py (101 lines of code) (raw):
import time
from collections.abc import AsyncIterator
from logging import getLogger
from typing import Any, Dict
from agents import Runner, trace
from agents.voice import (
TTSModelSettings,
VoicePipeline,
VoicePipelineConfig,
VoiceWorkflowBase,
)
from app.agent_config import starting_agent
from app.utils import (
WebsocketHelper,
concat_audio_chunks,
extract_audio_chunk,
is_audio_complete,
is_new_audio_chunk,
is_new_text_message,
is_sync_message,
is_text_output,
process_inputs,
)
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse, HTMLResponse
from dotenv import load_dotenv
# When .env file is present, it will override the environment variables
load_dotenv(dotenv_path="../.env", override=True)
app = FastAPI()
logger = getLogger(__name__)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
class Workflow(VoiceWorkflowBase):
def __init__(self, connection: WebsocketHelper):
self.connection = connection
async def run(self, input_text: str) -> AsyncIterator[str]:
conversation_history, latest_agent = await self.connection.show_user_input(
input_text
)
output = Runner.run_streamed(
latest_agent,
conversation_history,
)
async for event in output.stream_events():
await self.connection.handle_new_item(event)
if is_text_output(event):
yield event.data.delta # type: ignore
await self.connection.text_output_complete(output, is_done=True)
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
with trace("Voice Agent Chat"):
await websocket.accept()
connection = WebsocketHelper(websocket, [], starting_agent)
audio_buffer = []
workflow = Workflow(connection)
while True:
try:
message = await websocket.receive_json()
except WebSocketDisconnect:
print("Client disconnected")
return
# Handle text based messages
if is_sync_message(message):
connection.history = message["inputs"]
if message.get("reset_agent", False):
connection.latest_agent = starting_agent
elif is_new_text_message(message):
user_input = process_inputs(message, connection)
async for new_output_tokens in workflow.run(user_input):
await connection.stream_response(new_output_tokens, is_text=True)
# Handle a new audio chunk
elif is_new_audio_chunk(message):
audio_buffer.append(extract_audio_chunk(message))
# Send full audio to the agent
elif is_audio_complete(message):
start_time = time.perf_counter()
def transform_data(data):
nonlocal start_time
if start_time:
print(
f"Time taken to first byte: {time.perf_counter() - start_time}s"
)
start_time = None
return data
audio_input = concat_audio_chunks(audio_buffer)
output = await VoicePipeline(
workflow=workflow,
config=VoicePipelineConfig(
tts_settings=TTSModelSettings(
buffer_size=512, transform_data=transform_data
)
),
).run(audio_input)
async for event in output.stream():
await connection.send_audio_chunk(event)
audio_buffer = [] # reset the audio buffer
if __name__ == "__main__":
import uvicorn
uvicorn.run("server:app", host="0.0.0.0", port=8000, reload=True)